Fraser-Greenlee
add dreamcoder codebase
e1c1753
raw
history blame
52.9 kB
from dreamcoder.grammar import *
epsilon = 0.001
def instantiate(context, environment, tp):
bindings = {}
context, tp = tp.instantiate(context, bindings)
newEnvironment = {}
for i,ti in environment.items():
context,newEnvironment[i] = ti.instantiate(context, bindings)
return context, newEnvironment, tp
def unify(*environmentsAndTypes):
k = Context.EMPTY
e = {}
k,t = k.makeVariable()
for e_,t_ in environmentsAndTypes:
k, e_, t_ = instantiate(k, e_, t_)
k = k.unify(t,t_)
for i,ti in e_.items():
if i not in e: e[i] = ti
else: k = k.unify(e[i], ti)
return {i: ti.apply(k) for i,ti in e.items() }, t.apply(k)
class Union(Program):
def __init__(self, elements, canBeEmpty=False):
self.elements = frozenset(elements)
if not canBeEmpty: assert len(self.elements) > 1
@property
def isUnion(self): return True
def __eq__(self,o):
return isinstance(o,Union) and self.elements == o.elements
def __hash__(self): return hash(self.elements)
def __str__(self):
return "{%s}"%(", ".join(map(str,list(self.elements))))
def show(self, isFunction):
return str(self)
def __repr__(self): return str(self)
def __iter__(self): return iter(self.elements)
class VersionTable():
def __init__(self, typed=True, identity=True, factored=False):
self.factored = factored
self.identity = identity
self.typed = typed
self.debug = False
if self.debug:
print("WARNING: running version spaces in debug mode. Will be substantially slower.")
self.expressions = []
self.recursiveTable = []
self.substitutionTable = {}
self.expression2index = {}
self.maximumShift = []
# Table containing (minimum cost, set of minimum cost programs)
self.inhabitantTable = []
# Table containing (minimum cost, set of minimum cost programs NOT starting w/ abstraction)
self.functionInhabitantTable = []
self.superCache = {}
self.overlapTable = {}
self.universe = self.incorporate(Primitive("U",t0,None))
self.empty = self.incorporate(Union([], canBeEmpty=True))
def __len__(self): return len(self.expressions)
def clearOverlapTable(self):
self.overlapTable = {}
def visualize(self, j):
from graphviz import Digraph
g = Digraph()
visited = set()
def walk(i):
if i in visited: return
if i == self.universe:
g.node(str(i), 'universe')
elif i == self.empty:
g.node(str(i), 'nil')
else:
l = self.expressions[i]
if l.isIndex or l.isPrimitive or l.isInvented:
g.node(str(i), str(l))
elif l.isAbstraction:
g.node(str(i), "lambda")
walk(l.body)
g.edge(str(i), str(l.body))
elif l.isApplication:
g.node(str(i), "@")
walk(l.f)
walk(l.x)
g.edge(str(i), str(l.f), label='f')
g.edge(str(i), str(l.x), label='x')
elif l.isUnion:
g.node(str(i), "U")
for c in l:
walk(c)
g.edge(str(i), str(c))
else:
assert False
visited.add(i)
walk(j)
g.render(view=True)
def branchingFactor(self,j):
l = self.expressions[j]
if l.isApplication: return max(self.branchingFactor(l.f),
self.branchingFactor(l.x))
if l.isUnion: return max([len(l.elements)] + [self.branchingFactor(e) for e in l ])
if l.isAbstraction: return self.branchingFactor(l.body)
return 0
def intention(self,j, isFunction=False):
l = self.expressions[j]
if l.isIndex or l.isPrimitive or l.isInvented: return l
if l.isAbstraction: return Abstraction(self.intention(l.body))
if l.isApplication: return Application(self.intention(l.f),
self.intention(l.x))
if l.isUnion: return Union(self.intention(e)
for e in l )
assert False
def walk(self,j):
"""yields every subversion space of j"""
visited = set()
def r(n):
if n in visited: return
visited.add(n)
l = self.expressions[n]
yield l
if l.isApplication:
yield from r(l.f)
yield from r(l.x)
if l.isAbstraction:
yield from r(l.body)
if l.isUnion:
for e in l:
yield from r(e)
yield from r(j)
def incorporate(self,p):
#assert isinstance(p,Union)# or p.wellTyped()
if p.isIndex or p.isPrimitive or p.isInvented:
pass
elif p.isAbstraction:
p = Abstraction(self.incorporate(p.body))
elif p.isApplication:
p = Application(self.incorporate(p.f),
self.incorporate(p.x))
elif p.isUnion:
if len(p.elements) > 0:
p = Union([self.incorporate(e) for e in p ])
else: assert False
j = self._incorporate(p)
return j
def _incorporate(self,p):
if p in self.expression2index: return self.expression2index[p]
j = len(self.expressions)
self.expressions.append(p)
self.expression2index[p] = j
self.recursiveTable.append(None)
self.inhabitantTable.append(None)
self.functionInhabitantTable.append(None)
return j
def extract(self,j):
l = self.expressions[j]
if l.isAbstraction:
for b in self.extract(l.body):
yield Abstraction(b)
elif l.isApplication:
for f in self.extract(l.f):
for x in self.extract(l.x):
yield Application(f,x)
elif l.isIndex or l.isPrimitive or l.isInvented:
yield l
elif l.isUnion:
for e in l:
yield from self.extract(e)
else: assert False
def reachable(self, heads):
visited = set()
def visit(j):
if j in visited: return
visited.add(j)
l = self.expressions[j]
if l.isUnion:
for e in l:
visit(e)
elif l.isAbstraction: visit(l.body)
elif l.isApplication:
visit(l.f)
visit(l.x)
for h in heads:
visit(h)
return visited
def size(self,j):
l = self.expressions[j]
if l.isApplication:
return self.size(l.f) + self.size(l.x)
elif l.isAbstraction:
return self.size(l.body)
elif l.isUnion:
return sum(self.size(e) for e in l )
else:
return 1
def union(self,elements):
if self.universe in elements: return self.universe
_e = []
for e in elements:
if self.expressions[e].isUnion:
for j in self.expressions[e]:
_e.append(j)
elif e != self.empty:
_e.append(e)
elements = frozenset(_e)
if len(elements) == 0: return self.empty
if len(elements) == 1: return next(iter(elements))
return self._incorporate(Union(elements))
def apply(self,f,x):
if f == self.empty: return f
if x == self.empty: return x
return self._incorporate(Application(f,x))
def abstract(self,b):
if b == self.empty: return self.empty
return self._incorporate(Abstraction(b))
def index(self,i):
return self._incorporate(Index(i))
def intersection(self,a,b):
if a == self.empty or b == self.empty: return self.empty
if a == self.universe: return b
if b == self.universe: return a
if a == b: return a
x = self.expressions[a]
y = self.expressions[b]
if x.isAbstraction and y.isAbstraction:
return self.abstract(self.intersection(x.body,y.body))
if x.isApplication and y.isApplication:
return self.apply(self.intersection(x.f,y.f),
self.intersection(x.x,y.x))
if x.isUnion:
if y.isUnion:
return self.union([ self.intersection(x_,y_)
for x_ in x
for y_ in y ])
return self.union([ self.intersection(x_, b)
for x_ in x ])
if y.isUnion:
return self.union([ self.intersection(a, y_)
for y_ in y ])
return self.empty
def haveOverlap(self,a,b):
if a == self.empty or b == self.empty: return False
if a == self.universe: return True
if b == self.universe: return True
if a == b: return True
if a in self.overlapTable:
if b in self.overlapTable[a]:
return self.overlapTable[a][b]
else: self.overlapTable[a] = {}
x = self.expressions[a]
y = self.expressions[b]
if x.isAbstraction and y.isAbstraction:
overlap = self.haveOverlap(x.body,y.body)
elif x.isApplication and y.isApplication:
overlap = self.haveOverlap(x.f,y.f) and \
self.haveOverlap(x.x,y.x)
elif x.isUnion:
if y.isUnion:
overlap = any( self.haveOverlap(x_,y_)
for x_ in x
for y_ in y )
overlap = any( self.haveOverlap(x_, b)
for x_ in x )
elif y.isUnion:
overlap = any( self.haveOverlap(a, y_)
for y_ in y )
else:
overlap = False
self.overlapTable[a][b] = overlap
return overlap
def minimalInhabitants(self,j):
"""Returns (minimal size, set of singleton version spaces)"""
assert isinstance(j,int)
if self.inhabitantTable[j] is not None: return self.inhabitantTable[j]
e = self.expressions[j]
if e.isAbstraction:
cost, members = self.minimalInhabitants(e.body)
cost = cost + epsilon
members = {self.abstract(m) for m in members}
elif e.isApplication:
fc, fm = self.minimalFunctionInhabitants(e.f)
xc, xm = self.minimalInhabitants(e.x)
cost = fc + xc + epsilon
members = {self.apply(f_,x_)
for f_ in fm for x_ in xm }
elif e.isUnion:
children = [self.minimalInhabitants(z)
for z in e ]
cost = min(c for c,_ in children)
members = {zp
for c,z in children
if c == cost
for zp in z }
else:
assert e.isIndex or e.isInvented or e.isPrimitive
cost = 1
members = {j}
# if len(members) > 1:
# for m in members: break
# members = {m}
self.inhabitantTable[j] = (cost, members)
return cost, members
def minimalFunctionInhabitants(self,j):
"""Returns (minimal size, set of singleton version spaces)"""
assert isinstance(j,int)
if self.functionInhabitantTable[j] is not None: return self.functionInhabitantTable[j]
e = self.expressions[j]
if e.isAbstraction:
cost = POSITIVEINFINITY
members = set()
elif e.isApplication:
fc, fm = self.minimalFunctionInhabitants(e.f)
xc, xm = self.minimalInhabitants(e.x)
cost = fc + xc + epsilon
members = {self.apply(f_,x_)
for f_ in fm for x_ in xm }
elif e.isUnion:
children = [self.minimalFunctionInhabitants(z)
for z in e ]
cost = min(c for c,_ in children)
members = {zp
for c,z in children
if c == cost
for zp in z }
else:
assert e.isIndex or e.isInvented or e.isPrimitive
cost = 1
members = {j}
# if len(members) > 1:
# for m in members: break
# members = {m}
self.functionInhabitantTable[j] = (cost, members)
return cost, members
def shiftFree(self,j,n,c=0):
if n == 0: return j
l = self.expressions[j]
if l.isUnion:
return self.union([ self.shiftFree(e,n,c)
for e in l ])
if l.isApplication:
return self.apply(self.shiftFree(l.f,n,c),
self.shiftFree(l.x,n,c))
if l.isAbstraction:
return self.abstract(self.shiftFree(l.body,n,c+1))
if l.isIndex:
if l.i < c: return j
if l.i >= n + c: return self.index(l.i - n)
return self.empty
assert l.isPrimitive or l.isInvented
return j
def substitutions(self,j):
if self.typed:
for (v,_),b in self._substitutions(j,0).items():
yield v,b
else:
yield from self._substitutions(j,0).items()
def _substitutions(self,j,n):
if (j,n) in self.substitutionTable: return self.substitutionTable[(j,n)]
s = self.shiftFree(j,n)
if self.debug:
assert set(self.extract(s)) == set( e.shift(-n)
for e in self.extract(j)
if all( f >= n for f in e.freeVariables() )),\
"shiftFree_%d: %s"%(n,set(self.extract(s)))
if s == self.empty: m = {}
else:
if self.typed:
principalType = self.infer(s)
if principalType == self.bottom:
print(self.infer(j))
print(list(self.extract(j)))
print(list(self.extract(s)))
assert False
m = {(s, self.infer(s)[1].canonical()): self.index(n)}
else:
m = {s: self.index(n)}
l = self.expressions[j]
if l.isPrimitive or l.isInvented:
m[(self.universe,t0) if self.typed else self.universe] = j
elif l.isIndex:
m[(self.universe,t0) if self.typed else self.universe] = \
j if l.i < n else self.index(l.i + 1)
elif l.isAbstraction:
for v,b in self._substitutions(l.body, n + 1).items():
m[v] = self.abstract(b)
elif l.isApplication and not self.factored:
newMapping = {}
fm = self._substitutions(l.f,n)
xm = self._substitutions(l.x,n)
for v1,f in fm.items():
if self.typed: v1,nType1 = v1
for v2,x in xm.items():
if self.typed: v2,nType2 = v2
a = self.apply(f,x)
# See if the types that they assigned to $n are consistent
if self.typed:
if self.infer(a) == self.bottom: continue
try:
nType = canonicalUnification(nType1, nType2,
self.infer(a)[0].get(n,t0))
except UnificationFailure:
continue
v = self.intersection(v1,v2)
if v == self.empty: continue
if self.typed and self.infer(v) == self.bottom: continue
key = (v,nType) if self.typed else v
if key in newMapping:
newMapping[key].append(a)
else:
newMapping[key] = [a]
for v in newMapping:
newMapping[v] = self.union(newMapping[v])
newMapping.update(m)
m = newMapping
# print(f"substitutions: |{len(fm)}|x|{len(xm)}| = {len(m)}\t{len(m) <= len(fm)+len(xm)}")
elif l.isApplication and self.factored:
newMapping = {}
fm = self._substitutions(l.f,n)
xm = self._substitutions(l.x,n)
for v1,f in fm.items():
if self.typed: v1,nType1 = v1
for v2,x in xm.items():
if self.typed: v2,nType2 = v2
v = self.intersection(v1,v2)
if v == self.empty: continue
if v in newMapping:
newMapping[v] = ({f} | newMapping[v][0],
{x} | newMapping[v][1])
else:
newMapping[v] = ({f},{x})
for v,(fs,xs) in newMapping.items():
fs = self.union(list(fs))
xs = self.union(list(xs))
m[v] = self.apply(fs,xs)
# print(f"substitutions: |{len(fm)}|x|{len(xm)}| = {len(m)}\t{len(m) <= len(fm)+len(xm)}")
elif l.isUnion:
newMapping = {}
for e in l:
for v,b in self._substitutions(e,n).items():
if v in newMapping:
newMapping[v].append(b)
else:
newMapping[v] = [b]
for v in newMapping:
newMapping[v] = self.union(newMapping[v])
newMapping.update(m)
m = newMapping
else: assert False
self.substitutionTable[(j,n)] = m
return m
def inversion(self,j):
i = self.union([self.apply(self.abstract(b),v)
for v,b in self.substitutions(j)
if v != self.universe])
if self.debug and self.typed:
if not (self.infer(i) == self.infer(j)):
print("inversion produced space with a different type!")
print("the original type was",self.infer(j))
print("the type of the rewritten expressions is",self.infer(i))
print("the original extension was")
n = None
for e in self.extract(j):
print(e, e.infer())
# print(f"\t{e.betaNormalForm()} : {e.betaNormalForm().infer()}")
assert n is None or e.betaNormalForm() == n
n = e.betaNormalForm()
print("the rewritten extension is")
for e in self.extract(i):
print(e, e.infer())
# print(f"\t{e.betaNormalForm()} : {e.betaNormalForm().infer()}")
assert n is None or e.betaNormalForm() == n
assert self.infer(i) == self.infer(j)
assert False
return i
def recursiveInversion(self,j):
if self.recursiveTable[j] is not None: return self.recursiveTable[j]
l = self.expressions[j]
if l.isUnion:
return self.union([self.recursiveInversion(e) for e in l ])
t = [self.apply(self.abstract(b),v)
for v,b in self.substitutions(j)
if v != self.universe and (self.identity or b != self.index(0))]
if self.debug and self.typed:
ru = self.union(t)
if not (self.infer(ru) == self.infer(j)):
print("inversion produced space with a different type!")
print("the original type was",self.infer(j))
print("the type of the rewritten expressions is",self.infer(ru))
print("the original extension was")
n = None
for e in self.extract(j):
print(e, e.infer())
# print(f"\t{e.betaNormalForm()} : {e.betaNormalForm().infer()}")
assert n is None or e.betaNormalForm() == n
n = e.betaNormalForm()
print("the rewritten extension is")
for e in self.extract(ru):
print(e, e.infer())
# print(f"\t{e.betaNormalForm()} : {e.betaNormalForm().infer()}")
assert n is None or e.betaNormalForm() == n
assert self.infer(ru) == self.infer(j)
if l.isApplication:
t.append(self.apply(self.recursiveInversion(l.f),l.x))
t.append(self.apply(l.f,self.recursiveInversion(l.x)))
elif l.isAbstraction:
t.append(self.abstract(self.recursiveInversion(l.body)))
ru = self.union(t)
self.recursiveTable[j] = ru
return ru
def repeatedExpansion(self,j,n):
spaces = [j]
for _ in range(n):
spaces.append(self.recursiveInversion(spaces[-1]))
return spaces
def rewriteReachable(self,heads,n):
vertices = self.reachable(heads)
spaces = {v: self.repeatedExpansion(v,n)
for v in vertices }
return spaces
def properVersionSpace(self, j, n):
return self.union(self.repeatedExpansion(j, n))
def superVersionSpace(self, j, n):
"""Construct decorated tree and then merge version spaces with subtrees via union operator"""
if j in self.superCache: return self.superCache[j]
spaces = self.rewriteReachable({j}, n)
def superSpace(i):
assert i in spaces
e = self.expressions[i]
components = [i] + spaces[i]
if e.isIndex or e.isPrimitive or e.isInvented:
pass
elif e.isAbstraction:
components.append(self.abstract(superSpace(e.body)))
elif e.isApplication:
components.append(self.apply(superSpace(e.f), superSpace(e.x)))
elif e.isUnion: assert False
else: assert False
return self.union(components)
self.superCache[j] = superSpace(j)
return self.superCache[j]
def loadEquivalences(self, g, spaces):
versionClasses = [None]*len(self.expressions)
def extract(j):
if versionClasses[j] is not None:
return versionClasses[j]
l = self.expressions[j]
if l.isAbstraction:
ks = g.setOfClasses(g.abstractClass(b)
for b in extract(l.body))
elif l.isApplication:
fs = extract(l.f)
xs = extract(l.x)
ks = g.setOfClasses(g.applyClass(f,x)
for x in xs for f in fs )
elif l.isUnion:
ks = g.setOfClasses(e for u in l for e in extract(u))
else:
ks = g.setOfClasses({g.incorporate(l)})
versionClasses[j] = ks
return ks
N = len(next(iter(spaces.values())))
vertices = list(sorted(spaces.keys(), key=lambda v: self.size(v)))
# maps from a vertex to a map from types to classes
# the idea is to only enforceable equivalence between terms of the same type
typedClassesOfVertex = {v: {} for v in vertices }
for n in range(N):
# print(f"Processing rewrites {n} steps away from original expressions...")
for v in vertices:
expressions = list(self.extract(v))
assert len(expressions) == 1
expression = expressions[0]
k = g.incorporate(expression)
if k is None: continue
t0 = g.typeOfClass[k]
if t0 not in typedClassesOfVertex[v]:
typedClassesOfVertex[v][t0] = k
extracted = list(extract(spaces[v][n]))
for e in extracted:
t = g.typeOfClass[e]
if t in typedClassesOfVertex[v]:
g.makeEquivalent(typedClassesOfVertex[v][t],e)
else:
typedClassesOfVertex[v][e] = e
def bestInventions(self, versions, bs=25):
"""versions: [[version index]]"""
"""bs: beam size"""
"""returns: list of (indices to) candidates"""
import gc
def nontrivial(proposal):
primitives = 0
collisions = 0
indices = set()
for d, tree in proposal.walk():
if tree.isPrimitive or tree.isInvented: primitives += 1
elif tree.isIndex:
i = tree.i - d
if i in indices: collisions += 1
indices.add(i)
return primitives > 1 or (primitives == 1 and collisions > 0)
with timing("calculated candidates from version space"):
candidates = [{j
for k in self.reachable(hs)
for _,js in [self.minimalInhabitants(k), self.minimalFunctionInhabitants(k)]
for j in js }
for hs in versions]
from collections import Counter
candidates = Counter(k for ks in candidates for k in ks)
candidates = {k for k,f in candidates.items() if f >= 2 and nontrivial(next(self.extract(k))) }
# candidates = [k for k in candidates if next(self.extract(k)).isBetaLong()]
eprint(len(candidates),"candidates from version space")
# Calculate the number of free variables for each candidate invention
# This is important because, if a candidate has free variables,
# then whenever we use it we will have to apply it to those free variables;
# thus using a candidate with free variables is more expensive
candidateCost = {k: len(set(next(self.extract(k)).freeVariables())) + 1
for k in candidates }
inhabitTable = self.inhabitantTable
functionTable = self.functionInhabitantTable
class B():
def __init__(self, j):
cost, inhabitants = inhabitTable[j]
functionCost, functionInhabitants = functionTable[j]
self.relativeCost = {inhabitant: candidateCost[inhabitant]
for inhabitant in inhabitants
if inhabitant in candidates}
self.relativeFunctionCost = {inhabitant: candidateCost[inhabitant]
# INTENTIONALLY, do not use function inhabitants
for inhabitant in inhabitants
if inhabitant in candidates}
self.defaultCost = cost
self.defaultFunctionCost = functionCost
@property
def domain(self):
return set(self.relativeCost.keys())
@property
def functionDomain(self):
return set(self.relativeFunctionCost.keys())
def restrict(self):
if len(self.relativeCost) > bs:
self.relativeCost = dict(sorted(self.relativeCost.items(),
key=lambda rk: rk[1])[:bs])
if len(self.relativeFunctionCost) > bs:
self.relativeFunctionCost = dict(sorted(self.relativeFunctionCost.items(),
key=lambda rk: rk[1])[:bs])
def getCost(self, given):
return self.relativeCost.get(given, self.defaultCost)
def getFunctionCost(self, given):
return self.relativeFunctionCost.get(given, self.defaultFunctionCost)
def relax(self, given, cost):
self.relativeCost[given] = min(cost,
self.getCost(given))
def relaxFunction(self, given, cost):
self.relativeFunctionCost[given] = min(cost,
self.getFunctionCost(given))
def unobject(self):
return {'relativeCost': self.relativeCost, 'defaultCost': self.defaultCost,
'relativeFunctionCost': self.relativeFunctionCost, 'defaultFunctionCost': self.defaultFunctionCost}
beamTable = [None]*len(self.expressions)
def costs(j):
if beamTable[j] is not None:
return beamTable[j]
beamTable[j] = B(j)
e = self.expressions[j]
if e.isIndex or e.isPrimitive or e.isInvented:
pass
elif e.isAbstraction:
b = costs(e.body)
for i,c in b.relativeCost.items():
beamTable[j].relax(i, c + epsilon)
elif e.isApplication:
f = costs(e.f)
x = costs(e.x)
for i in f.functionDomain | x.domain:
beamTable[j].relax(i, f.getFunctionCost(i) + x.getCost(i) + epsilon)
beamTable[j].relaxFunction(i, f.getFunctionCost(i) + x.getCost(i) + epsilon)
elif e.isUnion:
for z in e:
cz = costs(z)
for i,c in cz.relativeCost.items(): beamTable[j].relax(i, c)
for i,c in cz.relativeFunctionCost.items(): beamTable[j].relaxFunction(i, c)
else: assert False
beamTable[j].restrict()
return beamTable[j]
with timing("beamed version spaces"):
beams = parallelMap(numberOfCPUs(),
lambda hs: [ costs(h).unobject() for h in hs ],
versions,
memorySensitive=True,
chunksize=1,
maxtasksperchild=1)
# This can get pretty memory intensive - clean up the garbage
beamTable = None
gc.collect()
candidates = {d
for _bs in beams
for b in _bs
for d in b['relativeCost'].keys() }
def score(candidate):
return sum(min(min(b['relativeCost'].get(candidate, b['defaultCost']),
b['relativeFunctionCost'].get(candidate, b['defaultFunctionCost']))
for b in _bs )
for _bs in beams )
candidates = sorted(candidates, key=score)
return candidates
def rewriteWithInvention(self, i, js):
"""Rewrites list of indices in beta long form using invention"""
self.clearOverlapTable()
class RW():
"""rewritten cost/expression either as a function or argument"""
def __init__(self, f,fc,a,ac):
assert not (fc < ac)
self.f, self.fc, self.a, self.ac = f,fc,a,ac
_i = list(self.extract(i))
assert len(_i) == 1
_i = _i[0]
table = {}
def rewrite(j):
if j in table: return table[j]
e = self.expressions[j]
if self.haveOverlap(i, j): r = RW(fc=1,ac=1,
f=_i,a=_i)
elif e.isPrimitive or e.isInvented or e.isIndex:
r = RW(fc=1,ac=1,
f=e,a=e)
elif e.isApplication:
f = rewrite(e.f)
x = rewrite(e.x)
cost = f.fc + x.ac + epsilon
ep = Application(f.f, x.a) if cost < POSITIVEINFINITY else None
r = RW(fc=cost, ac=cost,
f=ep, a=ep)
elif e.isAbstraction:
b = rewrite(e.body)
cost = b.ac + epsilon
ep = Abstraction(b.a) if cost < POSITIVEINFINITY else None
r = RW(f=None, fc=POSITIVEINFINITY,
a=ep, ac=cost)
elif e.isUnion:
children = [rewrite(z) for z in e ]
f,fc = min(( (child.f, child.fc) for child in children ),
key=cindex(1))
a,ac = min(( (child.a, child.ac) for child in children ),
key=cindex(1))
r = RW(f=f,fc=fc,
a=a,ac=ac)
else: assert False
table[j] = r
return r
js = [ rewrite(j).a for j in js ]
self.clearOverlapTable()
return js
def addInventionToGrammar(self, candidate, g0, frontiers,
pseudoCounts=1.):
candidateSource = next(self.extract(candidate))
v = RewriteWithInventionVisitor(candidateSource)
invention = v.invention
rewriteMapping = list({e.program
for f in frontiers
for e in f })
spaces = [self.superCache[self.incorporate(program)]
for program in rewriteMapping ]
rewriteMapping = dict(zip(rewriteMapping,
self.rewriteWithInvention(candidate, spaces)))
def tryRewrite(program, request=None):
rw = v.execute(rewriteMapping[program], request=request)
# print(f"Rewriting {program} ({rewriteMapping[program]}) : rw={rw}")
# print("slow-motion:")
# try:
# i = rewriteMapping[program].visit(v)
# print(f"\ti={i}")
# l = EtaLongVisitor().execute(i)
# print(f"\tl={l}")
# except Exception as e: print(e)
return rw or program
frontiers = [Frontier([FrontierEntry(program=tryRewrite(e.program, request=f.task.request),
logLikelihood=e.logLikelihood,
logPrior=0.)
for e in f ],
f.task)
for f in frontiers ]
# print(invention)
# for f in frontiers: print(f.entries[0].program)
# print()
# print()
g = Grammar.uniform([invention] + g0.primitives, continuationType=g0.continuationType).\
insideOutside(frontiers,
pseudoCounts=pseudoCounts)
frontiers = [g.rescoreFrontier(f) for f in frontiers]
return g, frontiers
class CloseInventionVisitor():
"""normalize free variables - e.g., if $1 & $3 occur free then rename them to $0, $1
then wrap in enough lambdas so that there are no free variables and finally wrap in invention"""
def __init__(self, p):
self.p = p
freeVariables = list(sorted(set(p.freeVariables())))
self.mapping = {fv: j for j,fv in enumerate(freeVariables) }
def index(self, e, d):
if e.i - d in self.mapping:
return Index(self.mapping[e.i - d] + d)
return e
def abstraction(self, e, d):
return Abstraction(e.body.visit(self, d + 1))
def application(self, e, d):
return Application(e.f.visit(self, d),
e.x.visit(self, d))
def primitive(self, e, d): return e
def invented(self, e, d): return e
def execute(self):
normed = self.p.visit(self, 0)
closed = normed
for _ in range(len(self.mapping)):
closed = Abstraction(closed)
return Invented(closed)
class RewriteWithInventionVisitor():
def __init__(self, p):
v = CloseInventionVisitor(p)
self.original = p
self.mapping = { j: fv for fv, j in v.mapping.items() }
self.invention = v.execute()
self.appliedInvention = self.invention
for j in range(len(self.mapping) - 1, -1, -1):
self.appliedInvention = Application(self.appliedInvention, Index(self.mapping[j]))
def tryRewrite(self, e):
if e == self.original:
return self.appliedInvention
return None
def index(self, e): return e
def primitive(self, e): return e
def invented(self, e): return e
def abstraction(self, e):
return self.tryRewrite(e) or Abstraction(e.body.visit(self))
def application(self, e):
return self.tryRewrite(e) or Application(e.f.visit(self),
e.x.visit(self))
def execute(self, e, request=None):
try:
i = e.visit(self)
l = EtaLongVisitor(request=request).execute(i)
return l
except (UnificationFailure, EtaExpandFailure):
return None
def induceGrammar_Beta(g0, frontiers, _=None,
pseudoCounts=1.,
a=3,
aic=1.,
topK=2,
topI=50,
structurePenalty=1.,
CPUs=1):
"""grammar induction using only version spaces"""
from dreamcoder.fragmentUtilities import primitiveSize
import gc
originalFrontiers = frontiers
frontiers = [frontier for frontier in frontiers if not frontier.empty]
eprint("Inducing a grammar from", len(frontiers), "frontiers")
arity = a
def restrictFrontiers():
return parallelMap(1,#CPUs,
lambda f: g0.rescoreFrontier(f).topK(topK),
frontiers,
memorySensitive=True,
chunksize=1,
maxtasksperchild=1)
restrictedFrontiers = restrictFrontiers()
def objective(g, fs):
ll = sum(g.frontierMDL(f) for f in fs )
sp = structurePenalty * sum(primitiveSize(p) for p in g.primitives)
return ll - sp - aic*len(g.productions)
v = None
def scoreCandidate(candidate, currentFrontiers, currentGrammar):
try:
newGrammar, newFrontiers = v.addInventionToGrammar(candidate, currentGrammar, currentFrontiers,
pseudoCounts=pseudoCounts)
except InferenceFailure:
# And this can occur if the candidate is not well typed:
# it is expected that this can occur;
# in practice, it is more efficient to filter out the ill typed terms,
# then it is to construct the version spaces so that they only contain well typed terms.
return NEGATIVEINFINITY
o = objective(newGrammar, newFrontiers)
#eprint("+", end='')
eprint(o,'\t',newGrammar.primitives[0],':',newGrammar.primitives[0].tp)
# eprint(next(v.extract(candidate)))
# for f in newFrontiers:
# for e in f:
# eprint(e.program)
return o
with timing("Estimated initial grammar production probabilities"):
g0 = g0.insideOutside(restrictedFrontiers, pseudoCounts)
oldScore = objective(g0, restrictedFrontiers)
eprint("Starting grammar induction score",oldScore)
while True:
v = VersionTable(typed=False, identity=False)
with timing("constructed %d-step version spaces"%arity):
versions = [[v.superVersionSpace(v.incorporate(e.program), arity) for e in f]
for f in restrictedFrontiers ]
eprint("Enumerated %d distinct version spaces"%len(v.expressions))
# Bigger beam because I feel like it
candidates = v.bestInventions(versions, bs=3*topI)[:topI]
eprint("Only considering the top %d candidates"%len(candidates))
# Clean caches that are no longer needed
v.recursiveTable = [None]*len(v)
v.inhabitantTable = [None]*len(v)
v.functionInhabitantTable = [None]*len(v)
v.substitutionTable = {}
gc.collect()
with timing("scored the candidate inventions"):
scoredCandidates = parallelMap(CPUs,
lambda candidate: \
(candidate, scoreCandidate(candidate, restrictedFrontiers, g0)),
candidates,
memorySensitive=True,
chunksize=1,
maxtasksperchild=1)
if len(scoredCandidates) > 0:
bestNew, bestScore = max(scoredCandidates, key=lambda sc: sc[1])
if len(scoredCandidates) == 0 or bestScore < oldScore:
eprint("No improvement possible.")
# eprint("Runner-up:")
# eprint(next(v.extract(bestNew)))
# Return all of the frontiers, which have now been rewritten to use the
# new fragments
frontiers = {f.task: f for f in frontiers}
frontiers = [frontiers.get(f.task, f)
for f in originalFrontiers]
return g0, frontiers
# This is subtle: at this point we have not calculated
# versions bases for programs outside the restricted
# frontiers; but here we are rewriting the entire frontier in
# terms of the new primitive. So we have to recalculate
# version spaces for everything.
with timing("constructed versions bases for entire frontiers"):
for f in frontiers:
for e in f:
v.superVersionSpace(v.incorporate(e.program), arity)
newGrammar, newFrontiers = v.addInventionToGrammar(bestNew, g0, frontiers,
pseudoCounts=pseudoCounts)
eprint("Improved score to", bestScore, "(dS =", bestScore-oldScore, ") w/ invention",newGrammar.primitives[0],":",newGrammar.primitives[0].infer())
oldScore = bestScore
for f in newFrontiers:
eprint(f.summarizeFull())
g0, frontiers = newGrammar, newFrontiers
restrictedFrontiers = restrictFrontiers()
def testTyping(p):
v = VersionTable()
j = v.incorporate(p)
wellTyped = set(v.extract(v.inversion(j)))
print(len(wellTyped))
v = VersionTable(typed=False)
j = v.incorporate(p)
arbitrary = set(v.extract(v.recursiveInversion(v.recursiveInversion(v.recursiveInversion(j)))))
print(len(arbitrary))
assert wellTyped <= arbitrary
assert wellTyped == {e
for e in arbitrary if e.wellTyped() }
assert all( e.wellTyped() for e in wellTyped )
import sys
sys.exit()
def testSharing(projection=2):
source = "(+ 1 1)"
N = 4 # maximum number of refactorings
L = 6 # maximum size of expression
# def literalSize(v,j):
# hs = []
# vp = VersionTable(typed=False)
# for i in v.extract(j):
# hs.append(vp.incorporate(i))
# return len(set(vp.reachable(hs)))
# smart = {}
# dumb = {}
# for l in range(L):
# for n in range(N):
# v = VersionTable(typed=False)
# j = v.properVersionSpace(v.incorporate(Program.parse(source)),n)
# smart[(l,n)] = len(v.reachable({j}))
# dumb[(l,n)] = literalSize(v,j)
# print(f"vs l={l}\tn={n} sz={smart[(l,n)]}")
# print(f"db l={l}\tn={n} sz={dumb[(l,n)]}")
# # increase the size of the expression
# source = "(+ 1 %s)"%source
# print("Increased size to",l + 1)
import numpy as np
distinct_programs = np.zeros((L,N))
version_size = np.zeros((L,N))
program_memory = np.zeros((L,N))
version_size[0,1] = 24
distinct_programs[0,1] = 8
program_memory[0,1] = 28
version_size[0,2] = 155
distinct_programs[0,2] = 63
program_memory[0,2] = 201
version_size[0,3] = 1126
distinct_programs[0,3] = 534
program_memory[0,3] = 1593
version_size[1,1] = 48
distinct_programs[1,1] = 24
program_memory[1,1] = 78
version_size[1,2] = 526
distinct_programs[1,2] = 457
program_memory[1,2] = 1467
version_size[1,3] = 6639
distinct_programs[1,3] = 8146
program_memory[1,3] = 26458
version_size[2,1] = 74
distinct_programs[2,1] = 57
program_memory[2,1] = 193
version_size[2,2] = 1095
distinct_programs[2,2] = 2234
program_memory[2,2] = 7616
version_size[2,3] = 19633
distinct_programs[2,3] = 74571
program_memory[2,3] = 260865
version_size[3,1] = 101
distinct_programs[3,1] = 123
program_memory[3,1] = 438
version_size[3,2] = 1751
distinct_programs[3,2] = 9209
program_memory[3,2] = 32931
version_size[3,3] = 38781
distinct_programs[3,3] = 540315
program_memory[3,3] = 1984171
version_size[4,1] = 129
distinct_programs[4,1] = 254
program_memory[4,1] = 942
version_size[4,2] = 2488
distinct_programs[4,2] = 35011
program_memory[4,2] = 129513
version_size[4,3] = 63271
distinct_programs[4,3] = 3477046
program_memory[4,3] = 13179440
version_size[5,1] = 158
distinct_programs[5,1] = 514
program_memory[5,1] = 1962
version_size[5,2] = 3308
distinct_programs[5,2] = 128319
program_memory[5,2] = 485862
version_size[5,3] = 93400
distinct_programs[5,3] = 21042591
program_memory[5,3] = 81433633
import matplotlib.pyplot as plot
from matplotlib import rcParams
rcParams.update({'figure.autolayout': True})
if projection == 3:
f = plot.figure()
a = f.add_subplot(111, projection='3d')
X = np.arange(0,N)
Y = np.arange(0,L)
X,Y = np.meshgrid(X,Y)
Z = np.zeros((L,N))
for l in range(L):
for n in range(N):
Z[l,n] = smart[(l,n)]
a.plot_surface(X,
Y,
np.log10(Z),
color='blue',
alpha=0.3)
for l in range(L):
for n in range(N):
Z[l,n] = dumb[(l,n)]
a.plot_surface(X,
Y,
np.log10(Z),
color='red',
alpha=0.3)
else:
plot.figure(figsize=(3.5,3))
plot.tight_layout()
logarithmic = False
if logarithmic: P = plot.semilogy
else: P = plot.plot
for n in range(1, 2):
xs = np.array(range(L))*2 + 3
P(xs,
[version_size[l,n] for l in range(L) ],
'purple',
label=None if n > 1 else 'version space')
P(xs,
[program_memory[l,n] for l in range(L) ],
'green',
label=None if n > 1 else 'no version space')
if n > 1: dy = 1
if n == 1 and logarithmic: dy = 0.6
if n == 1 and not logarithmic: dy = 1
# plot.text(xs[-1], dy*version_size[L - 1,n], "n=%d"%n)
# plot.text(xs[-1], dy*program_memory[L - 1,n], "n=%d"%n)
plot.legend()
plot.xlabel('Size of program being refactored')
plot.ylabel('Size of VS (purple) or progs (green)')
plot.xticks(list(xs) + [xs[-1] + 2],
[ str(x) if j == 0 or j == L - 1 else ''
for j,x in enumerate(list(xs) + [xs[-1] + 2])])
# if not logarithmic:
# plot.ylim([0,100000])
plot.savefig('/tmp/vs.eps')
assert False
if __name__ == "__main__":
from dreamcoder.domains.arithmetic.arithmeticPrimitives import *
from dreamcoder.domains.list.listPrimitives import *
from dreamcoder.fragmentGrammar import *
bootstrapTarget_extra()
McCarthyPrimitives()
testSharing()
# p = Program.parse("(#(lambda (lambda (lambda (fold $0 empty ($1 $2))))) cons (lambda (lambda (lambda ($2 (+ (+ 5 5) (+ $1 $1)) $0)))))")
# print(EtaLongVisitor().execute(p))
# BOOTSTRAP
programs = [# "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))",
# "(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ ($1 (cdr $0)) 1))))))",
# "(lambda (+ $0 1))",
# "(lambda (+ (car $0) 1))",
# "(lambda (+ $0 (+ 1 1)))",
# "(lambda (- $0 1))",
# "(lambda (- $0 (+ 1 1)))",
# "(lambda (- (car $0) 1))",
("(lambda (fix1 $0 (lambda (lambda (if (eq? 0 $0) empty (cons (- 0 $0) ($1 (+ 1 $0))))))))",None),
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) empty (cons (cdr $0) ($1 (cdr $0))))))))",arrow(tlist(tint),tlist(tlist(tint)))),
# drop the last element
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? (cdr $0)) empty (cons (car $0) ($1 (cdr $0))))))))",arrow(tlist(tint),tlist(tint))),
# take in till 1
# ("(lambda (fix1 $0 (lambda (lambda (if (eq? (car $0) 1) empty (cons (car $0) ($1 (cdr $0))))))))",arrow(tlist(tint),tlist(tint))),
# "(lambda (lambda (fix2 $1 $0 (lambda (lambda (lambda (if (eq? $1 0) (car $0) ($2 (- $1 1) (cdr $0)))))))))",
# "(lambda (lambda (fix2 $1 $0 (lambda (lambda (lambda (if (eq? $1 0) (car $0) ($2 (- $1 1) (cdr $0)))))))))",
("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 0 (+ (car $0) ($1 (cdr $0))))))))",None),
("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) 1 (- (car $0) ($1 (cdr $0))))))))",None),
("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) (cons 0 empty) (cons (car $0) ($1 (cdr $0))))))))",None),
("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) (empty? empty) (if (car $0) ($1 (cdr $0)) (eq? 1 0)))))))",None),
# "(lambda (lambda (fix2 $1 $0 (lambda (lambda (lambda (if (empty? $1) $0 (cons (car $1) ($2 (cdr $1) $0)))))))))",
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) empty (cons (+ (car $0) (car $0)) ($1 (cdr $0))))))))",None),
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) empty (cons (+ (car $0) 1) ($1 (cdr $0))))))))",None),
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) empty (cons (- (car $0) 1) ($1 (cdr $0))))))))",None),
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) empty (cons (cons (car $0) empty) ($1 (cdr $0))))))))",arrow(tlist(tint),tlist(tlist(tint)))),
# ("(lambda (fix1 $0 (lambda (lambda (if (empty? $0) empty (cons (- 0 (car $0)) ($1 (cdr $0))))))))",None)
]
programs = [(Program.parse(p),t) for p,t in programs ]
N=3
primitives = McCarthyPrimitives()
# for p, _ in programs:
# for _, s in p.walk():
# if s.isPrimitive:
# primitives.add(s)
g0 = Grammar.uniform(list(primitives))
print(g0)
# with timing("RUST test"):
# g = induceGrammar(g0, [Frontier.dummy(p, tp=tp) for p, tp in programs],
# CPUs=1,
# a=N,
# backend="vs")
# eprint(g)
# with open('vs.pickle','rb') as handle:
# a,kw = pickle.load(handle)
# induceGrammar_Beta(*a,**kw)
with timing("induced DSL"):
induceGrammar_Beta(g0, [Frontier.dummy(p, tp=tp) for p, tp in programs],
CPUs=1,
a=N,
structurePenalty=0.)
# if __name__ == "__main__":
# import argparse
# parser = argparse.ArgumentParser(description = "Version-space based compression")
# parser.add_argument("--CPUs", type=int, default=1)
# parser.add_argument("--arity", type=int, default=3)
# parser.add_argument("--bs", type=int, default=25,
# help="beam size")
# parser.add_argument("--topK", type=int, default=2)
# parser.add_argument("--topI", type=int, default=None,
# help="defaults to beam size")
# parser.add_argument("--pseudoCounts",
# type=float,
# default=1.)
# parser.add_argument("--structurePenalty",
# type=float, default=1.)
# arguments = parser.parse_args()