magnetic commited on
Commit
7803dd9
1 Parent(s): e4b2840

Upload online demo code

Browse files
.gitignore ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore .ckpt files
2
+ ckpt
3
+
4
+ # Ignore Python compiled files
5
+ __pycache__/
6
+ *.py[cod]
7
+
8
+ # Ignore Python virtual environment
9
+ venv/
10
+
11
+ # Ignore Jupyter notebook checkpoints
12
+ .ipynb_checkpoints/
13
+ .git/
14
+ .vscode/
15
+
16
+ # Ignore .DS_Store on MacOS
17
+ .DS_Store
18
+
19
+ rilab_key.txt
20
+ gpt4_custom_code_interpreter/rilab_key.txt
21
+ openai_api_key.txt
22
+
23
+ gpt4_custom_code_interpreter/
24
+ tmp/
25
+ output/
26
+ wandb/
27
+
28
+ utils/const.py
29
+ utils/hf_model_upload.py
30
+ gpt_data_gen/
31
+ *.json
32
+ *.txt
33
+ *.sh
34
+ *.pt
35
+ *.pth
36
+ *.ckpt
37
+ *.tokenizer
38
+
39
+ # eval data
40
+ eval/ds1000_data
41
+ eval/grade-school-math
42
+
43
+ # gradio features
44
+ chatbot_feat.py
45
+ chatbot_feat2.py
46
+ gradio_test.py
47
+
48
+ cache/
49
+ env/
50
+ json_dataset/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Magnetic2014
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
assets/assistant.pic.jpg ADDED
assets/user.pic.jpg ADDED
chatbot.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import gradio as gr
3
+ import os
4
+ import re
5
+ import json
6
+ import logging
7
+
8
+ import torch
9
+ from datetime import datetime
10
+
11
+ from threading import Thread
12
+ from typing import Optional
13
+ from transformers import TextIteratorStreamer
14
+ from functools import partial
15
+ from huggingface_hub import CommitScheduler
16
+ from uuid import uuid4
17
+ from pathlib import Path
18
+
19
+ from code_interpreter.JupyterClient import JupyterNotebook
20
+
21
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
22
+
23
+ import warnings
24
+
25
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
26
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
27
+
28
+
29
+ from code_interpreter.OpenCodeInterpreter import OpenCodeInterpreter
30
+
31
+ JSON_DATASET_DIR = Path("json_dataset")
32
+ JSON_DATASET_DIR.mkdir(parents=True, exist_ok=True)
33
+
34
+ upvote_button_value = "👍 Upvote Conversation"
35
+ downvote_button_value = "👎 Downvote Conversation"
36
+
37
+ scheduler = CommitScheduler(
38
+ repo_id="opencodeinterpreter_user_data",
39
+ repo_type="dataset",
40
+ folder_path=JSON_DATASET_DIR,
41
+ path_in_repo="data",
42
+ private=True
43
+ )
44
+
45
+ logging.basicConfig(level=logging.INFO)
46
+
47
+ class StreamingOpenCodeInterpreter(OpenCodeInterpreter):
48
+ streamer: Optional[TextIteratorStreamer] = None
49
+
50
+ # overwirte generate function
51
+ @torch.inference_mode()
52
+ def generate(
53
+ self,
54
+ prompt: str = "",
55
+ max_new_tokens = 1024,
56
+ do_sample: bool = False,
57
+ top_p: float = 0.95,
58
+ top_k: int = 50,
59
+ ) -> str:
60
+ # Get the model and tokenizer, and tokenize the user text.
61
+
62
+ self.streamer = TextIteratorStreamer(
63
+ self.tokenizer, skip_prompt=True, Timeout=5
64
+ )
65
+
66
+ inputs = self.tokenizer([prompt], return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH)
67
+ inputs = inputs.to(self.model.device)
68
+
69
+ kwargs = dict(
70
+ **inputs,
71
+ streamer=self.streamer,
72
+ max_new_tokens=max_new_tokens,
73
+ do_sample=do_sample,
74
+ top_p=top_p,
75
+ top_k=top_k,
76
+ eos_token_id=self.tokenizer.eos_token_id
77
+ )
78
+
79
+ thread = Thread(target=self.model.generate, kwargs=kwargs)
80
+ thread.start()
81
+
82
+ return ""
83
+
84
+ def save_json(dialog, mode, json_file_path, flag, dialog_id) -> None:
85
+ with scheduler.lock:
86
+ with json_file_path.open("a") as f:
87
+ json.dump({"id": dialog_id, "dialog": dialog, "mode": mode, "flag": flag, "datetime": datetime.now().isoformat()}, f, ensure_ascii=False)
88
+ f.write("\n")
89
+
90
+ def convert_history(gradio_history: list[list], interpreter_history: list[dict]):
91
+ interpreter_history = [interpreter_history[0]] if interpreter_history and interpreter_history[0]["role"] == "system" else []
92
+ if not gradio_history:
93
+ return interpreter_history
94
+ for item in gradio_history:
95
+ if item[0] is not None:
96
+ interpreter_history.append({"role": "user", "content": item[0]})
97
+ if item[1] is not None:
98
+ interpreter_history.append({"role": "assistant", "content": item[1]})
99
+ return interpreter_history
100
+
101
+ def reset_dialog_info(dialog_info):
102
+ new_uuid = str(uuid4())
103
+ logging.info(f"allocating new uuid {new_uuid} for conversation...")
104
+ return [new_uuid, None]
105
+
106
+ def is_valid_python_code(code):
107
+ try:
108
+ ast.parse(code)
109
+ return True
110
+ except SyntaxError:
111
+ return False
112
+
113
+
114
+ class InputFunctionVisitor(ast.NodeVisitor):
115
+ def __init__(self):
116
+ self.found_input = False
117
+
118
+ def visit_Call(self, node):
119
+ if isinstance(node.func, ast.Name) and node.func.id == 'input':
120
+ self.found_input = True
121
+ self.generic_visit(node)
122
+
123
+ def has_input_function_calls(code):
124
+ try:
125
+ tree = ast.parse(code)
126
+ except SyntaxError:
127
+ return False
128
+ visitor = InputFunctionVisitor()
129
+ visitor.visit(tree)
130
+ return visitor.found_input
131
+
132
+ def gradio_launch(model_path: str, MAX_TRY: int = 3):
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown("# Online Demo of OpenCodeInterpreter Models")
135
+ gr.Markdown("**NOTE: Please read the disclaimer section in [README.md](https://huggingface.co/spaces/m-a-p/OpenCodeInterpreter_demo/blob/main/README.md) before using this demo!**")
136
+ gr.Markdown("**By using this demo, you acknowledge that you have read this disclaimer, understand its terms, and agree to be bound by them.**")
137
+ chatbot = gr.Chatbot(height=600, label="OpenCodeInterpreter", avatar_images=["assets/user.pic.jpg", "assets/assistant.pic.jpg"], show_copy_button=True)
138
+ with gr.Group():
139
+ with gr.Row():
140
+ msg = gr.Textbox(
141
+ container=False,
142
+ show_label=False,
143
+ label="Message",
144
+ placeholder="Type a message...",
145
+ scale=7,
146
+ autofocus=True
147
+ )
148
+ sub = gr.Button(
149
+ "Submit",
150
+ variant="primary",
151
+ scale=1,
152
+ min_width=150
153
+ )
154
+ # stop = gr.Button(
155
+ # "Stop",
156
+ # variant="stop",
157
+ # visible=False,
158
+ # scale=1,
159
+ # min_width=150
160
+ # )
161
+
162
+ with gr.Row():
163
+ # retry = gr.Button("🔄 Retry", variant="secondary")
164
+ # undo = gr.Button("↩️ Undo", variant="secondary")
165
+ upvote = gr.Button(upvote_button_value, variant="secondary")
166
+ downvote = gr.Button(downvote_button_value, variant="secondary")
167
+ clear = gr.Button("🗑️ Clear", variant="secondary")
168
+
169
+ session_state = gr.State([])
170
+ jupyter_state = gr.State(JupyterNotebook())
171
+ dialog_info = gr.State(["", None])
172
+ demo.load(reset_dialog_info, dialog_info, dialog_info)
173
+
174
+ def bot(user_message, history, jupyter_state, dialog_info, interpreter):
175
+ logging.info(f"user message: {user_message}")
176
+ interpreter.dialog = convert_history(gradio_history=history, interpreter_history=interpreter.dialog)
177
+ history.append([user_message, None])
178
+
179
+ interpreter.dialog.append({"role": "user", "content": user_message})
180
+
181
+ # setup
182
+ HAS_CODE = False # For now
183
+ prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog)
184
+
185
+ _ = interpreter.generate(prompt)
186
+ history[-1][1] = ""
187
+ generated_text = ""
188
+ for character in interpreter.streamer:
189
+ history[-1][1] += character
190
+ history[-1][1] = history[-1][1].replace("<|EOT|>","")
191
+ generated_text += character
192
+ yield history, history, jupyter_state, dialog_info
193
+
194
+ if is_valid_python_code(history[-1][1].strip()):
195
+ history[-1][1] = f"```python\n{history[-1][1].strip()}\n```"
196
+ generated_text = history[-1][1]
197
+
198
+ HAS_CODE, generated_code_block = interpreter.extract_code_blocks(
199
+ generated_text
200
+ )
201
+
202
+ interpreter.dialog.append(
203
+ {
204
+ "role": "assistant",
205
+ "content": generated_text.replace("<unk>_", "")
206
+ .replace("<unk>", "")
207
+ .replace("<|EOT|>", ""),
208
+ }
209
+ )
210
+
211
+ logging.info(f"saving current dialog to file {dialog_info[0]}.json...")
212
+ logging.info(f"current dialog: {interpreter.dialog}")
213
+ save_json(interpreter.dialog, mode="openci_only", flag=dialog_info[1], json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0])
214
+
215
+ attempt = 1
216
+ while HAS_CODE:
217
+ if attempt > MAX_TRY:
218
+ break
219
+ # if no code then doesn't have to execute it
220
+ generated_text = "" # clear generated text
221
+
222
+ yield history, history, jupyter_state, dialog_info
223
+
224
+ # replace unknown thing to none ''
225
+ generated_code_block = generated_code_block.replace(
226
+ "<unk>_", ""
227
+ ).replace("<unk>", "")
228
+
229
+ if has_input_function_calls(generated_code_block):
230
+ code_block_output = "Please directly assign the value of inputs instead of using input() function in your code."
231
+ else:
232
+ (
233
+ code_block_output,
234
+ error_flag,
235
+ ) = interpreter.execute_code_and_return_output(
236
+ f"{generated_code_block}",
237
+ jupyter_state
238
+ )
239
+ if error_flag == "Timeout":
240
+ logging.info(f"{dialog_info[0]}: Restart jupyter kernel due to timeout")
241
+ jupyter_state = JupyterNotebook()
242
+ code_block_output = interpreter.clean_code_output(code_block_output)
243
+
244
+ if code_block_output.strip():
245
+ code_block_output = "Execution result: \n" + code_block_output
246
+ else:
247
+ code_block_output = "Code is executed, but result is empty. Please make sure that you include test case in your code."
248
+
249
+ history.append([code_block_output, ""])
250
+
251
+ interpreter.dialog.append({"role": "user", "content": code_block_output})
252
+
253
+ yield history, history, jupyter_state, dialog_info
254
+
255
+ prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog)
256
+
257
+ logging.info(f"generating answer for dialog {dialog_info[0]}")
258
+ _ = interpreter.generate(prompt)
259
+ for character in interpreter.streamer:
260
+ history[-1][1] += character
261
+ history[-1][1] = history[-1][1].replace("<|EOT|>","")
262
+ generated_text += character
263
+ yield history, history, jupyter_state, dialog_info
264
+ logging.info(f"finish generating answer for dialog {dialog_info[0]}")
265
+
266
+ HAS_CODE, generated_code_block = interpreter.extract_code_blocks(
267
+ history[-1][1]
268
+ )
269
+
270
+ interpreter.dialog.append(
271
+ {
272
+ "role": "assistant",
273
+ "content": generated_text.replace("<unk>_", "")
274
+ .replace("<unk>", "")
275
+ .replace("<|EOT|>", ""),
276
+ }
277
+ )
278
+
279
+ attempt += 1
280
+
281
+ logging.info(f"saving current dialog to file {dialog_info[0]}.json...")
282
+ logging.info(f"current dialog: {interpreter.dialog}")
283
+ save_json(interpreter.dialog, mode="openci_only", flag=dialog_info[1], json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0])
284
+
285
+ if generated_text.endswith("<|EOT|>"):
286
+ continue
287
+
288
+ return history, history, jupyter_state, dialog_info
289
+
290
+
291
+ def reset_textbox():
292
+ return gr.update(value="")
293
+
294
+ def set_button_variant(upvote_button_variant, downvote_button_variant):
295
+ return gr.Button(upvote_button_value, variant=upvote_button_variant), gr.Button(downvote_button_value, variant=downvote_button_variant)
296
+
297
+ def reset_button_and_flag(dialog_info):
298
+ return (*set_button_variant("secondary", "secondary"), [dialog_info[0], None])
299
+
300
+ def clear_history(history, jupyter_state, dialog_info, interpreter):
301
+ interpreter.dialog = []
302
+ jupyter_state.close()
303
+ return ([], [], JupyterNotebook(), reset_dialog_info(dialog_info), *set_button_variant("secondary", "secondary"))
304
+
305
+ def toggle_preference(button, dialog_info):
306
+ if button == upvote_button_value:
307
+ dialog_info[1] = True
308
+ elif button == downvote_button_value:
309
+ dialog_info[1] = False
310
+ else:
311
+ raise ValueError(button)
312
+ logging.info(f"{button} is clicked by {dialog_info[0]}, current flag: {dialog_info[1]}")
313
+
314
+ if dialog_info[1] is None:
315
+ return (*set_button_variant("secondary", "secondary"), dialog_info)
316
+ elif dialog_info[1]:
317
+ return (*set_button_variant("primary", "secondary"), dialog_info)
318
+ else:
319
+ return (*set_button_variant("secondary", "primary"), dialog_info)
320
+
321
+ def save_preference(dialog_info, interpreter):
322
+ if interpreter.dialog:
323
+ save_json(interpreter.dialog, mode="openci_only", flag=dialog_info[1], json_file_path=JSON_DATASET_DIR/f"{dialog_info[0]}.json", dialog_id=dialog_info[0])
324
+ return dialog_info
325
+
326
+ interpreter = StreamingOpenCodeInterpreter(model_path=model_path)
327
+
328
+ sub.click(reset_button_and_flag, dialog_info, [upvote, downvote, dialog_info])
329
+ sub.click(partial(bot, interpreter=interpreter), [msg, session_state, jupyter_state, dialog_info], [chatbot, session_state, jupyter_state, dialog_info])
330
+ sub.click(reset_textbox, [], [msg])
331
+
332
+ clear.click(
333
+ partial(clear_history, interpreter=interpreter),
334
+ [session_state, jupyter_state, dialog_info],
335
+ [chatbot, session_state, jupyter_state, dialog_info, upvote, downvote],
336
+ queue=False
337
+ )
338
+
339
+ upvote.click(
340
+ toggle_preference,
341
+ [upvote, dialog_info],
342
+ [upvote, downvote, dialog_info]
343
+ ).then(
344
+ partial(save_preference, interpreter=interpreter),
345
+ dialog_info,
346
+ dialog_info
347
+ )
348
+
349
+ downvote.click(
350
+ toggle_preference,
351
+ [downvote, dialog_info],
352
+ [upvote, downvote, dialog_info]
353
+ ).then(
354
+ partial(save_preference, interpreter=interpreter),
355
+ dialog_info,
356
+ dialog_info
357
+ )
358
+
359
+ demo.queue(max_size=20)
360
+ demo.launch(share=True)
361
+
362
+
363
+ if __name__ == "__main__":
364
+ import argparse
365
+
366
+ parser = argparse.ArgumentParser()
367
+ parser.add_argument(
368
+ "--path",
369
+ type=str,
370
+ required=False,
371
+ help="Path to the OpenCodeInterpreter Model.",
372
+ default="m-a-p/OpenCodeInterpreter-DS-6.7B",
373
+ )
374
+ args = parser.parse_args()
375
+
376
+ gradio_launch(model_path=args.path)
code_interpreter/BaseCodeInterpreter.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+
5
+ prj_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+ sys.path.append(prj_root_path)
7
+
8
+
9
+ from utils.const import *
10
+
11
+ class BaseCodeInterpreter:
12
+ def __init__(self):
13
+ self.dialog = [
14
+ {
15
+ "role": "system",
16
+ "content": CODE_INTERPRETER_SYSTEM_PROMPT,
17
+ },
18
+ ]
19
+
20
+ @staticmethod
21
+ def extract_code_blocks(text: str):
22
+ pattern = r"```(?:python\n)?(.*?)```" # Match optional 'python\n' but don't capture it
23
+ code_blocks = re.findall(pattern, text, re.DOTALL)
24
+ return [block.strip() for block in code_blocks]
25
+
26
+ def execute_code_and_return_output(self, code_str: str, nb):
27
+ _, _ = nb.add_and_run(GUARD_CODE)
28
+ outputs, error_flag = nb.add_and_run(code_str)
29
+ return outputs, error_flag
code_interpreter/JupyterClient.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jupyter_client import KernelManager
2
+ import threading
3
+ import re
4
+ from utils.const import *
5
+
6
+
7
+ class JupyterNotebook:
8
+ def __init__(self):
9
+ self.km = KernelManager()
10
+ self.km.start_kernel()
11
+ self.kc = self.km.client()
12
+ _ = self.add_and_run(TOOLS_CODE)
13
+
14
+ def clean_output(self, outputs):
15
+ outputs_only_str = list()
16
+ for i in outputs:
17
+ if type(i) == dict:
18
+ if "text/plain" in list(i.keys()):
19
+ outputs_only_str.append(i["text/plain"])
20
+ elif type(i) == str:
21
+ outputs_only_str.append(i)
22
+ elif type(i) == list:
23
+ error_msg = "\n".join(i)
24
+ error_msg = re.sub(r"\x1b\[.*?m", "", error_msg)
25
+ outputs_only_str.append(error_msg)
26
+
27
+ return "\n".join(outputs_only_str).strip()
28
+
29
+ def add_and_run(self, code_string):
30
+ # This inner function will be executed in a separate thread
31
+ def run_code_in_thread():
32
+ nonlocal outputs, error_flag
33
+
34
+ # Execute the code and get the execution count
35
+ msg_id = self.kc.execute(code_string)
36
+
37
+ while True:
38
+ try:
39
+ msg = self.kc.get_iopub_msg(timeout=20)
40
+
41
+ msg_type = msg["header"]["msg_type"]
42
+ content = msg["content"]
43
+
44
+ if msg_type == "execute_result":
45
+ outputs.append(content["data"])
46
+ elif msg_type == "stream":
47
+ outputs.append(content["text"])
48
+ elif msg_type == "error":
49
+ error_flag = True
50
+ outputs.append(content["traceback"])
51
+
52
+ # If the execution state of the kernel is idle, it means the cell finished executing
53
+ if msg_type == "status" and content["execution_state"] == "idle":
54
+ break
55
+ except:
56
+ break
57
+
58
+ outputs = []
59
+ error_flag = False
60
+
61
+ # Start the thread to run the code
62
+ thread = threading.Thread(target=run_code_in_thread)
63
+ thread.start()
64
+
65
+ # Wait for 20 seconds for the thread to finish
66
+ thread.join(timeout=20)
67
+
68
+ # If the thread is still alive after 20 seconds, it's a timeout
69
+ if thread.is_alive():
70
+ outputs = ["Execution timed out."]
71
+ # outputs = ["Error"]
72
+ error_flag = "Timeout"
73
+
74
+ return self.clean_output(outputs), error_flag
75
+
76
+ def close(self):
77
+ """Shutdown the kernel."""
78
+ self.km.shutdown_kernel()
79
+
80
+ def __deepcopy__(self, memo):
81
+ if id(self) in memo:
82
+ return memo[id(self)]
83
+ new_copy = type(self)()
84
+ memo[id(self)] = new_copy
85
+ return new_copy
code_interpreter/OpenCodeInterpreter.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ prj_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
5
+ sys.path.append(prj_root_path)
6
+
7
+ from code_interpreter.BaseCodeInterpreter import BaseCodeInterpreter
8
+ from utils.const import *
9
+
10
+ from typing import List, Tuple, Dict
11
+ import re
12
+
13
+ import torch
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+
16
+
17
+ sys.path.append(os.path.dirname(__file__))
18
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
19
+
20
+ import warnings
21
+
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="transformers")
23
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
24
+
25
+
26
+ class OpenCodeInterpreter(BaseCodeInterpreter):
27
+ def __init__(
28
+ self,
29
+ model_path: str,
30
+ load_in_8bit: bool = False,
31
+ load_in_4bit: bool = False,
32
+ ):
33
+ # build tokenizer
34
+ self.tokenizer = AutoTokenizer.from_pretrained(
35
+ model_path,
36
+ padding_side="right",
37
+ trust_remote_code=True
38
+ )
39
+
40
+ self.model = AutoModelForCausalLM.from_pretrained(
41
+ model_path,
42
+ device_map="auto",
43
+ load_in_4bit=load_in_4bit,
44
+ load_in_8bit=load_in_8bit,
45
+ torch_dtype=torch.float16,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ self.model.resize_token_embeddings(len(self.tokenizer))
50
+
51
+ self.model = self.model.eval()
52
+
53
+ self.dialog = []
54
+ self.MAX_CODE_OUTPUT_LENGTH = 1000
55
+
56
+
57
+ def dialog_to_prompt(self, dialog: List[Dict]) -> str:
58
+ full_str = self.tokenizer.apply_chat_template(dialog, tokenize=False)
59
+
60
+ return full_str
61
+
62
+ def extract_code_blocks(self, prompt: str) -> Tuple[bool, str]:
63
+ pattern = re.escape("```python") + r"(.*?)" + re.escape("```")
64
+ matches = re.findall(pattern, prompt, re.DOTALL)
65
+
66
+ if matches:
67
+ # Return the last matched code block
68
+ return True, matches[-1].strip()
69
+ else:
70
+ return False, ""
71
+
72
+ def clean_code_output(self, output: str) -> str:
73
+ if self.MAX_CODE_OUTPUT_LENGTH < len(output):
74
+ return (
75
+ output[: self.MAX_CODE_OUTPUT_LENGTH // 5]
76
+ + "\n...(truncated due to length)...\n"
77
+ + output[-self.MAX_CODE_OUTPUT_LENGTH // 5 :]
78
+ )
79
+
80
+ return output
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.21.0
2
+ bitsandbytes==0.41.1
3
+ colorama==0.4.6
4
+ coloredlogs==15.0.1
5
+ colorlog==6.7.0
6
+ datasets==2.12.0
7
+ deepspeed==0.10.1
8
+ diffusers==0.20.0
9
+ einops==0.6.1
10
+ gradio==3.48.0
11
+ ipykernel==6.25.1
12
+ ipython==8.12.2
13
+ jupyter_client==8.3.0
14
+ jupyter_core==5.3.0
15
+ Markdown==3.4.3
16
+ nbclient==0.8.0
17
+ nbconvert==7.7.1
18
+ nbformat==5.8.0
19
+ omegaconf==2.3.0
20
+ openai==0.27.7
21
+ rich==13.7.0
22
+ scikit-learn==1.4.0
23
+ scipy==1.12.0
24
+ seaborn==0.13.2
25
+ sentencepiece==0.1.99
26
+ termcolor==2.3.0
27
+ tqdm==4.66.1
28
+ transformers==4.37.1
29
+ triton==2.0.0
30
+ yfinance==0.2.28
31
+ retrying==1.3.4
32
+ pydantic<2.0.0
utils/cleaner.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+
4
+ PYTHON_PREFIX = os.environ.get("CONDA_PREFIX", "/usr/local")
5
+
6
+ SITE_PKG_ERROR_PREFIX = f'File {PYTHON_PREFIX}/lib/python3.10/'
7
+
8
+ def get_error_header(traceback_str):
9
+ lines = traceback_str.split('\n')
10
+ for line in lines:
11
+ if 'Error:' in line:
12
+ return line
13
+ return '' # Return None if no error message is found
14
+
15
+ def clean_error_msg(error_str:str =''):
16
+ filtered_error_msg = error_str.__str__().split('An error occurred while executing the following cell')[-1].split("\n------------------\n")[-1]
17
+ raw_error_msg = "".join(filtered_error_msg)
18
+
19
+ # Remove escape sequences for colored text
20
+ ansi_escape = re.compile(r'\x1b\[[0-?]*[ -/]*[@-~]')
21
+ error_msg = ansi_escape.sub('', raw_error_msg)
22
+
23
+ error_str_out = ''
24
+ error_msg_only_cell = error_msg.split(SITE_PKG_ERROR_PREFIX)
25
+
26
+ error_str_out += f'{error_msg_only_cell[0]}\n'
27
+ error_header = get_error_header(error_msg_only_cell[-1])
28
+ if error_header not in error_str_out:
29
+ error_str_out += get_error_header(error_msg_only_cell[-1])
30
+
31
+ return error_str_out
utils/const.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TOOLS_CODE = """
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from scipy import stats
7
+ import os,sys
8
+ import re
9
+ from datetime import datetime
10
+ from sympy import symbols, Eq, solve
11
+ import torch
12
+ import requests
13
+ from bs4 import BeautifulSoup
14
+ import json
15
+ import math
16
+ import yfinance
17
+ import time
18
+ """
19
+
20
+ write_denial_function = 'lambda *args, **kwargs: (_ for _ in ()).throw(PermissionError("Writing to disk operation is not permitted due to safety reasons. Please do not try again!"))'
21
+ read_denial_function = 'lambda *args, **kwargs: (_ for _ in ()).throw(PermissionError("Reading from disk operation is not permitted due to safety reasons. Please do not try again!"))'
22
+ class_denial = """Class Denial:
23
+ def __getattr__(self, name):
24
+ def method(*args, **kwargs):
25
+ return "Using this class is not permitted due to safety reasons. Please do not try again!"
26
+ return method
27
+ """
28
+
29
+ GUARD_CODE = f"""
30
+ import builtins
31
+
32
+ _original_open = open
33
+
34
+ def custom_open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None):
35
+ if 'w' in mode or 'a' in mode or 'x' in mode or '+' in mode:
36
+ raise PermissionError("Writing operation is not permitted due to safety reasons. Please do not try again!")
37
+ return _original_open(file, mode, buffering, encoding, errors, newline, closefd, opener)
38
+
39
+ builtins.open = custom_open
40
+
41
+ builtins.exit = {write_denial_function}
42
+ builtins.quit = {write_denial_function}
43
+
44
+ import sys
45
+
46
+ blocked_modules = ['pathlib', 'glob', 'ctypes']
47
+
48
+ for module in blocked_modules:
49
+ sys.modules[module] = PermissionError
50
+
51
+ import os
52
+
53
+ os.listdir = {read_denial_function}
54
+ os.scandir = {read_denial_function}
55
+ os.walk = {read_denial_function}
56
+ os.stat = {read_denial_function}
57
+ os.kill = {write_denial_function}
58
+ os.system = {write_denial_function}
59
+ os.putenv = {write_denial_function}
60
+ os.remove = {write_denial_function}
61
+ os.removedirs = {write_denial_function}
62
+ os.rmdir = {write_denial_function}
63
+ os.fchdir = {write_denial_function}
64
+ os.setuid = {write_denial_function}
65
+ os.fork = {write_denial_function}
66
+ os.forkpty = {write_denial_function}
67
+ os.killpg = {write_denial_function}
68
+ os.rename = {write_denial_function}
69
+ os.renames = {write_denial_function}
70
+ os.truncate = {write_denial_function}
71
+ os.replace = {write_denial_function}
72
+ os.unlink = {write_denial_function}
73
+ os.fchmod = {write_denial_function}
74
+ os.fchown = {write_denial_function}
75
+ os.chmod = {write_denial_function}
76
+ os.chown = {write_denial_function}
77
+ os.chroot = {write_denial_function}
78
+ os.fchdir = {write_denial_function}
79
+ os.lchflags = {write_denial_function}
80
+ os.lchmod = {write_denial_function}
81
+ os.lchown = {write_denial_function}
82
+ os.getcwd = {write_denial_function}
83
+ os.chdir = {write_denial_function}
84
+ os.popen = {write_denial_function}
85
+ os.environ = {{}}
86
+ os.getenv = {write_denial_function}
87
+ builtins.open = {write_denial_function}
88
+
89
+ import shutil
90
+
91
+ shutil.rmtree = {write_denial_function}
92
+ shutil.move = {write_denial_function}
93
+ shutil.chown = {write_denial_function}
94
+
95
+ import subprocess
96
+
97
+ subprocess.Popen = {write_denial_function} # type: ignore
98
+
99
+ __builtins__["help"] = {write_denial_function}
100
+
101
+ import sys
102
+
103
+ sys.modules["ipdb"] = {write_denial_function}
104
+ sys.modules["joblib"] = {write_denial_function}
105
+ sys.modules["resource"] = {write_denial_function}
106
+ sys.modules["psutil"] = {write_denial_function}
107
+ sys.modules["tkinter"] = {write_denial_function}
108
+
109
+ get_ipython().system = lambda *args, **kwargs: (_ for _ in ()).throw(PermissionError("Sorry, magic command is disabled due to safety reasons. Please do not try again!"))
110
+ """
111
+
112
+ CODE_INTERPRETER_SYSTEM_PROMPT = """You are an AI code interpreter.
113
+ Your goal is to help users do a variety of jobs by executing Python code.
114
+
115
+ You should:
116
+ 1. Comprehend the user's requirements carefully & to the letter.
117
+ 2. Give a brief description for what you plan to do & call the provided function to run code.
118
+ 3. Provide results analysis based on the execution output.
119
+ 4. If error occurred, try to fix it.
120
+ 5. Response in the same language as the user."""