File size: 4,638 Bytes
e20ef71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import ast
import importlib
import io
import os
import re
import string
import time
from functools import partial
from typing import List

import pysnooper

FUNCTION_HEAD = "def execute_command({input_type}) -> {output_type}:"
EXEC_FUNCTION_HEAD = 'def execute_command({input_type}, possible_answers, query, ImagePatch, VideoSegment,' \
                     ' llm_query, bool_to_yesno, distance, best_image_match):'


class CompileTimeError:
    pass


class ProgramRuntimeError:
    pass


def process_trace(text, function_head, execution_function_head):
    def remove_indent(lines):
        n_space = 0
        for i, c in enumerate(lines[0]):
            if c == ' ':
                n_space += 1
            else:
                break
        return [line[n_space:] if line[0] == ' ' else line for line in lines]

    def remove_pre_context(lines: List[str]):  # lol, just a random use of List
        for i in range(len(lines) - 1, -1, -1):
            line = lines[i]
            if execution_function_head in line:
                # assert "call" in line  # TODO: further double-check?
                content = [line.replace(execution_function_head, function_head)] + lines[i + 1:]
                if line[0] == ' ':
                    return remove_indent(content)
                else:
                    return content
        return []

    def remove_post_context(lines):
        for i, line in enumerate(lines):
            if line.startswith("Source path:") and line.endswith(__file__):
                return lines[:i]
            elif line.startswith("Elapsed time"):
                return lines[:i]
        return lines

    def remove_timestamp(lines):
        ret = []
        for line in lines:
            if len(line) > 0 and line[0] in string.digits:
                line = line[16:]  # remove timestamp
            ret.append(line)
        return ret

    def remove_tensor(line):
        return re.sub(r"tensor\(\[\[\[.*?\]\]\]\)", "tensor([[[...]]])", line)

    lines = text.splitlines()
    lines = remove_pre_context(lines)
    lines = remove_post_context(lines)
    lines = remove_timestamp(lines)
    lines = [remove_tensor(line) for line in lines]

    return '\n'.join(lines)


cnt = 0


def run_program_with_trace(code, image, input_type_, output_type_):
    from image_patch import ImagePatch, llm_query, best_image_match, distance, bool_to_yesno

    function_head = FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)
    execution_function_head = EXEC_FUNCTION_HEAD.format(input_type=input_type_, output_type=output_type_)

    code = str(code)
    if code.startswith("\ndef"):
        code = code[1:]  # TODO: just a temporary fix

    if code.startswith('def'):
        if code.startswith(function_head):
            code = code.replace(function_head, '')
        else:
            print("--- Code with invalid format\n")
            print(code)
    code = execution_function_head + code
    try:
        code = ast.unparse(ast.parse(code))
    except:
        return None, CompileTimeError(), None

    global cnt
    cnt += 1
    name = f'x{cnt}'
    with open(f'{name}.py', 'w') as f:
        f.write(code)

    for _ in range(20):
        try:
            x = importlib.import_module(name)
        except ModuleNotFoundError:
            print("Errrr, import error. Wait a bit while.")
            time.sleep(60)  # I have no idea why it sometimes fails. Probably file system error
        except Exception as e:
            print("Import has error:", e)
            break
        else:
            break

    queues = [None, None]

    image_patch_partial = partial(ImagePatch, queues=queues)
    video_segment_partial = None
    llm_query_partial = partial(llm_query, queues=queues)

    # signal.signal(signal.SIGALRM, handler)  # unfortunately doesn't work
    # signal.alarm(60 * 20)  # timeout = 10min, just in case while True
    with io.StringIO() as f:
        with pysnooper.snoop(output=f, color=False, depth=2, max_variable_length=1000):
            result = None
            error = None
            try:
                result = x.execute_command(image, None, '', image_patch_partial, video_segment_partial,
                                           llm_query_partial, bool_to_yesno, distance, best_image_match)
            except:
                error = ProgramRuntimeError()
            # finally:
            #     signal.alarm(0)
            os.remove(f'{name}.py')
        f.seek(0)
        traced = f.read(100000)
        traced_processed = process_trace(traced, function_head, execution_function_head)

    return result, error, traced_processed