import gradio as gr import tempfile import pytest import io import sys import os import requests import ast api_base = "https://api.endpoints.anyscale.com/v1" token = os.environ["OPENAI_API_KEY"] url = f"{api_base}/chat/completions" def extract_functions_from_file(filename): """Given a file written to disk, extract all functions from it into a list.""" with open(filename, "r") as file: tree = ast.parse(file.read()) functions = [] for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): start_line = node.lineno end_line = node.end_lineno if hasattr(node, "end_lineno") else start_line with open(filename, "r") as file: function_code = "".join( [ line for i, line in enumerate(file) if start_line <= i + 1 <= end_line ] ) functions.append(function_code) return functions def extract_tests_from_list(l): """Given a list of strings, extract all functions from it into a list.""" return [t for t in l if t.startswith("def")] def remove_leading_whitespace(func_str): """Given a string representing a function, remove the leading whitespace from each line such that the function definition is left-aligned and all following lines follow Python's whitespace formatting rules. """ lines = func_str.split("\n") # Find the amount of whitespace before 'def' (the function signature) leading_whitespace = len(lines[0]) - len(lines[0].lstrip()) # Remove that amount of whitespace from each line new_lines = [line[leading_whitespace:] for line in lines if line.strip()] return "\n".join(new_lines) def main(fxn: str, openai_api_key, examples: str = "", temperature: float = 0.7): """Requires Anyscale Endpoints Alpha API access. If examples is not a empty string, it will be formatted into a list of input/output pairs used to prompt the model. """ s = requests.Session() api_base = os.environ["OPENAI_API_BASE"] token = openai_api_key url = f"{api_base}/chat/completions" message = "Write me a test of this function\n{}".format(fxn) if examples: message += "\nExample input output pairs:\n" system_prompt = """ You are a helpful coding assistant. Your job is to help people write unit tests for their python code. Please write all unit tests in the format expected by pytest. If inputs and outputs are provided, return a set of unit tests that will verify that the function will produce the corect outputs. Also provide tests to handle base and edge cases. It is very important that the code is formatted correctly for pytest. """ body = { "model": "meta-llama/Llama-2-70b-chat-hf", "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message}, ], "temperature": temperature, } with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp: response = resp.json()["choices"][0] if response["finish_reason"] != "stop": raise ValueError("Print please try again -- response was not finished!") # Parse the response to get the tests out. split_response = response["message"]["content"].split("```") if len(split_response) != 3: raise ValueError("Please try again -- response generated too many code blocks!") all_tests = split_response[1] # Writes out all tests to a file. Then, extracts each individual test out into a # list. with tempfile.NamedTemporaryFile( prefix="all_tests_", suffix=".py", mode="w" ) as temp: temp.writelines(all_tests) temp.flush() parsed_tests = extract_functions_from_file(temp.name) # Loop through test, run pytest, and return two lists of tests. passed_tests, failed_tests = [], [] for test in parsed_tests: test_formatted = remove_leading_whitespace(test) print("testing: \n {}".format(test_formatted)) with tempfile.NamedTemporaryFile( prefix="test_", suffix=".py", mode="w" ) as temp: # Writes out each test to a file. Then, runs pytest on that file. full_test_file = "#!/usr/bin/env python\n\nimport pytest\n{}\n{}".format( fxn, test_formatted ) temp.writelines(full_test_file) temp.flush() retcode = pytest.main(["-x", temp.name]) print(retcode.name) if retcode.name == "TESTS_FAILED": failed_tests.append(test) print("test failed") elif retcode.name == "OK": passed_tests.append(test) print("test passed") passed_tests = "\n".join(passed_tests) failed_tests = "\n".join(failed_tests) return passed_tests, failed_tests def generate_test(code): s = requests.Session() message = "Write me a test of this function\n{}".format(code) system_prompt = """ You are a helpful coding assistant. Your job is to help people write unit tests for the python code. If inputs and outputs are provided, please return a set of unit tests that will verify that the function will produce the corect outputs. Also provide tests to handle base and edge cases. """ body = { "model": "meta-llama/Llama-2-70b-chat-hf", "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": message}, ], "temperature": 0.7, } with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp: response = resp.json()["choices"][0] if response["finish_reason"] != "stop": raise ValueError("Print please try again -- response was not finished!") split_response = response["message"]["content"].split("```") if len(split_response) != 3: raise ValueError("Please try again -- response generated too many code blocks!") def execute_code(code, test): # Capture the standard output in a StringIO object old_stdout = sys.stdout new_stdout = io.StringIO() sys.stdout = new_stdout with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f: f.writelines(code) f.writelines(test) f.flush() temp_path = f.name pytest.main(["-x", temp_path]) # Restore the standard output sys.stdout = old_stdout # Get the captured output from the StringIO object output = new_stdout.getvalue() return output examples = [""" def prime_factors(n): i = 2 factors = [] while i * i <= n: if n % i: i += 1 else: n //= i factors.append(i) if n > 1: factors.append(n) return factors """, """ import numpy def matrix_multiplication(A, B): return np.dot(A, B) """, """ import numpy as np def efficient_is_semipositive_definite(matrix): try: # Attempt Cholesky decomposition np.linalg.cholesky(matrix) return True except np.linalg.LinAlgError: return False """, """ import numpy as np def is_semipositive_definite(matrix): # Compute the eigenvalues of the matrix eigenvalues = np.linalg.eigvals(matrix) # Check if all eigenvalues are non-negative return all(val >= 0 for val in eigenvalues) """ ] example = examples[0] with gr.Blocks() as demo: gr.Markdown("