Scala類型推導
之劍
2016.5.1 00:38:12
類型系統
什么是靜態類型?為什么它們很有用?
根據Picrce的說法:“類型系統是一個可以根據代碼段計算出來的值對它們進行分類,然后通過語法的手段來自動檢測程序錯誤的系統。”
類型可以讓你表示函數的域和值域。例如,在數學里,我們經常看到下面的函數:
f: R -> N
這個定義告訴我們函數”f”的作用是把實數集里的數映射到自然集里。
抽象地說,這才是具體意義上的類型。類型系統給了我們一些表示這些集合的更強大的方式。
有了這些類型標識,編譯器現在可以 靜態地(在編譯期)判斷這個程序是正確的。
換句話說,如果一個值(在運行期)不能夠滿足程序里的限制條件的話,那么在編譯期就會出錯。
通常來說,類型檢測器(typechecker)只能夠保證不正確的的程序不能通過編譯。但是,它不能夠保證所有正確的程序都能通過編譯。
由于類型系統的表達能力不斷增加,使得我們能夠生成出更加可靠的代碼,因為它使得我們能夠控制程序上的不可變,即是是程序還沒有運行的情況下(在類型上限制bug的出現)。學術界一直在努力突破類型系統的表達能力的極限,包含值相關的類型。
注意,所有的類型信息都會在編譯期擦除。后面不再需要。這個被稱為類型擦除。比如,Java里面的泛型的實現.
Scala中的類型
Scala強大的類型系統讓我們可以使用更具有表現力的表達式。一些主要的特點如下:
支持參數多態,泛型編程
支持(局部)類型推導,這就是你為什么不需要寫val i: Int = 12: Int
支持存在向量(existential quantification),給一些沒有名稱的類型定義一些操作
支持視圖。 給定的值從一個類型到其他類型的“可轉換性”
參數多態
多態可以用來編寫泛型代碼(用于處理不同類型的值),并且不會減少靜態類型的表達能力。
例如,沒有參數多態的話,一個泛型的列表數據結構通常會是下面這樣的寫法(在Java還沒有泛型的時候,確實是這樣的):
scala> 2 :: 1 :: "bar" :: "foo" :: Nil
res5: List[Any] = List(2, 1, bar, foo)
這樣的話,我們就不能夠恢復每個元素的類型信息。
scala> res5.head
res6: Any = 2
這樣一來,我們的應用里會包含一系列的類型轉換(“asInstanceOf[]“),代碼會缺少類型安全(因為他們都是動態類型的)。
多態是通過指定類型變量來達到的。
scala> def drop1[A](l: List[A]) = l.tail
drop1: [A](l: List[A])List[A]
scala> drop1(List(1,2,3))
res1: List[Int] = List(2, 3)
多態是scala里的一等公民
簡單來說,這意味著有一些你想在Scala里表達的類型概念會顯得“太過于泛型”,從而導致編譯器無法理解。所有的類型變量在運行期必須是確定的。
對于靜態類型的一個比較常見的缺陷就是有太多的類型語法。Scala提供了類型推導來解決這個問題。
函數式語言里比較經典的類型推導的方法是 Hindlry-Milner,并且它是在ML里首先使用的。
Scala的類型推導有一點點不同,不過思想上是一致的:推導所有的約束條件,然后統一到一個類型上。
在Scala里,例如,你不能這樣寫:
scala> { x => x }
:7: error: missing parameter type
{ x => x }
但是在OCaml里,你可以:
# fun x -> x;;
- : 'a -> 'a =
在Scala里,所有的類型推導都是局部的。Scala一次只考慮一個表達式。例如:
scala> def id[T](x: T) = x
id: [T](x: T)T
scala> val x = id(322)
x: Int = 322
scala> val x = id("hey")
x: java.lang.String = hey
scala> val x = id(Array(1,2,3,4))
x: Array[Int] = Array(1, 2, 3, 4)
在這里,類型都被隱藏了。Scala編譯器自動推導參數的類型。注意我們也沒有必要顯示指定返回值的類型了。
型變
Scala的類型系統需要把類的繼承關系和多態結合起來。類的繼承使得類之間存在父子的關系。當把面向對象和多態結合在一起時,一個核心的問題就出來了:如果T'是T的子類,那么Container[T']是不是Container[T]的子類呢?Variance注釋允許你在類繼承和多態類型之間表達下面的這些關系:
<table>
<tr>
<td></td><td>含義</td><td>Scala中的標記</td>
</tr>
<tr>
<td>covariant(協變)</td><td> C[T’]是C[T]的子類</td><td>[+T]</td>
</tr>
<tr>
<td>contravariant(逆變)</td><td>C[T]是C[T’]子類</td><td>[-T]</td>
</tr>
<tr>
<td>invariant(不變)</td><td> C[T]和C[T’]不相關</td><td>[T]</td>
</tr>
</table>
子類關系的真正意思是:對于一個給定的類型T,如果T’是它的子類,那么T’可以代替T嗎?
scala> class Contravariant[-A]
defined class Contravariant
scala> val cv: Contravariant[String] = new Contravariant[AnyRef]
cv: Contravariant[AnyRef] = Contravariant@49fa7ba
scala> val fail: Contravariant[AnyRef] = new Contravariant[String]
:6: error: type mismatch;
found : Contravariant[String]
required: Contravariant[AnyRef]
val fail: Contravariant[AnyRef] = new Contravariant[String]
量化(Quantification)
有時候你不需要給一個類型變量以名稱,例如
scala> def count[A](l: List[A]) = l.size
count: [A](List[A])Int
你可以用“通配符”來替代:
scala> def count(l: List[_]) = l.size
count: (List[_])Int
什么是類型推導
先看個代碼:
Map<Integer, Map<String, String>> m = new HashMap<Integer, Map<String, String>>();
是啊, 這簡直太長了,我們不禁感嘆,這編譯器也太愚蠢了.幾乎一半字符都是重復的!
針對泛型定義和實例太過繁瑣的問題,在java 7 中引入了鉆石運算符. 神奇的Coin項目,滿足了你的心愿.
于是,你在java 7之后可以這樣寫了:
Map<Integer, Map<String, String>> m = new HashMap();
鉆石運算符通常用于簡化創建帶有泛型對象的代碼,可以避免運行時 的異常,并且它不再要求程序員在編碼時顯示書寫冗余的類型參數。實際上,編譯器在進行詞法解析時會自動推導類型,自動為代碼進行補全,并且編譯的字節碼與 以前無異。
當時在提案中,這個問題叫"Improved Type Inference for Generic Instance Creation",縮寫ITIGIX聽起來怪怪的,但是為啥叫鉆石算法? 世界上, 哪有那么多為什么.
Scala正是因為做了類型推導, 讓Coders感覺仿佛在寫動態語言的代碼.
在Scala中,高階函數經常傳遞匿名函數.舉個栗子:
一段定義泛型函數的代碼
def dropWhile[A](list: List[A], f: A => Boolean): List[A]
當我們傳入一個匿名函數f來調用它,
val mylist: List[Int] = List(1,2,3,4,5)
val listDropped = dropWhile( mylist, (x: Int) => x < 4 )
listDropped的值是List(4,5)
我們用大腦可以輕易判斷, 當list: List[A] 中的類型A在mylist聲明的時候已經指定了Int, 那么很明顯, 在第二個參數中,我們的x也必是Int.
很幸運Scala設計者們早已考慮到這一點,Scala編譯器可以推導這種情況.但是你得按照Scala的規范限制來寫你的dropWhile函數的簽名(柯里化的): dropWhile( mylist )( f )
def dropWhile[A] ( list: List[A] ) ( f: A => Boolean ) : List[A] = list match {
case Cons(h,t) if f(h) => dropWhile(t)(f)
case _ => list
}
如此而來,我們就可以直接像下面這樣使用這個函數了:
val mylist: List[Int] = List(1,2,3,4,5)
val droppedList = dropWhile( mylist ) ( x => x < 4 )
注意, x參數沒有指定Int類型, 因為編譯器直接通過mylist的泛型信息Int推導出x的類型也是Int.
類型推導是一門博大的學問,背后有繁冗的理論, 這在編譯器設計開發的時候需要解決的問題.
|Scala|Haskell,ML|
|---------|--------|
|局部的(local)、基于流的(flow-based)類型推斷|全局化的Hindley-Milner類型推斷|
在《Programming in Scala》一書中提到基于流的類型推斷有它的局限性,但是對于面向對象的分支類型處理比Hindley-Mlner更加優雅。
基于流的類型推導在偏應用函數場景下,不能對參數類型省略
類型推導算法
類型推導(Type Inference)是現代高級語言中一個越來越常見的特性。其實,這個特性在函數式語言
中早有了廣泛應用。而HindleyMilner推導器是所有類型推導器的基礎。
Scala實現的一個簡單的HindleyMilner推導器:
/*
* http://dysphoria.net/code/hindley-milner/HindleyMilner.scala
* Andrew Forrest
*
* Implementation of basic polymorphic type-checking for a simple language.
* Based heavily on Nikita Borisov’s Perl implementation at
* http://web.archive.org/web/20050420002559/www.cs.berkeley.edu/~nikitab/courses/cs263/hm.html
* which in turn is based on the paper by Luca Cardelli at
* http://lucacardelli.name/Papers/BasicTypechecking.pdf
*
* If you run it with "scala HindleyMilner.scala" it will attempt to report the types
* for a few example expressions. (It uses UTF-8 for output, so you may need to set your
* terminal accordingly.)
*
* Changes
* June 30, 2011 by Liang Kun(liangkun(AT)baidu.com)
* 1. Modify to enhance readability
* 2. Extend to Support if expression in syntax
*
*
*
* Do with it what you will. :)
*/
/** Syntax definition. This is a simple lambda calculous syntax.
* Expression ::= Identifier
* | Constant
* | "if" Expression "then" Expression "else" Expression
* | "lambda(" Identifier ") " Expression
* | Expression "(" Expression ")"
* | "let" Identifier "=" Expression "in" Expression
* | "letrec" Identifier "=" Expression "in" Expression
* | "(" Expression ")"
* See the examples below in main function.
*/
sealed abstract class Expression
case class Identifier(name: String) extends Expression {
override def toString = name
}
case class Constant(value: String) extends Expression {
override def toString = value
}
case class If(condition: Expression, then: Expression, other: Expression) extends Expression {
override def toString = "(if " + condition + " then " + then + " else " + other + ")"
}
case class Lambda(argument: Identifier, body: Expression) extends Expression {
override def toString = "(lambda " + argument + " → " + body + ")"
}
case class Apply(function: Expression, argument: Expression) extends Expression {
override def toString = "(" + function + " " + argument + ")"
}
case class Let(binding: Identifier, definition: Expression, body: Expression) extends Expression {
override def toString = "(let " + binding + " = " + definition + " in " + body + ")"
}
case class Letrec(binding: Identifier, definition: Expression, body: Expression) extends Expression {
override def toString = "(letrec " + binding + " = " + definition + " in " + body + ")"
}
/** Exceptions may happened */
class TypeError(msg: String) extends Exception(msg)
class ParseError(msg: String) extends Exception(msg)
/** Type inference system */
object TypeSystem {
type Env = Map[Identifier, Type]
val EmptyEnv: Map[Identifier, Type] = Map.empty
// type variable and type operator
sealed abstract class Type
case class Variable(id: Int) extends Type {
var instance: Option[Type] = None
lazy val name = nextUniqueName()
override def toString = instance match {
case Some(t) => t.toString
case None => name
}
}
case class Operator(name: String, args: Seq[Type]) extends Type {
override def toString = {
if (args.length == 0)
name
else if (args.length == 2)
"[" + args(0) + " " + name + " " + args(1) + "]"
else
args.mkString(name + "[", ", ", "]")
}
}
// builtin types, types can be extended by environment
def Function(from: Type, to: Type) = Operator("→", Array(from, to))
val Integer = Operator("Integer", Array[Type]())
val Boolean = Operator("Boolean", Array[Type]())
protected var _nextVariableName = 'α';
protected def nextUniqueName() = {
val result = _nextVariableName
_nextVariableName = (_nextVariableName.toInt + 1).toChar
result.toString
}
protected var _nextVariableId = 0
def newVariable(): Variable = {
val result = _nextVariableId
_nextVariableId += 1
Variable(result)
}
// main entry point
def analyze(expr: Expression, env: Env): Type = analyze(expr, env, Set.empty)
def analyze(expr: Expression, env: Env, nongeneric: Set[Variable]): Type = expr match {
case i: Identifier => getIdentifierType(i, env, nongeneric)
case Constant(value) => getConstantType(value)
case If(cond, then, other) => {
val condType = analyze(cond, env, nongeneric)
val thenType = analyze(then, env, nongeneric)
val otherType = analyze(other, env, nongeneric)
unify(condType, Boolean)
unify(thenType, otherType)
thenType
}
case Apply(func, arg) => {
val funcType = analyze(func, env, nongeneric)
val argType = analyze(arg, env, nongeneric)
val resultType = newVariable()
unify(Function(argType, resultType), funcType)
resultType
}
case Lambda(arg, body) => {
val argType = newVariable()
val resultType = analyze(body,
env + (arg -> argType),
nongeneric + argType)
Function(argType, resultType)
}
case Let(binding, definition, body) => {
val definitionType = analyze(definition, env, nongeneric)
val newEnv = env + (binding -> definitionType)
analyze(body, newEnv, nongeneric)
}
case Letrec(binding, definition, body) => {
val newType = newVariable()
val newEnv = env + (binding -> newType)
val definitionType = analyze(definition, newEnv, nongeneric + newType)
unify(newType, definitionType)
analyze(body, newEnv, nongeneric)
}
}
protected def getIdentifierType(id: Identifier, env: Env, nongeneric: Set[Variable]): Type = {
if (env.contains(id))
fresh(env(id), nongeneric)
else
throw new ParseError("Undefined symbol: " + id)
}
protected def getConstantType(value: String): Type = {
if(isIntegerLiteral(value))
Integer
else
throw new ParseError("Undefined symbol: " + value)
}
protected def fresh(t: Type, nongeneric: Set[Variable]) = {
import scala.collection.mutable
val mappings = new mutable.HashMap[Variable, Variable]
def freshrec(tp: Type): Type = {
prune(tp) match {
case v: Variable =>
if (isgeneric(v, nongeneric))
mappings.getOrElseUpdate(v, newVariable())
else
v
case Operator(name, args) =>
Operator(name, args.map(freshrec(_)))
}
}
freshrec(t)
}
protected def unify(t1: Type, t2: Type) {
val type1 = prune(t1)
val type2 = prune(t2)
(type1, type2) match {
case (a: Variable, b) => if (a != b) {
if (occursintype(a, b))
throw new TypeError("Recursive unification")
a.instance = Some(b)
}
case (a: Operator, b: Variable) => unify(b, a)
case (a: Operator, b: Operator) => {
if (a.name != b.name ||
a.args.length != b.args.length) throw new TypeError("Type mismatch: " + a + " ≠ " + b)
for(i <- 0 until a.args.length)
unify(a.args(i), b.args(i))
}
}
}
// Returns the currently defining instance of t.
// As a side effect, collapses the list of type instances.
protected def prune(t: Type): Type = t match {
case v: Variable if v.instance.isDefined => {
val inst = prune(v.instance.get)
v.instance = Some(inst)
inst
}
case _ => t
}
// Note: must be called with v 'pre-pruned'
protected def isgeneric(v: Variable, nongeneric: Set[Variable]) = !(occursin(v, nongeneric))
// Note: must be called with v 'pre-pruned'
protected def occursintype(v: Variable, type2: Type): Boolean = {
prune(type2) match {
case `v` => true
case Operator(name, args) => occursin(v, args)
case _ => false
}
}
protected def occursin(t: Variable, list: Iterable[Type]) =
list exists (t2 => occursintype(t, t2))
protected val checkDigits = "^(\\d+)$".r
protected def isIntegerLiteral(name: String) = checkDigits.findFirstIn(name).isDefined
}
/** Demo program */
object HindleyMilner {
def main(args: Array[String]){
Console.setOut(new java.io.PrintStream(Console.out, true, "utf-8"))
// extends the system with a new type[pair] and some builtin functions
val left = TypeSystem.newVariable()
val right = TypeSystem.newVariable()
val pairType = TypeSystem.Operator("×", Array(left, right))
val myenv: TypeSystem.Env = TypeSystem.EmptyEnv ++ Array(
Identifier("pair") -> TypeSystem.Function(left, TypeSystem.Function(right, pairType)),
Identifier("true") -> TypeSystem.Boolean,
Identifier("false")-> TypeSystem.Boolean,
Identifier("zero") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Boolean),
Identifier("pred") -> TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer),
Identifier("times")-> TypeSystem.Function(TypeSystem.Integer,
TypeSystem.Function(TypeSystem.Integer, TypeSystem.Integer))
)
// example expressions
val pair = Apply(
Apply(
Identifier("pair"), Apply(Identifier("f"), Constant("4"))
),
Apply(Identifier("f"), Identifier("true"))
)
val examples = Array[Expression](
// factorial
Letrec(Identifier("factorial"), // letrec factorial =
Lambda(Identifier("n"), // lambda n =>
If(
Apply(Identifier("zero"), Identifier("n")),
Constant("1"),
Apply(
Apply(Identifier("times"), Identifier("n")),
Apply(
Identifier("factorial"),
Apply(Identifier("pred"), Identifier("n"))
)
)
)
), // in
Apply(Identifier("factorial"), Constant("5"))
),
// Should fail:
// fn x => (pair(x(3) (x(true))))
Lambda(Identifier("x"),
Apply(
Apply(Identifier("pair"),
Apply(Identifier("x"), Constant("3"))
),
Apply(Identifier("x"), Identifier("true"))
)
),
// pair(f(3), f(true))
Apply(
Apply(Identifier("pair"), Apply(Identifier("f"), Constant("4"))),
Apply(Identifier("f"), Identifier("true"))
),
// letrec f = (fn x => x) in ((pair (f 4)) (f true))
Let(Identifier("f"), Lambda(Identifier("x"), Identifier("x")), pair),
// Should fail:
// fn f => f f
Lambda(Identifier("f"), Apply(Identifier("f"), Identifier("f"))),
// let g = fn f => 5 in g g
Let(
Identifier("g"),
Lambda(Identifier("f"), Constant("5")),
Apply(Identifier("g"), Identifier("g"))
),
// example that demonstrates generic and non-generic variables:
// fn g => let f = fn x => g in pair (f 3, f true)
Lambda(Identifier("g"),
Let(Identifier("f"),
Lambda(Identifier("x"), Identifier("g")),
Apply(
Apply(Identifier("pair"),
Apply(Identifier("f"), Constant("3"))
),
Apply(Identifier("f"), Identifier("true"))
)
)
),
// Function composition
// fn f (fn g (fn arg (f g arg)))
Lambda( Identifier("f"),
Lambda( Identifier("g"),
Lambda( Identifier("arg"),
Apply(Identifier("g"), Apply(Identifier("f"), Identifier("arg")))
)
)
)
)
for(eg <- examples){
tryexp(myenv, eg)
}
}
def tryexp(env: TypeSystem.Env, expr: Expression) {
try {
val t = TypeSystem.analyze(expr, env)
print(t)
}catch{
case t: ParseError => print(t.getMessage)
case t: TypeError => print(t.getMessage)
}
println(":\t" + expr)
}
}
HindleyMilner.main(argv)
Haskell寫的一個 合一算法的簡單實現:
https://github.com/yihuang/haskell-snippets/blob/master/Unif.hs
module Main where
import Data.List (intersperse)
import Control.Monad
-- utils --
mapFst :: (a -> b) -> (a, c) -> (b, c)
mapFst f (a, c) = (f a, c)
-- types --
type Name = String
data Term = Var Name
| App Name [Term]
-- 表示一個替換關系
type Sub = (Term, Name)
-- implementation --
-- 檢查變量 Name 是否出現在 Term 中
occurs :: Name -> Term -> Bool
occurs x t = case t of
(Var y) -> x==y
(App _ ts) -> and . map (occurs x) $ ts
-- 使用 Sub 對 Term 進行替換
sub :: Sub -> Term -> Term
sub (t1, y) t@(Var a)
| a==y = t1
| otherwise = t
sub s (App f ts) = App f $ map (sub s) ts
-- 使用 Sub 列表對 Term 進行替換
subs :: [Sub] -> Term -> Term
subs ss t = foldl (flip sub) t ss
-- 把兩個替換列表組合起來,同時用新加入的替換對其中所有 Term 進行替換
compose :: [Sub] -> [Sub] -> [Sub]
compose [] s1 = s1
compose (s:ss) s1 = compose ss $ s : iter s s1
where
iter :: Sub -> [Sub] -> [Sub]
iter s ss = map (mapFst (sub s)) ss
-- 合一函數
unify :: Term -> Term -> Maybe [Sub]
unify t1 t2 = case (t1, t2) of
(Var x, Var y) -> if x==y then Just [] else Just [(t1, y)]
(Var x, App _ _) -> if occurs x t2 then Nothing else Just [(t2, x)]
(App _ _, Var x) -> if occurs x t1 then Nothing else Just [(t1, x)]
(App n1 ts1, App n2 ts2)
-> if n1/=n2 then Nothing else unify_args ts1 ts2
where
unify_args [] [] = Just []
unify_args _ [] = Nothing
unify_args [] _ = Nothing
unify_args (t1:ts1) (t2:ts2) = do
u <- unify t1 t2
let update = map (subs u)
u1 <- unify_args (update ts1) (update ts2)
return (u1 `compose` u)
-- display --
instance Show Term where
show (Var s) = s
show (App name ts) = name++"("++(concat . intersperse "," $ (map show ts))++")"
showSub (t, s) = s ++ " -> " ++ show t
-- test cases --
a = Var "a"
b = Var "b"
c = Var "c"
d = Var "d"
x = Var "x"
y = Var "y"
z = Var "z"
f = App "f"
g = App "g"
j = App "j"
test t1 t2 = do
putStrLn $ show t1 ++ " <==> " ++ show t2
case unify t1 t2 of
Nothing -> putStrLn "unify fail"
Just u -> putStrLn $ concat . intersperse "\n" $ map showSub u
testcases = [(j [x,y,z],
j [f [y,y], f [z,z], f [a,a]])
,(x,
f [x])
,(f [x],
y)
,(f [a, f [b, c], g [b, a, c]],
f [a, a, x])
,(f [d, d, x],
f [a, f [b, c], f [b, a, c]])
]
main = forM testcases (uncurry test)