File size: 1,607 Bytes
9b4edaf 201cfef 9b4edaf 201cfef 076bdf6 9b4edaf 076bdf6 950e174 9b4edaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from prompt import TA_prompt
import re
from utils import generate_response, run_code
def post_process_code(code, question):
func_name = code.split("(")[0].split("def")[-1].strip()
parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",")
if '' in parameters:
parameters.remove('')
values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)]
values = [int(v) for v in values]
arguments = list(zip(parameters, values))
arg_string = ""
for param, val in arguments:
arg_string += f"{param}={val},"
func_call = f"\nprint({func_name}({arg_string[:-1]}))"
code += func_call
return code
def solve_ta(question):
question = question.strip()
question = "Human: " + question
query = TA_prompt + question
query = query.strip()
query += "\n"
code = generate_response(query, 0.9)
n = len(TA_prompt.strip())
code = code[n:].strip().split("-----")[0]
# print(code)
splitting_string = "```" if "```python" not in code else "```python"
if "```" in code:
code = code.split(splitting_string)[1].split("```")[0].strip()
# code preprocessing
code = post_process_code(code, question)
print(code)
# code running
if "input(" in code:
return None, code
pred = None
try:
pred = run_code(code)
except Exception as ex:
return None, code
return pred, code
else:
res = re.findall(r"Assistant:(.*)", code, re.DOTALL)[0].split("Human:")[0]
return res.strip(), ""
|