Spaces:
Sleeping
Sleeping
File size: 8,437 Bytes
3087bb1 b9e5451 3087bb1 14e166f 3087bb1 b9e5451 14e166f b9e5451 14e166f b9e5451 3087bb1 cf544db 3087bb1 cf544db 53733d9 cf544db 53733d9 cf544db 3087bb1 14e166f 4feb357 14e166f 53733d9 be1b538 cf544db 3087bb1 53733d9 3087bb1 14e166f 3087bb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 |
import gradio as gr
import tempfile
import pytest
import io
import sys
import os
import requests
import ast
api_base = "https://api.endpoints.anyscale.com/v1"
token = os.environ["OPENAI_API_KEY"]
url = f"{api_base}/chat/completions"
def extract_functions_from_file(filename):
"""Given a file written to disk, extract all functions from it into a list."""
with open(filename, "r") as file:
tree = ast.parse(file.read())
functions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
start_line = node.lineno
end_line = node.end_lineno if hasattr(node, "end_lineno") else start_line
with open(filename, "r") as file:
function_code = "".join(
[
line
for i, line in enumerate(file)
if start_line <= i + 1 <= end_line
]
)
functions.append(function_code)
return functions
def extract_tests_from_list(l):
"""Given a list of strings, extract all functions from it into a list."""
return [t for t in l if t.startswith("def")]
def remove_leading_whitespace(func_str):
"""Given a string representing a function, remove the leading whitespace from each
line such that the function definition is left-aligned and all following lines
follow Python's whitespace formatting rules.
"""
lines = func_str.split("\n")
# Find the amount of whitespace before 'def' (the function signature)
leading_whitespace = len(lines[0]) - len(lines[0].lstrip())
# Remove that amount of whitespace from each line
new_lines = [line[leading_whitespace:] for line in lines if line.strip()]
return "\n".join(new_lines)
def main(fxn: str, openai_api_key, examples: str = "", temperature: float = 0.7):
"""Requires Anyscale Endpoints Alpha API access.
If examples is not a empty string, it will be formatted into
a list of input/output pairs used to prompt the model.
"""
s = requests.Session()
api_base = os.environ["OPENAI_API_BASE"]
token = openai_api_key
url = f"{api_base}/chat/completions"
message = "Write me a test of this function\n{}".format(fxn)
if examples:
message += "\nExample input output pairs:\n"
system_prompt = """
You are a helpful coding assistant.
Your job is to help people write unit tests for their python code. Please write all
unit tests in the format expected by pytest. If inputs and outputs are provided,
return a set of unit tests that will verify that the function will produce the
corect outputs. Also provide tests to handle base and edge cases. It is very
important that the code is formatted correctly for pytest.
"""
body = {
"model": "meta-llama/Llama-2-70b-chat-hf",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
"temperature": temperature,
}
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
response = resp.json()["choices"][0]
if response["finish_reason"] != "stop":
raise ValueError("Print please try again -- response was not finished!")
# Parse the response to get the tests out.
split_response = response["message"]["content"].split("```")
if len(split_response) != 3:
raise ValueError("Please try again -- response generated too many code blocks!")
all_tests = split_response[1]
# Writes out all tests to a file. Then, extracts each individual test out into a
# list.
with tempfile.NamedTemporaryFile(
prefix="all_tests_", suffix=".py", mode="w"
) as temp:
temp.writelines(all_tests)
temp.flush()
parsed_tests = extract_functions_from_file(temp.name)
# Loop through test, run pytest, and return two lists of tests.
passed_tests, failed_tests = [], []
for test in parsed_tests:
test_formatted = remove_leading_whitespace(test)
print("testing: \n {}".format(test_formatted))
with tempfile.NamedTemporaryFile(
prefix="test_", suffix=".py", mode="w"
) as temp:
# Writes out each test to a file. Then, runs pytest on that file.
full_test_file = "#!/usr/bin/env python\n\nimport pytest\n{}\n{}".format(
fxn, test_formatted
)
temp.writelines(full_test_file)
temp.flush()
retcode = pytest.main(["-x", temp.name])
print(retcode.name)
if retcode.name == "TESTS_FAILED":
failed_tests.append(test)
print("test failed")
elif retcode.name == "OK":
passed_tests.append(test)
print("test passed")
passed_tests = "\n".join(passed_tests)
failed_tests = "\n".join(failed_tests)
return passed_tests, failed_tests
def generate_test(code):
s = requests.Session()
message = "Write me a test of this function\n{}".format(code)
system_prompt = """
You are a helpful coding assistant.
Your job is to help people write unit tests for the python code.
If inputs and outputs are provided, please return a set of unit tests that will
verify that the function will produce the corect outputs. Also provide tests to
handle base and edge cases.
"""
body = {
"model": "meta-llama/Llama-2-70b-chat-hf",
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": message},
],
"temperature": 0.7,
}
with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
response = resp.json()["choices"][0]
if response["finish_reason"] != "stop":
raise ValueError("Print please try again -- response was not finished!")
split_response = response["message"]["content"].split("```")
if len(split_response) != 3:
raise ValueError("Please try again -- response generated too many code blocks!")
def execute_code(code, test):
# Capture the standard output in a StringIO object
old_stdout = sys.stdout
new_stdout = io.StringIO()
sys.stdout = new_stdout
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f:
f.writelines(code)
f.writelines(test)
f.flush()
temp_path = f.name
pytest.main(["-x", temp_path])
# Restore the standard output
sys.stdout = old_stdout
# Get the captured output from the StringIO object
output = new_stdout.getvalue()
return output
examples = ["""
def prime_factors(n):
i = 2
factors = []
while i * i <= n:
if n % i:
i += 1
else:
n //= i
factors.append(i)
if n > 1:
factors.append(n)
return factors
""",
"""
import numpy
def matrix_multiplication(A, B):
return np.dot(A, B)
""",
"""
import numpy as np
def efficient_is_semipositive_definite(matrix):
try:
# Attempt Cholesky decomposition
np.linalg.cholesky(matrix)
return True
except np.linalg.LinAlgError:
return False
""",
"""
import numpy as np
def is_semipositive_definite(matrix):
# Compute the eigenvalues of the matrix
eigenvalues = np.linalg.eigvals(matrix)
# Check if all eigenvalues are non-negative
return all(val >= 0 for val in eigenvalues)
"""
]
example = examples[0]
with gr.Blocks() as demo:
gr.Markdown("<h1><center>Llama_test: generate unit test for your Python code</center></h1>")
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Set your Anyscale API key here.",
lines=1,
type="password"
)
code_input = gr.Code(example, language="python", label="Provide the code of the function you want to test")
gr.Examples(
examples=examples,
inputs=code_input,)
generate_btn = gr.Button("Generate test")
with gr.Row():
code_output = gr.Code(language="python", label="Passed tests")
code_output2 = gr.Code(language="python", label="Failed tests")
generate_btn.click(main, inputs=[code_input, openai_api_key], outputs=[code_output, code_output2])
if __name__ == "__main__":
demo.launch() |