File size: 1,217 Bytes
9b4edaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from prompt import prompt
import re
from utils import generate_response, run_code


def post_process_code(code, question):
    parameters = code.split("\n")[0].split("def solution")[-1][1:-2].split(",")
    if '' in parameters:
        parameters.remove('')
    values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)]
    values = [int(v) for v in values]
    return list(zip(parameters, values))


def solve_pal(question, token):
    question = question.strip()
    query = prompt.format(question=question).strip()
    code = generate_response(query, 0.9, token)
    code = code.split("def solution():")[-1].strip()
    code = "def solution():\n" + code
    # code preprocessing

    arguments = post_process_code(code, question)

    arg_string = ""
    for param, val in arguments:
        arg_string += f"{param}={val},"
    func_call = f"\nprint(solution({arg_string[:-1]}))"
    code += func_call
    # code running
    if "input(" in code:
        return None, code
    pred = None
    try:
        pred = run_code(code)
    except Exception as ex:
        return None, code
    return pred, code


if __name__ == "__main__":

    q = "What is the 7th Fibonacci number?"
    print(solve_pal(q))