|
from prompt import prompt |
|
import re |
|
from utils import generate_response, run_code |
|
|
|
|
|
def post_process_code(code, question): |
|
parameters = code.split("\n")[0].split("def solution")[-1][1:-2].split(",") |
|
if '' in parameters: |
|
parameters.remove('') |
|
values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)] |
|
values = [int(v) for v in values] |
|
return list(zip(parameters, values)) |
|
|
|
|
|
def solve_pal(question, token): |
|
question = question.strip() |
|
query = prompt.format(question=question).strip() |
|
code = generate_response(query, 0.9, token) |
|
code = code.split("def solution():")[-1].strip() |
|
code = "def solution():\n" + code |
|
|
|
|
|
arguments = post_process_code(code, question) |
|
|
|
arg_string = "" |
|
for param, val in arguments: |
|
arg_string += f"{param}={val}," |
|
func_call = f"\nprint(solution({arg_string[:-1]}))" |
|
code += func_call |
|
|
|
if "input(" in code: |
|
return None, code |
|
pred = None |
|
try: |
|
pred = run_code(code) |
|
except Exception as ex: |
|
return None, code |
|
return pred, code |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
q = "What is the 7th Fibonacci number?" |
|
print(solve_pal(q)) |
|
|
|
|