File size: 4,638 Bytes
e20ef71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import ast
import importlib
import io
import os
import re
import string
import time
from functools import partial
from typing import List
import pysnooper
FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:"
EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \
' llm_query, bool_to_yesno, distance, best_image_match):'
class CompileTimeError:
pass
class ProgramRuntimeError:
pass
def process_trace(text, function_head, execution_function_head):
def remove_indent(lines):
n_space = 0
for i, c in enumerate(lines[0]):
if c == ' ':
n_space += 1
else:
break
return [line[n_space:] if line[0] == ' ' else line for line in lines]
def remove_pre_context(lines: List[str]): # lol, just a random use of List
for i in range(len(lines) - 1, -1, -1):
line = lines[i]
if execution_function_head in line:
# assert "call" in line # TODO: further double-check?
content = [line.replace(execution_function_head, function_head)] + lines[i + 1:]
if line[0] == ' ':
return remove_indent(content)
else:
return content
return []
def remove_post_context(lines):
for i, line in enumerate(lines):
if line.startswith("Source path:") and line.endswith(__file__):
return lines[:i]
elif line.startswith("Elapsed time"):
return lines[:i]
return lines
def remove_timestamp(lines):
ret = []
for line in lines:
if len(line) > 0 and line[0] in string.digits:
line = line[16:] # remove timestamp
ret.append(line)
return ret
def remove_tensor(line):
return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line)
lines = text.splitlines()
lines = remove_pre_context(lines)
lines = remove_post_context(lines)
lines = remove_timestamp(lines)
lines = [remove_tensor(line) for line in lines]
return '\n'.join(lines)
cnt = 0
def run_program_with_trace(code, image, input_type_, output_type_):
from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno
function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
code = str(code)
if code.startswith("\ndef"):
code = code[1:] # TODO: just a temporary fix
if code.startswith('def'):
if code.startswith(function_head):
code = code.replace(function_head, '')
else:
print("--- Code with invalid format\n")
print(code)
code = execution_function_head + code
try:
code = ast.unparse(ast.parse(code))
except:
return None, CompileTimeError(), None
global cnt
cnt += 1
name = f'x{cnt}'
with open(f'{name}.py', 'w') as f:
f.write(code)
for _ in range(20):
try:
x = importlib.import_module(name)
except ModuleNotFoundError:
print("Errrr, import error. Wait a bit while.")
time.sleep(60) # I have no idea why it sometimes fails. Probably file system error
except Exception as e:
print("Import has error:", e)
break
else:
break
queues = [None, None]
image_patch_partial = partial(ImagePatch, queues=queues)
video_segment_partial = None
llm_query_partial = partial(llm_query, queues=queues)
# signal.signal(signal.SIGALRM, handler) # unfortunately doesn't work
# signal.alarm(60 * 20) # timeout = 10min, just in case while True
with io.StringIO() as f:
with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000):
result = None
error = None
try:
result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial,
llm_query_partial, bool_to_yesno, distance, best_image_match)
except:
error = ProgramRuntimeError()
# finally:
# signal.alarm(0)
os.remove(f'{name}.py')
f.seek(0)
traced = f.read(100000)
traced_processed = process_trace(traced, function_head, execution_function_head)
return result, error, traced_processed
|