File size: 8,437 Bytes
3087bb1
 
 
 
 
 
 
b9e5451
 
3087bb1
14e166f
3087bb1
 
 
 
b9e5451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14e166f
b9e5451
 
 
 
 
 
 
 
14e166f
b9e5451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3087bb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf544db
3087bb1
 
 
 
 
 
 
 
 
 
 
 
cf544db
 
 
53733d9
cf544db
53733d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf544db
 
 
3087bb1
 
 
14e166f
 
4feb357
14e166f
 
 
53733d9
be1b538
 
 
cf544db
3087bb1
 
53733d9
 
3087bb1
14e166f
3087bb1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import gradio as gr
import tempfile
import pytest
import io
import sys
import os
import requests
import ast



api_base = "https://api.endpoints.anyscale.com/v1"
token = os.environ["OPENAI_API_KEY"]
url = f"{api_base}/chat/completions"

def extract_functions_from_file(filename):
    """Given a file written to disk, extract all functions from it into a list."""
    with open(filename, "r") as file:
        tree = ast.parse(file.read())

    functions = []

    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            start_line = node.lineno
            end_line = node.end_lineno if hasattr(node, "end_lineno") else start_line
            with open(filename, "r") as file:
                function_code = "".join(
                    [
                        line
                        for i, line in enumerate(file)
                        if start_line <= i + 1 <= end_line
                    ]
                )
            functions.append(function_code)

    return functions


def extract_tests_from_list(l):
    """Given a list of strings, extract all functions from it into a list."""
    return [t for t in l if t.startswith("def")]


def remove_leading_whitespace(func_str):
    """Given a string representing a function, remove the leading whitespace from each
    line such that the function definition is left-aligned and all following lines
    follow Python's whitespace formatting rules.
    """
    lines = func_str.split("\n")
    # Find the amount of whitespace before 'def' (the function signature)
    leading_whitespace = len(lines[0]) - len(lines[0].lstrip())
    # Remove that amount of whitespace from each line
    new_lines = [line[leading_whitespace:] for line in lines if line.strip()]
    return "\n".join(new_lines)


def main(fxn: str, openai_api_key, examples: str = "", temperature: float = 0.7):
    """Requires Anyscale Endpoints Alpha API access.

    If examples is not a empty string, it will be formatted into
    a list of input/output pairs used to prompt the model.
    """

    s = requests.Session()
    api_base = os.environ["OPENAI_API_BASE"]
    token = openai_api_key
    url = f"{api_base}/chat/completions"

    message = "Write me a test of this function\n{}".format(fxn)

    if examples:
        message += "\nExample input output pairs:\n"

    system_prompt = """
    You are a helpful coding assistant.
    Your job is to help people write unit tests for their python code. Please write all
    unit tests in the format expected by pytest. If inputs and outputs are provided,
    return a set of unit tests that will verify that the function will produce the
    corect outputs. Also provide tests to handle base and edge cases. It is very
    important that the code is formatted correctly for pytest.
    """

    body = {
        "model": "meta-llama/Llama-2-70b-chat-hf",
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        "temperature": temperature,
    }

    with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
        response = resp.json()["choices"][0]

    if response["finish_reason"] != "stop":
        raise ValueError("Print please try again -- response was not finished!")

    # Parse the response to get the tests out.
    split_response = response["message"]["content"].split("```")
    if len(split_response) != 3:
        raise ValueError("Please try again -- response generated too many code blocks!")

    all_tests = split_response[1]

    # Writes out all tests to a file. Then, extracts each individual test out into a
    # list.
    with tempfile.NamedTemporaryFile(
        prefix="all_tests_", suffix=".py", mode="w"
    ) as temp:
        temp.writelines(all_tests)
        temp.flush()
        parsed_tests = extract_functions_from_file(temp.name)

    # Loop through test, run pytest, and return two lists of tests.
    passed_tests, failed_tests = [], []
    for test in parsed_tests:
        test_formatted = remove_leading_whitespace(test)

        print("testing: \n {}".format(test_formatted))

        with tempfile.NamedTemporaryFile(
            prefix="test_", suffix=".py", mode="w"
        ) as temp:
            # Writes out each test to a file. Then, runs pytest on that file.
            full_test_file = "#!/usr/bin/env python\n\nimport pytest\n{}\n{}".format(
                fxn, test_formatted
            )
            temp.writelines(full_test_file)
            temp.flush()

            retcode = pytest.main(["-x", temp.name])

            print(retcode.name)
            if retcode.name == "TESTS_FAILED":
                failed_tests.append(test)
                print("test failed")

            elif retcode.name == "OK":
                passed_tests.append(test)
                print("test passed")

    passed_tests = "\n".join(passed_tests)
    failed_tests = "\n".join(failed_tests)
    return passed_tests, failed_tests

def generate_test(code):
    s = requests.Session()
    message = "Write me a test of this function\n{}".format(code)
    system_prompt = """
    You are a helpful coding assistant.
    Your job is to help people write unit tests for the python code.
    If inputs and outputs are provided, please return a set of unit tests that will
    verify that the function will produce the corect outputs. Also provide tests to
    handle base and edge cases.
    """

    body = {
        "model": "meta-llama/Llama-2-70b-chat-hf",
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": message},
        ],
        "temperature": 0.7,
    }

    with s.post(url, headers={"Authorization": f"Bearer {token}"}, json=body) as resp:
        response = resp.json()["choices"][0]

    if response["finish_reason"] != "stop":
        raise ValueError("Print please try again -- response was not finished!")

    split_response = response["message"]["content"].split("```")
    if len(split_response) != 3:
        raise ValueError("Please try again -- response generated too many code blocks!")


def execute_code(code, test):    

    # Capture the standard output in a StringIO object
    old_stdout = sys.stdout
    new_stdout = io.StringIO()
    sys.stdout = new_stdout

    with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.py') as f:
        f.writelines(code)
        f.writelines(test)
        f.flush()
        temp_path = f.name
        pytest.main(["-x", temp_path])

    # Restore the standard output
    sys.stdout = old_stdout

    # Get the captured output from the StringIO object
    output = new_stdout.getvalue()
    return output

examples = ["""
def prime_factors(n):
    i = 2
    factors = []
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            factors.append(i)
    if n > 1:
        factors.append(n)
    return factors
    """,
"""
import numpy
def matrix_multiplication(A, B):
    return np.dot(A, B)
""",
"""
import numpy as np

def efficient_is_semipositive_definite(matrix):
    try:
        # Attempt Cholesky decomposition
        np.linalg.cholesky(matrix)
        return True
    except np.linalg.LinAlgError:
        return False
""",
"""
import numpy as np

def is_semipositive_definite(matrix):
    # Compute the eigenvalues of the matrix
    eigenvalues = np.linalg.eigvals(matrix)

    # Check if all eigenvalues are non-negative
    return all(val >= 0 for val in eigenvalues)
"""
    ]
example = examples[0]

with gr.Blocks() as demo:
    gr.Markdown("<h1><center>Llama_test: generate unit test for your Python code</center></h1>")
    openai_api_key = gr.Textbox(
                show_label=False,
                placeholder="Set your Anyscale API key here.",
                lines=1,
                type="password"
            )
    code_input = gr.Code(example, language="python", label="Provide the code of the function you want to test")
    gr.Examples(
    examples=examples,
    inputs=code_input,)

    generate_btn = gr.Button("Generate test")
    with gr.Row():
        code_output = gr.Code(language="python", label="Passed tests")
        code_output2 = gr.Code(language="python", label="Failed tests")

    generate_btn.click(main, inputs=[code_input, openai_api_key], outputs=[code_output, code_output2])
if __name__ == "__main__":
    demo.launch()