Spaces:
Sleeping
Sleeping
import gradio as gr | |
import tempfile | |
import pytest | |
import io | |
import sys | |
import os | |
import requests | |
api_base = "https://api.endpoints.anyscale.com/v1" | |
token = os.environ["OPENAI_API_KEY"] | |
url = f"{api_base}/chat/completions" | |
def generate_test(code): | |
s = requests.Session() | |
message = "Write me a test of this function\n{}".format(code) | |
system_prompt = """ | |
You are a helpful coding assistant. | |
Your job is to help people write unit tests for the python code. | |
If inputs and outputs are provided, please return a set of unit tests that will | |
verify that the function will produce the corect outputs. Also provide tests to | |
handle base and edge cases. | |
""" | |
body = { | |
"model": "meta-llama/Llama-2-70b-chat-hf", | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": message}, | |
], | |
"temperature": 0.7, | |
} | |
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp: | |
response = resp.json()["choices"][0] | |
if response["finish_reason"] != "stop": | |
raise ValueError("Print please try again -- response was not finished!") | |
split_response = response["message"]["content"].split("```") | |
if len(split_response) != 3: | |
raise ValueError("Please try again -- response generated too many code blocks!") | |
def execute_code(code, test): | |
# Capture the standard output in a StringIO object | |
old_stdout = sys.stdout | |
new_stdout = io.StringIO() | |
sys.stdout = new_stdout | |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f: | |
f.writelines(code) | |
f.writelines(test) | |
f.flush() | |
temp_path = f.name | |
pytest.main(["-x", temp_path]) | |
# Restore the standard output | |
sys.stdout = old_stdout | |
# Get the captured output from the StringIO object | |
output = new_stdout.getvalue() | |
return output | |
example = """ | |
def prime_factors(n): | |
i = 2 | |
factors = [] | |
while i * i <= n: | |
if n % i: | |
i += 1 | |
else: | |
n //= i | |
factors.append(i) | |
if n > 1: | |
factors.append(n) | |
return factors | |
""" | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1><center>Llama_test: generate unit test for your Python code</center></h1>") | |
with gr.Row(): | |
code_input = gr.Code(example, label="Provide the code of the function you want to test") | |
generate_btn = gr.Button("Generate test") | |
with gr.Row(): | |
code_output = gr.Code() | |
code_output2 = gr.Code() | |
generate_btn.click(execute_code, outputs=code_output) | |
if __name__ == "__main__": | |
demo.launch() |