program-synthesis / abstract_syntax_tree.py
ayushnoori's picture
Amend weight description
4dbf6ae
'''
ABSTRACT SYNTAX TREE
This file contains the Python class that defines the abstract syntax tree (AST) representation.
'''
class OperatorNode:
'''
Class to represent operator nodes (i.e., an operator and its operands) as an AST.
Args:
operator (object): operator object (e.g., Add, Subtract, etc.)
children (list): list of children nodes (operands)
Example:
add_node = OperatorNode(Add(), [IntegerConstant(7), IntegerConstant(5)])
subtract_node = OperatorNode(Subtract(), [IntegerConstant(3), IntegerConstant(1)])
multiply_node = OperatorNode(Multiply(), [add_node, subtract_node])
multiply_node.evaluate() # returns 24
multiply_node.str() # returns "((7 + 5) * (3 - 1))"
For variable computation, the input arguments are passed to the evaluate() method.
For example, if instead:
add_node = OperatorNode(Add(), [IntegerVariable(0), IntegerConstant(5)])
multiply_node.evaluate([7]) # returns 24
'''
def __init__(self, operator, children):
self.operator = operator # operator object (e.g., Add, Subtract, etc.)
self.children = children # list of children nodes (operands)
self.weight = operator.weight + sum([child.weight for child in children]) # weight of the program
self.type = operator.return_type # return type of the operator object
def evaluate(self, input = None):
# check arity of operator in AST
if len(self.children) != self.operator.arity:
raise ValueError("Invalid number of operands for operator")
# recursively evaluate the operator and its operands
operands = [child.evaluate(input) for child in self.children]
return self.operator.evaluate(*operands, input)
def str(self):
# check arity of operator in AST
if len(self.children) != self.operator.arity:
raise ValueError("Invalid number of operands for operator")
# recursively generate a string representation of the AST
operand_strings = [child.str() for child in self.children]
return self.operator.str(*operand_strings)