|
from prompt import TA_prompt |
|
import re |
|
from utils import generate_response, run_code |
|
|
|
|
|
def post_process_code(code, question): |
|
func_name = code.split("(")[0].split("def")[-1].strip() |
|
parameters = code.split("\n")[0].split(f"def {func_name}")[-1][1:-2].split(",") |
|
if '' in parameters: |
|
parameters.remove('') |
|
values = re.findall(r"[-+]?\d*\.\d+|\d+", question)[:len(parameters)] |
|
values = [int(v) for v in values] |
|
arguments = list(zip(parameters, values)) |
|
|
|
arg_string = "" |
|
for param, val in arguments: |
|
arg_string += f"{param}={val}," |
|
func_call = f"\nprint({func_name}({arg_string[:-1]}))" |
|
code += func_call |
|
return code |
|
|
|
|
|
def solve_ta(question): |
|
question = question.strip() |
|
question = "Human: " + question |
|
query = TA_prompt + question |
|
query = query.strip() |
|
query += "\n" |
|
code = generate_response(query, 0.9) |
|
n = len(TA_prompt.strip()) |
|
code = code[n:].strip().split("-----")[0] |
|
|
|
splitting_string = "```" if "```python" not in code else "```python" |
|
if "```" in code: |
|
code = code.split(splitting_string)[1].split("```")[0].strip() |
|
|
|
code = post_process_code(code, question) |
|
print(code) |
|
|
|
if "input(" in code: |
|
return None, code |
|
pred = None |
|
try: |
|
pred = run_code(code) |
|
except Exception as ex: |
|
return None, code |
|
return pred, code |
|
else: |
|
res = re.findall(r"Assistant:(.*)", code, re.DOTALL)[0].split("Human:")[0] |
|
return res.strip(), "" |
|
|
|
|