|
class UnificationFailure(Exception): |
|
pass |
|
|
|
|
|
class Occurs(UnificationFailure): |
|
pass |
|
|
|
|
|
class Type(object): |
|
def __str__(self): return self.show(True) |
|
|
|
def __repr__(self): return str(self) |
|
|
|
@staticmethod |
|
def fromjson(j): |
|
if "index" in j: return TypeVariable(j["index"]) |
|
if "constructor" in j: return TypeConstructor(j["constructor"], |
|
[ Type.fromjson(a) for a in j["arguments"] ]) |
|
assert False |
|
|
|
|
|
class TypeConstructor(Type): |
|
def __init__(self, name, arguments): |
|
self.name = name |
|
self.arguments = arguments |
|
self.isPolymorphic = any(a.isPolymorphic for a in arguments) |
|
|
|
def makeDummyMonomorphic(self, mapping=None): |
|
mapping = mapping if mapping is not None else {} |
|
return TypeConstructor(self.name, |
|
[ a.makeDummyMonomorphic(mapping) for a in self.arguments ]) |
|
|
|
def __eq__(self, other): |
|
return isinstance(other, TypeConstructor) and \ |
|
self.name == other.name and \ |
|
all(x == y for x, y in zip(self.arguments, other.arguments)) |
|
|
|
def __hash__(self): return hash((self.name,) + tuple(self.arguments)) |
|
|
|
def __ne__(self, other): |
|
return not (self == other) |
|
|
|
def show(self, isReturn): |
|
if self.name == ARROW: |
|
if isReturn: |
|
return "%s %s %s" % (self.arguments[0].show( |
|
False), ARROW, self.arguments[1].show(True)) |
|
else: |
|
return "(%s %s %s)" % (self.arguments[0].show( |
|
False), ARROW, self.arguments[1].show(True)) |
|
elif self.arguments == []: |
|
return self.name |
|
else: |
|
return "%s(%s)" % (self.name, ", ".join(x.show(True) |
|
for x in self.arguments)) |
|
|
|
def json(self): |
|
return {"constructor": self.name, |
|
"arguments": [a.json() for a in self.arguments]} |
|
|
|
|
|
def isArrow(self): return self.name == ARROW |
|
|
|
def functionArguments(self): |
|
if self.name == ARROW: |
|
xs = self.arguments[1].functionArguments() |
|
return [self.arguments[0]] + xs |
|
return [] |
|
|
|
def returns(self): |
|
if self.name == ARROW: |
|
return self.arguments[1].returns() |
|
else: |
|
return self |
|
|
|
def apply(self, context): |
|
if not self.isPolymorphic: |
|
return self |
|
return TypeConstructor(self.name, |
|
[x.apply(context) for x in self.arguments]) |
|
|
|
def applyMutable(self, context): |
|
if not self.isPolymorphic: |
|
return self |
|
return TypeConstructor(self.name, |
|
[x.applyMutable(context) for x in self.arguments]) |
|
|
|
def occurs(self, v): |
|
if not self.isPolymorphic: |
|
return False |
|
return any(x.occurs(v) for x in self.arguments) |
|
|
|
def negateVariables(self): |
|
return TypeConstructor(self.name, |
|
[a.negateVariables() for a in self.arguments]) |
|
|
|
def instantiate(self, context, bindings=None): |
|
if not self.isPolymorphic: |
|
return context, self |
|
if bindings is None: |
|
bindings = {} |
|
newArguments = [] |
|
for x in self.arguments: |
|
(context, x) = x.instantiate(context, bindings) |
|
newArguments.append(x) |
|
return (context, TypeConstructor(self.name, newArguments)) |
|
|
|
def instantiateMutable(self, context, bindings=None): |
|
if not self.isPolymorphic: |
|
return self |
|
if bindings is None: |
|
bindings = {} |
|
newArguments = [] |
|
return TypeConstructor(self.name, [x.instantiateMutable(context, bindings) |
|
for x in self.arguments ]) |
|
|
|
|
|
def canonical(self, bindings=None): |
|
if not self.isPolymorphic: |
|
return self |
|
if bindings is None: |
|
bindings = {} |
|
return TypeConstructor(self.name, |
|
[x.canonical(bindings) for x in self.arguments]) |
|
|
|
|
|
class TypeVariable(Type): |
|
def __init__(self, j): |
|
assert isinstance(j, int) |
|
self.v = j |
|
self.isPolymorphic = True |
|
|
|
def makeDummyMonomorphic(self, mapping=None): |
|
mapping = mapping if mapping is not None else {} |
|
if self.v not in mapping: |
|
mapping[self.v] = TypeConstructor(f"dummy_type_{len(mapping)}", []) |
|
return mapping[self.v] |
|
|
|
|
|
def __eq__(self, other): |
|
return isinstance(other, TypeVariable) and self.v == other.v |
|
|
|
def __ne__(self, other): return not (self.v == other.v) |
|
|
|
def __hash__(self): return self.v |
|
|
|
def show(self, _): return "t%d" % self.v |
|
|
|
def json(self): |
|
return {"index": self.v} |
|
|
|
def returns(self): return self |
|
|
|
def isArrow(self): return False |
|
|
|
def functionArguments(self): return [] |
|
|
|
def apply(self, context): |
|
for v, t in context.substitution: |
|
if v == self.v: |
|
return t.apply(context) |
|
return self |
|
|
|
def applyMutable(self, context): |
|
s = context.substitution[self.v] |
|
if s is None: return self |
|
new = s.applyMutable(context) |
|
context.substitution[self.v] = new |
|
return new |
|
|
|
def occurs(self, v): return v == self.v |
|
|
|
def instantiate(self, context, bindings=None): |
|
if bindings is None: |
|
bindings = {} |
|
if self.v in bindings: |
|
return (context, bindings[self.v]) |
|
new = TypeVariable(context.nextVariable) |
|
bindings[self.v] = new |
|
context = Context(context.nextVariable + 1, context.substitution) |
|
return (context, new) |
|
|
|
def instantiateMutable(self, context, bindings=None): |
|
if bindings is None: bindings = {} |
|
if self.v in bindings: return bindings[self.v] |
|
new = context.makeVariable() |
|
bindings[self.v] = new |
|
return new |
|
|
|
def canonical(self, bindings=None): |
|
if bindings is None: |
|
bindings = {} |
|
if self.v in bindings: |
|
return bindings[self.v] |
|
new = TypeVariable(len(bindings)) |
|
bindings[self.v] = new |
|
return new |
|
|
|
def negateVariables(self): |
|
return TypeVariable(-1 - self.v) |
|
|
|
|
|
class Context(object): |
|
def __init__(self, nextVariable=0, substitution=[]): |
|
self.nextVariable = nextVariable |
|
self.substitution = substitution |
|
|
|
def extend(self, j, t): |
|
return Context(self.nextVariable, [(j, t)] + self.substitution) |
|
|
|
def makeVariable(self): |
|
return (Context(self.nextVariable + 1, self.substitution), |
|
TypeVariable(self.nextVariable)) |
|
|
|
def unify(self, t1, t2): |
|
t1 = t1.apply(self) |
|
t2 = t2.apply(self) |
|
if t1 == t2: |
|
return self |
|
|
|
if not t1.isPolymorphic and not t2.isPolymorphic: |
|
raise UnificationFailure(t1, t2) |
|
|
|
if isinstance(t1, TypeVariable): |
|
if t2.occurs(t1.v): |
|
raise Occurs() |
|
return self.extend(t1.v, t2) |
|
if isinstance(t2, TypeVariable): |
|
if t1.occurs(t2.v): |
|
raise Occurs() |
|
return self.extend(t2.v, t1) |
|
if t1.name != t2.name: |
|
raise UnificationFailure(t1, t2) |
|
k = self |
|
for x, y in zip(t2.arguments, t1.arguments): |
|
k = k.unify(x, y) |
|
return k |
|
|
|
def __str__(self): |
|
return "Context(next = %d, {%s})" % (self.nextVariable, ", ".join( |
|
"t%d ||> %s" % (k, v.apply(self)) for k, v in self.substitution)) |
|
|
|
def __repr__(self): return str(self) |
|
|
|
class MutableContext(object): |
|
def __init__(self): |
|
self.substitution = [] |
|
|
|
def extend(self,i,t): |
|
assert self.substitution[i] is None |
|
self.substitution[i] = t |
|
|
|
def makeVariable(self): |
|
self.substitution.append(None) |
|
return TypeVariable(len(self.substitution) - 1) |
|
|
|
def unify(self, t1, t2): |
|
t1 = t1.applyMutable(self) |
|
t2 = t2.applyMutable(self) |
|
|
|
if t1 == t2: return |
|
|
|
|
|
if not t1.isPolymorphic and not t2.isPolymorphic: |
|
raise UnificationFailure(t1, t2) |
|
|
|
if isinstance(t1, TypeVariable): |
|
if t2.occurs(t1.v): |
|
raise Occurs() |
|
self.extend(t1.v, t2) |
|
return |
|
if isinstance(t2, TypeVariable): |
|
if t1.occurs(t2.v): |
|
raise Occurs() |
|
self.extend(t2.v, t1) |
|
return |
|
if t1.name != t2.name: |
|
raise UnificationFailure(t1, t2) |
|
|
|
for x, y in zip(t2.arguments, t1.arguments): |
|
self.unify(x, y) |
|
|
|
|
|
Context.EMPTY = Context(0, []) |
|
|
|
|
|
def canonicalTypes(ts): |
|
bindings = {} |
|
return [t.canonical(bindings) for t in ts] |
|
|
|
|
|
def instantiateTypes(context, ts): |
|
bindings = {} |
|
newTypes = [] |
|
for t in ts: |
|
context, t = t.instantiate(context, bindings) |
|
newTypes.append(t) |
|
return context, newTypes |
|
|
|
|
|
def baseType(n): return TypeConstructor(n, []) |
|
|
|
|
|
tint = baseType("int") |
|
treal = baseType("real") |
|
tbool = baseType("bool") |
|
tboolean = tbool |
|
tcharacter = baseType("char") |
|
|
|
|
|
def tlist(t): return TypeConstructor("list", [t]) |
|
|
|
|
|
def tpair(a, b): return TypeConstructor("pair", [a, b]) |
|
|
|
|
|
def tmaybe(t): return TypeConstructor("maybe", [t]) |
|
|
|
|
|
tstr = tlist(tcharacter) |
|
t0 = TypeVariable(0) |
|
t1 = TypeVariable(1) |
|
t2 = TypeVariable(2) |
|
|
|
|
|
tpregex = baseType("pregex") |
|
|
|
ARROW = "->" |
|
|
|
|
|
def arrow(*arguments): |
|
if len(arguments) == 1: |
|
return arguments[0] |
|
return TypeConstructor(ARROW, [arguments[0], arrow(*arguments[1:])]) |
|
|
|
|
|
def inferArg(tp, tcaller): |
|
ctx, tp = tp.instantiate(Context.EMPTY) |
|
ctx, tcaller = tcaller.instantiate(ctx) |
|
ctx, targ = ctx.makeVariable() |
|
ctx = ctx.unify(tcaller, arrow(targ, tp)) |
|
return targ.apply(ctx) |
|
|
|
|
|
def guess_type(xs): |
|
""" |
|
Return a TypeConstructor corresponding to x's python type. |
|
Raises an exception if the type cannot be guessed. |
|
""" |
|
if all(isinstance(x, bool) for x in xs): |
|
return tbool |
|
elif all(isinstance(x, int) for x in xs): |
|
return tint |
|
elif all(isinstance(x, str) for x in xs): |
|
return tstr |
|
elif all(isinstance(x, list) for x in xs): |
|
return tlist(guess_type([y for ys in xs for y in ys])) |
|
else: |
|
raise ValueError("cannot guess type from {}".format(xs)) |
|
|
|
|
|
def guess_arrow_type(examples): |
|
a = len(examples[0][0]) |
|
input_types = [] |
|
for n in range(a): |
|
input_types.append(guess_type([xs[n] for xs, _ in examples])) |
|
output_type = guess_type([y for _, y in examples]) |
|
return arrow(*(input_types + [output_type])) |
|
|
|
def canUnify(t1, t2): |
|
k = MutableContext() |
|
t1 = t1.instantiateMutable(k) |
|
t2 = t2.instantiateMutable(k) |
|
try: |
|
k.unify(t1, t2) |
|
return True |
|
except UnificationFailure: return False |
|
|
|
|