Fraser-Greenlee
add dreamcoder codebase
e1c1753
raw
history blame
11 kB
class UnificationFailure(Exception):
pass
class Occurs(UnificationFailure):
pass
class Type(object):
def __str__(self): return self.show(True)
def __repr__(self): return str(self)
@staticmethod
def fromjson(j):
if "index" in j: return TypeVariable(j["index"])
if "constructor" in j: return TypeConstructor(j["constructor"],
[ Type.fromjson(a) for a in j["arguments"] ])
assert False
class TypeConstructor(Type):
def __init__(self, name, arguments):
self.name = name
self.arguments = arguments
self.isPolymorphic = any(a.isPolymorphic for a in arguments)
def makeDummyMonomorphic(self, mapping=None):
mapping = mapping if mapping is not None else {}
return TypeConstructor(self.name,
[ a.makeDummyMonomorphic(mapping) for a in self.arguments ])
def __eq__(self, other):
return isinstance(other, TypeConstructor) and \
self.name == other.name and \
all(x == y for x, y in zip(self.arguments, other.arguments))
def __hash__(self): return hash((self.name,) + tuple(self.arguments))
def __ne__(self, other):
return not (self == other)
def show(self, isReturn):
if self.name == ARROW:
if isReturn:
return "%s %s %s" % (self.arguments[0].show(
False), ARROW, self.arguments[1].show(True))
else:
return "(%s %s %s)" % (self.arguments[0].show(
False), ARROW, self.arguments[1].show(True))
elif self.arguments == []:
return self.name
else:
return "%s(%s)" % (self.name, ", ".join(x.show(True)
for x in self.arguments))
def json(self):
return {"constructor": self.name,
"arguments": [a.json() for a in self.arguments]}
def isArrow(self): return self.name == ARROW
def functionArguments(self):
if self.name == ARROW:
xs = self.arguments[1].functionArguments()
return [self.arguments[0]] + xs
return []
def returns(self):
if self.name == ARROW:
return self.arguments[1].returns()
else:
return self
def apply(self, context):
if not self.isPolymorphic:
return self
return TypeConstructor(self.name,
[x.apply(context) for x in self.arguments])
def applyMutable(self, context):
if not self.isPolymorphic:
return self
return TypeConstructor(self.name,
[x.applyMutable(context) for x in self.arguments])
def occurs(self, v):
if not self.isPolymorphic:
return False
return any(x.occurs(v) for x in self.arguments)
def negateVariables(self):
return TypeConstructor(self.name,
[a.negateVariables() for a in self.arguments])
def instantiate(self, context, bindings=None):
if not self.isPolymorphic:
return context, self
if bindings is None:
bindings = {}
newArguments = []
for x in self.arguments:
(context, x) = x.instantiate(context, bindings)
newArguments.append(x)
return (context, TypeConstructor(self.name, newArguments))
def instantiateMutable(self, context, bindings=None):
if not self.isPolymorphic:
return self
if bindings is None:
bindings = {}
newArguments = []
return TypeConstructor(self.name, [x.instantiateMutable(context, bindings)
for x in self.arguments ])
def canonical(self, bindings=None):
if not self.isPolymorphic:
return self
if bindings is None:
bindings = {}
return TypeConstructor(self.name,
[x.canonical(bindings) for x in self.arguments])
class TypeVariable(Type):
def __init__(self, j):
assert isinstance(j, int)
self.v = j
self.isPolymorphic = True
def makeDummyMonomorphic(self, mapping=None):
mapping = mapping if mapping is not None else {}
if self.v not in mapping:
mapping[self.v] = TypeConstructor(f"dummy_type_{len(mapping)}", [])
return mapping[self.v]
def __eq__(self, other):
return isinstance(other, TypeVariable) and self.v == other.v
def __ne__(self, other): return not (self.v == other.v)
def __hash__(self): return self.v
def show(self, _): return "t%d" % self.v
def json(self):
return {"index": self.v}
def returns(self): return self
def isArrow(self): return False
def functionArguments(self): return []
def apply(self, context):
for v, t in context.substitution:
if v == self.v:
return t.apply(context)
return self
def applyMutable(self, context):
s = context.substitution[self.v]
if s is None: return self
new = s.applyMutable(context)
context.substitution[self.v] = new
return new
def occurs(self, v): return v == self.v
def instantiate(self, context, bindings=None):
if bindings is None:
bindings = {}
if self.v in bindings:
return (context, bindings[self.v])
new = TypeVariable(context.nextVariable)
bindings[self.v] = new
context = Context(context.nextVariable + 1, context.substitution)
return (context, new)
def instantiateMutable(self, context, bindings=None):
if bindings is None: bindings = {}
if self.v in bindings: return bindings[self.v]
new = context.makeVariable()
bindings[self.v] = new
return new
def canonical(self, bindings=None):
if bindings is None:
bindings = {}
if self.v in bindings:
return bindings[self.v]
new = TypeVariable(len(bindings))
bindings[self.v] = new
return new
def negateVariables(self):
return TypeVariable(-1 - self.v)
class Context(object):
def __init__(self, nextVariable=0, substitution=[]):
self.nextVariable = nextVariable
self.substitution = substitution
def extend(self, j, t):
return Context(self.nextVariable, [(j, t)] + self.substitution)
def makeVariable(self):
return (Context(self.nextVariable + 1, self.substitution),
TypeVariable(self.nextVariable))
def unify(self, t1, t2):
t1 = t1.apply(self)
t2 = t2.apply(self)
if t1 == t2:
return self
# t1&t2 are not equal
if not t1.isPolymorphic and not t2.isPolymorphic:
raise UnificationFailure(t1, t2)
if isinstance(t1, TypeVariable):
if t2.occurs(t1.v):
raise Occurs()
return self.extend(t1.v, t2)
if isinstance(t2, TypeVariable):
if t1.occurs(t2.v):
raise Occurs()
return self.extend(t2.v, t1)
if t1.name != t2.name:
raise UnificationFailure(t1, t2)
k = self
for x, y in zip(t2.arguments, t1.arguments):
k = k.unify(x, y)
return k
def __str__(self):
return "Context(next = %d, {%s})" % (self.nextVariable, ", ".join(
"t%d ||> %s" % (k, v.apply(self)) for k, v in self.substitution))
def __repr__(self): return str(self)
class MutableContext(object):
def __init__(self):
self.substitution = []
def extend(self,i,t):
assert self.substitution[i] is None
self.substitution[i] = t
def makeVariable(self):
self.substitution.append(None)
return TypeVariable(len(self.substitution) - 1)
def unify(self, t1, t2):
t1 = t1.applyMutable(self)
t2 = t2.applyMutable(self)
if t1 == t2: return
# t1&t2 are not equal
if not t1.isPolymorphic and not t2.isPolymorphic:
raise UnificationFailure(t1, t2)
if isinstance(t1, TypeVariable):
if t2.occurs(t1.v):
raise Occurs()
self.extend(t1.v, t2)
return
if isinstance(t2, TypeVariable):
if t1.occurs(t2.v):
raise Occurs()
self.extend(t2.v, t1)
return
if t1.name != t2.name:
raise UnificationFailure(t1, t2)
for x, y in zip(t2.arguments, t1.arguments):
self.unify(x, y)
Context.EMPTY = Context(0, [])
def canonicalTypes(ts):
bindings = {}
return [t.canonical(bindings) for t in ts]
def instantiateTypes(context, ts):
bindings = {}
newTypes = []
for t in ts:
context, t = t.instantiate(context, bindings)
newTypes.append(t)
return context, newTypes
def baseType(n): return TypeConstructor(n, [])
tint = baseType("int")
treal = baseType("real")
tbool = baseType("bool")
tboolean = tbool # alias
tcharacter = baseType("char")
def tlist(t): return TypeConstructor("list", [t])
def tpair(a, b): return TypeConstructor("pair", [a, b])
def tmaybe(t): return TypeConstructor("maybe", [t])
tstr = tlist(tcharacter)
t0 = TypeVariable(0)
t1 = TypeVariable(1)
t2 = TypeVariable(2)
# regex types
tpregex = baseType("pregex")
ARROW = "->"
def arrow(*arguments):
if len(arguments) == 1:
return arguments[0]
return TypeConstructor(ARROW, [arguments[0], arrow(*arguments[1:])])
def inferArg(tp, tcaller):
ctx, tp = tp.instantiate(Context.EMPTY)
ctx, tcaller = tcaller.instantiate(ctx)
ctx, targ = ctx.makeVariable()
ctx = ctx.unify(tcaller, arrow(targ, tp))
return targ.apply(ctx)
def guess_type(xs):
"""
Return a TypeConstructor corresponding to x's python type.
Raises an exception if the type cannot be guessed.
"""
if all(isinstance(x, bool) for x in xs):
return tbool
elif all(isinstance(x, int) for x in xs):
return tint
elif all(isinstance(x, str) for x in xs):
return tstr
elif all(isinstance(x, list) for x in xs):
return tlist(guess_type([y for ys in xs for y in ys]))
else:
raise ValueError("cannot guess type from {}".format(xs))
def guess_arrow_type(examples):
a = len(examples[0][0])
input_types = []
for n in range(a):
input_types.append(guess_type([xs[n] for xs, _ in examples]))
output_type = guess_type([y for _, y in examples])
return arrow(*(input_types + [output_type]))
def canUnify(t1, t2):
k = MutableContext()
t1 = t1.instantiateMutable(k)
t2 = t2.instantiateMutable(k)
try:
k.unify(t1, t2)
return True
except UnificationFailure: return False