File size: 1,217 Bytes
9b4edaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
from prompt import prompt
import re
from utils import generate_response, run_code
def post_process_code(code, question):
parameters = code.split("\n")[0].split("def solution")[-1][1:-2].split(",")
if '' in parameters:
parameters.remove('')
values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)]
values = [int(v) for v in values]
return list(zip(parameters, values))
def solve_pal(question, token):
question = question.strip()
query = prompt.format(question=question).strip()
code = generate_response(query, 0.9, token)
code = code.split("def solution():")[-1].strip()
code = "def solution():\n" + code
# code preprocessing
arguments = post_process_code(code, question)
arg_string = ""
for param, val in arguments:
arg_string += f"{param}={val},"
func_call = f"\nprint(solution({arg_string[:-1]}))"
code += func_call
# code running
if "input(" in code:
return None, code
pred = None
try:
pred = run_code(code)
except Exception as ex:
return None, code
return pred, code
if __name__ == "__main__":
q = "What is the 7th Fibonacci number?"
print(solve_pal(q))
|