LLMBB-Agent / benchmark /code_interpreter.py
vlff李飞飞
update md
2319518
raw
history blame
7.75 kB
import base64
import io
import json
import logging
import os
import queue
import re
import subprocess
import sys
import time
import traceback
import uuid
import matplotlib
import PIL.Image
from jupyter_client import BlockingKernelClient
from utils.code_utils import extract_code
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', '/tmp/workspace')
LAUNCH_KERNEL_PY = """
from ipykernel import kernelapp as app
app.launch_new_instance()
"""
_KERNEL_CLIENTS = {}
# Run this fix before jupyter starts if matplotlib cannot render CJK fonts.
# And we need to additionally run the following lines in the jupyter notebook.
# ```python
# import matplotlib.pyplot as plt
# plt.rcParams['font.sans-serif'] = ['SimHei']
# plt.rcParams['axes.unicode_minus'] = False
# ````
def fix_matplotlib_cjk_font_issue():
local_ttf = os.path.join(
os.path.abspath(
os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
'fonts', 'ttf', 'simhei.ttf')
if not os.path.exists(local_ttf):
logging.warning(
f'Missing font file `{local_ttf}` for matplotlib. It may cause some error when using matplotlib.'
)
def start_kernel(pid):
fix_matplotlib_cjk_font_issue()
connection_file = os.path.join(WORK_DIR,
f'kernel_connection_file_{pid}.json')
launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
for f in [connection_file, launch_kernel_script]:
if os.path.exists(f):
logging.warning(f'{f} already exists')
os.remove(f)
os.makedirs(WORK_DIR, exist_ok=True)
with open(launch_kernel_script, 'w') as fout:
fout.write(LAUNCH_KERNEL_PY)
kernel_process = subprocess.Popen([
sys.executable,
launch_kernel_script,
'--IPKernelApp.connection_file',
connection_file,
'--matplotlib=inline',
'--quiet',
],
cwd=WORK_DIR)
logging.info(f"INFO: kernel process's PID = {kernel_process.pid}")
# Wait for kernel connection file to be written
while True:
if not os.path.isfile(connection_file):
time.sleep(0.1)
else:
# Keep looping if JSON parsing fails, file may be partially written
try:
with open(connection_file, 'r') as fp:
json.load(fp)
break
except json.JSONDecodeError:
pass
# Client
kc = BlockingKernelClient(connection_file=connection_file)
kc.load_connection_file()
kc.start_channels()
kc.wait_for_ready()
return kc
def escape_ansi(line):
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', line)
def publish_image_to_local(image_base64: str):
image_file = str(uuid.uuid4()) + '.png'
local_image_file = os.path.join(WORK_DIR, image_file)
png_bytes = base64.b64decode(image_base64)
assert isinstance(png_bytes, bytes)
bytes_io = io.BytesIO(png_bytes)
PIL.Image.open(bytes_io).save(local_image_file, 'png')
return local_image_file
START_CODE = """
import signal
def _m6_code_interpreter_timeout_handler(signum, frame):
raise TimeoutError("M6_CODE_INTERPRETER_TIMEOUT")
signal.signal(signal.SIGALRM, _m6_code_interpreter_timeout_handler)
def input(*args, **kwargs):
raise NotImplementedError('Python input() function is disabled.')
import os
if 'upload_file' not in os.getcwd():
os.chdir("./upload_file/")
import math
import re
import json
import seaborn as sns
sns.set_theme()
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
import numpy as np
import pandas as pd
from sympy import Eq, symbols, solve
"""
def code_interpreter(action_input_list: list, timeout=30, clear=False):
code = ''
for action_input in action_input_list:
code += (extract_code(action_input) + '\n')
fixed_code = []
for line in code.split('\n'):
fixed_code.append(line)
if line.startswith('sns.set_theme('):
fixed_code.append('plt.rcParams["font.sans-serif"] = ["SimHei"]')
fixed_code.append('plt.rcParams["axes.unicode_minus"] = False')
fixed_code = '\n'.join(fixed_code)
if 'def solution()' in fixed_code:
fixed_code += '\nsolution()'
return _code_interpreter(fixed_code, timeout, clear)
def _code_interpreter(code: str, timeout, clear=False):
if not code.strip():
return ''
if timeout:
code = f'signal.alarm({timeout})\n{code}'
if clear:
code = "get_ipython().run_line_magic('reset', '-f')\n" + START_CODE + code
pid = os.getpid()
if pid not in _KERNEL_CLIENTS:
_KERNEL_CLIENTS[pid] = start_kernel(pid)
_code_interpreter(START_CODE, timeout=None)
kc = _KERNEL_CLIENTS[pid]
kc.wait_for_ready()
kc.execute(code)
result = ''
image_idx = 0
while True:
text = ''
image = ''
finished = False
msg_type = 'error'
try:
msg = kc.get_iopub_msg()
msg_type = msg['msg_type']
if msg_type == 'status':
if msg['content'].get('execution_state') == 'idle':
finished = True
elif msg_type == 'execute_result':
text = msg['content']['data'].get('text/plain', '')
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = publish_image_to_local(image_b64)
image_idx += 1
image = '![fig-%03d](%s)' % (image_idx, image_url)
elif msg_type == 'display_data':
if 'image/png' in msg['content']['data']:
image_b64 = msg['content']['data']['image/png']
image_url = publish_image_to_local(image_b64)
image_idx += 1
image = '![fig-%03d](%s)' % (image_idx, image_url)
else:
text = msg['content']['data'].get('text/plain', '')
elif msg_type == 'stream':
msg_type = msg['content']['name'] # stdout, stderr
text = msg['content']['text']
elif msg_type == 'error':
text = escape_ansi('\n'.join(msg['content']['traceback']))
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
text = f'Timeout. No response after {timeout} seconds.'
except queue.Empty:
text = f'Timeout. No response after {timeout} seconds.'
finished = True
except Exception:
text = 'The code interpreter encountered an unexpected error.'
logging.warning(''.join(
traceback.format_exception(*sys.exc_info())))
finished = True
if text:
result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
if image:
result += f'\n\n{image}'
if finished:
break
result = result.lstrip('\n')
if timeout:
_code_interpreter('signal.alarm(0)', timeout=None)
return result
def get_multiline_input(hint):
print(hint)
print('// Press ENTER to make a new line. Press CTRL-D to end input.')
lines = []
while True:
try:
line = input()
except EOFError: # CTRL-D
break
lines.append(line)
print('// Input received.')
if lines:
return '\n'.join(lines)
else:
return ''
if __name__ == '__main__':
while True:
print(code_interpreter([get_multiline_input('Enter python code:')]))