NCTC / models /research /neural_gpu /program_utils.py
NCTCMumbai's picture
Upload 2571 files
0b8359d
raw
history blame
13.5 kB
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for generating program synthesis and evaluation data."""
import contextlib
import sys
import random
import os
try:
import StringIO
except ImportError:
from io import StringIO
class ListType(object):
def __init__(self, arg):
self.arg = arg
def __str__(self):
return "[" + str(self.arg) + "]"
def __eq__(self, other):
if not isinstance(other, ListType):
return False
return self.arg == other.arg
def __hash__(self):
return hash(self.arg)
class VarType(object):
def __init__(self, arg):
self.arg = arg
def __str__(self):
return str(self.arg)
def __eq__(self, other):
if not isinstance(other, VarType):
return False
return self.arg == other.arg
def __hash__(self):
return hash(self.arg)
class FunctionType(object):
def __init__(self, args):
self.args = args
def __str__(self):
return str(self.args[0]) + " -> " + str(self.args[1])
def __eq__(self, other):
if not isinstance(other, FunctionType):
return False
return self.args == other.args
def __hash__(self):
return hash(tuple(self.args))
class Function(object):
def __init__(self, name, arg_types, output_type, fn_arg_types = None):
self.name = name
self.arg_types = arg_types
self.fn_arg_types = fn_arg_types or []
self.output_type = output_type
Null = 100
## Functions
f_head = Function("c_head", [ListType("Int")], "Int")
def c_head(xs): return xs[0] if len(xs) > 0 else Null
f_last = Function("c_last", [ListType("Int")], "Int")
def c_last(xs): return xs[-1] if len(xs) > 0 else Null
f_take = Function("c_take", ["Int", ListType("Int")], ListType("Int"))
def c_take(n, xs): return xs[:n]
f_drop = Function("c_drop", ["Int", ListType("Int")], ListType("Int"))
def c_drop(n, xs): return xs[n:]
f_access = Function("c_access", ["Int", ListType("Int")], "Int")
def c_access(n, xs): return xs[n] if n >= 0 and len(xs) > n else Null
f_max = Function("c_max", [ListType("Int")], "Int")
def c_max(xs): return max(xs) if len(xs) > 0 else Null
f_min = Function("c_min", [ListType("Int")], "Int")
def c_min(xs): return min(xs) if len(xs) > 0 else Null
f_reverse = Function("c_reverse", [ListType("Int")], ListType("Int"))
def c_reverse(xs): return list(reversed(xs))
f_sort = Function("sorted", [ListType("Int")], ListType("Int"))
# def c_sort(xs): return sorted(xs)
f_sum = Function("sum", [ListType("Int")], "Int")
# def c_sum(xs): return sum(xs)
## Lambdas
# Int -> Int
def plus_one(x): return x + 1
def minus_one(x): return x - 1
def times_two(x): return x * 2
def neg(x): return x * (-1)
def div_two(x): return int(x/2)
def sq(x): return x**2
def times_three(x): return x * 3
def div_three(x): return int(x/3)
def times_four(x): return x * 4
def div_four(x): return int(x/4)
# Int -> Bool
def pos(x): return x > 0
def neg(x): return x < 0
def even(x): return x%2 == 0
def odd(x): return x%2 == 1
# Int -> Int -> Int
def add(x, y): return x + y
def sub(x, y): return x - y
def mul(x, y): return x * y
# HOFs
f_map = Function("map", [ListType("Int")],
ListType("Int"),
[FunctionType(["Int", "Int"])])
f_filter = Function("filter", [ListType("Int")],
ListType("Int"),
[FunctionType(["Int", "Bool"])])
f_count = Function("c_count", [ListType("Int")],
"Int",
[FunctionType(["Int", "Bool"])])
def c_count(f, xs): return len([x for x in xs if f(x)])
f_zipwith = Function("c_zipwith", [ListType("Int"), ListType("Int")],
ListType("Int"),
[FunctionType(["Int", "Int", "Int"])]) #FIX
def c_zipwith(f, xs, ys): return [f(x, y) for (x, y) in zip(xs, ys)]
f_scan = Function("c_scan", [ListType("Int")],
ListType("Int"),
[FunctionType(["Int", "Int", "Int"])])
def c_scan(f, xs):
out = xs
for i in range(1, len(xs)):
out[i] = f(xs[i], xs[i -1])
return out
@contextlib.contextmanager
def stdoutIO(stdout=None):
old = sys.stdout
if stdout is None:
stdout = StringIO.StringIO()
sys.stdout = stdout
yield stdout
sys.stdout = old
def evaluate(program_str, input_names_to_vals, default="ERROR"):
exec_str = []
for name, val in input_names_to_vals.iteritems():
exec_str += name + " = " + str(val) + "; "
exec_str += program_str
if type(exec_str) is list:
exec_str = "".join(exec_str)
with stdoutIO() as s:
# pylint: disable=bare-except
try:
exec(exec_str + " print(out)")
return s.getvalue()[:-1]
except:
return default
# pylint: enable=bare-except
class Statement(object):
"""Statement class."""
def __init__(self, fn, output_var, arg_vars, fn_args=None):
self.fn = fn
self.output_var = output_var
self.arg_vars = arg_vars
self.fn_args = fn_args or []
def __str__(self):
return "%s = %s(%s%s%s)"%(self.output_var,
self.fn.name,
", ".join(self.fn_args),
", " if self.fn_args else "",
", ".join(self.arg_vars))
def substitute(self, env):
self.output_var = env.get(self.output_var, self.output_var)
self.arg_vars = [env.get(v, v) for v in self.arg_vars]
class ProgramGrower(object):
"""Grow programs."""
def __init__(self, functions, types_to_lambdas):
self.functions = functions
self.types_to_lambdas = types_to_lambdas
def grow_body(self, new_var_name, dependencies, types_to_vars):
"""Grow the program body."""
choices = []
for f in self.functions:
if all([a in types_to_vars.keys() for a in f.arg_types]):
choices.append(f)
f = random.choice(choices)
args = []
for t in f.arg_types:
possible_vars = random.choice(types_to_vars[t])
var = random.choice(possible_vars)
args.append(var)
dependencies.setdefault(new_var_name, []).extend(
[var] + (dependencies[var]))
fn_args = [random.choice(self.types_to_lambdas[t]) for t in f.fn_arg_types]
types_to_vars.setdefault(f.output_type, []).append(new_var_name)
return Statement(f, new_var_name, args, fn_args)
def grow(self, program_len, input_types):
"""Grow the program."""
var_names = list(reversed(map(chr, range(97, 123))))
dependencies = dict()
types_to_vars = dict()
input_names = []
for t in input_types:
var = var_names.pop()
dependencies[var] = []
types_to_vars.setdefault(t, []).append(var)
input_names.append(var)
statements = []
for _ in range(program_len - 1):
var = var_names.pop()
statements.append(self.grow_body(var, dependencies, types_to_vars))
statements.append(self.grow_body("out", dependencies, types_to_vars))
new_var_names = [c for c in map(chr, range(97, 123))
if c not in input_names]
new_var_names.reverse()
keep_statements = []
env = dict()
for s in statements:
if s.output_var in dependencies["out"]:
keep_statements.append(s)
env[s.output_var] = new_var_names.pop()
if s.output_var == "out":
keep_statements.append(s)
for k in keep_statements:
k.substitute(env)
return Program(input_names, input_types, ";".join(
[str(k) for k in keep_statements]))
class Program(object):
"""The program class."""
def __init__(self, input_names, input_types, body):
self.input_names = input_names
self.input_types = input_types
self.body = body
def evaluate(self, inputs):
"""Evaluate this program."""
if len(inputs) != len(self.input_names):
raise AssertionError("inputs and input_names have to"
"have the same len. inp: %s , names: %s" %
(str(inputs), str(self.input_names)))
inp_str = ""
for (name, inp) in zip(self.input_names, inputs):
inp_str += name + " = " + str(inp) + "; "
with stdoutIO() as s:
# pylint: disable=exec-used
exec(inp_str + self.body + "; print(out)")
# pylint: enable=exec-used
return s.getvalue()[:-1]
def flat_str(self):
out = ""
for s in self.body.split(";"):
out += s + ";"
return out
def __str__(self):
out = ""
for (n, t) in zip(self.input_names, self.input_types):
out += n + " = " + str(t) + "\n"
for s in self.body.split(";"):
out += s + "\n"
return out
prog_vocab = []
prog_rev_vocab = {}
def tokenize(string, tokens=None):
"""Tokenize the program string."""
if tokens is None:
tokens = prog_vocab
tokens = sorted(tokens, key=len, reverse=True)
out = []
string = string.strip()
while string:
found = False
for t in tokens:
if string.startswith(t):
out.append(t)
string = string[len(t):]
found = True
break
if not found:
raise ValueError("Couldn't tokenize this: " + string)
string = string.strip()
return out
def clean_up(output, max_val=100):
o = eval(str(output))
if isinstance(o, bool):
return o
if isinstance(o, int):
if o >= 0:
return min(o, max_val)
else:
return max(o, -1 * max_val)
if isinstance(o, list):
return [clean_up(l) for l in o]
def make_vocab():
gen(2, 0)
def gen(max_len, how_many):
"""Generate some programs."""
functions = [f_head, f_last, f_take, f_drop, f_access, f_max, f_min,
f_reverse, f_sort, f_sum, f_map, f_filter, f_count, f_zipwith,
f_scan]
types_to_lambdas = {
FunctionType(["Int", "Int"]): ["plus_one", "minus_one", "times_two",
"div_two", "sq", "times_three",
"div_three", "times_four", "div_four"],
FunctionType(["Int", "Bool"]): ["pos", "neg", "even", "odd"],
FunctionType(["Int", "Int", "Int"]): ["add", "sub", "mul"]
}
tokens = []
for f in functions:
tokens.append(f.name)
for v in types_to_lambdas.values():
tokens.extend(v)
tokens.extend(["=", ";", ",", "(", ")", "[", "]", "Int", "out"])
tokens.extend(map(chr, range(97, 123)))
io_tokens = map(str, range(-220, 220))
if not prog_vocab:
prog_vocab.extend(["_PAD", "_EOS"] + tokens + io_tokens)
for i, t in enumerate(prog_vocab):
prog_rev_vocab[t] = i
io_tokens += [",", "[", "]", ")", "(", "None"]
grower = ProgramGrower(functions=functions,
types_to_lambdas=types_to_lambdas)
def mk_inp(l):
return [random.choice(range(-5, 5)) for _ in range(l)]
tar = [ListType("Int")]
inps = [[mk_inp(3)], [mk_inp(5)], [mk_inp(7)], [mk_inp(15)]]
save_prefix = None
outcomes_to_programs = dict()
tried = set()
counter = 0
choices = [0] if max_len == 0 else range(max_len)
while counter < 100 * how_many and len(outcomes_to_programs) < how_many:
counter += 1
length = random.choice(choices)
t = grower.grow(length, tar)
while t in tried:
length = random.choice(choices)
t = grower.grow(length, tar)
# print(t.flat_str())
tried.add(t)
outcomes = [clean_up(t.evaluate(i)) for i in inps]
outcome_str = str(zip(inps, outcomes))
if outcome_str in outcomes_to_programs:
outcomes_to_programs[outcome_str] = min(
[t.flat_str(), outcomes_to_programs[outcome_str]],
key=lambda x: len(tokenize(x, tokens)))
else:
outcomes_to_programs[outcome_str] = t.flat_str()
if counter % 5000 == 0:
print("== proggen: tried: " + str(counter))
print("== proggen: kept: " + str(len(outcomes_to_programs)))
if counter % 250000 == 0 and save_prefix is not None:
print("saving...")
save_counter = 0
progfilename = os.path.join(save_prefix, "prog_" + str(counter) + ".txt")
iofilename = os.path.join(save_prefix, "io_" + str(counter) + ".txt")
prog_token_filename = os.path.join(save_prefix,
"prog_tokens_" + str(counter) + ".txt")
io_token_filename = os.path.join(save_prefix,
"io_tokens_" + str(counter) + ".txt")
with open(progfilename, "a+") as fp, \
open(iofilename, "a+") as fi, \
open(prog_token_filename, "a+") as ftp, \
open(io_token_filename, "a+") as fti:
for (o, p) in outcomes_to_programs.iteritems():
save_counter += 1
if save_counter % 500 == 0:
print("saving %d of %d" % (save_counter, len(outcomes_to_programs)))
fp.write(p+"\n")
fi.write(o+"\n")
ftp.write(str(tokenize(p, tokens))+"\n")
fti.write(str(tokenize(o, io_tokens))+"\n")
return list(outcomes_to_programs.values())