# import os | |
# import re | |
# import streamlit as st | |
# import ast | |
# import json | |
# import openai | |
# from llama_index.llms.openai import OpenAI | |
# from llama_index.core.llms import ChatMessage | |
# from llama_index.llms.anthropic import Anthropic | |
# from llama_index.llms.mistralai import MistralAI | |
# import nest_asyncio | |
# nest_asyncio.apply() | |
# # import ollama | |
# # from llama_index.llms.ollama import Ollama | |
# # from llama_index.core.llms import ChatMessage | |
# # OpenAI credentials | |
# # key = os.getenv('OPENAI_API_KEY') | |
# # openai.api_key = key | |
# # os.environ["OPENAI_API_KEY"] = key | |
# # Anthropic credentials | |
# # key = os.getenv('CLAUDE_API_KEY') | |
# # os.environ["ANTHROPIC_API_KEY"] = key | |
# # Mistral | |
# key = os.getenv('MISTRAL_API_KEY') | |
# os.environ["MISTRAL_API_KEY"] = key | |
# # Streamlit UI | |
# st.title("Auto Test Case Generation using LLM") | |
# uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py","java"], accept_multiple_files=True) | |
# if uploaded_files: | |
# for uploaded_file in uploaded_files: | |
# with open(f"./data/{uploaded_file.name}", 'wb') as f: | |
# f.write(uploaded_file.getbuffer()) | |
# st.success("File uploaded...") | |
# # Check file type | |
# _, file_extension = os.path.splitext(uploaded_file.name) | |
# print(file_extension) | |
# st.success("Fetching list of functions...") | |
# file_path = f"./data/{uploaded_file.name}" | |
# def extract_functions_from_file(file_path, file_extension): | |
# if file_extension == '.py': | |
# with open(file_path, "r") as file: | |
# file_content = file.read() | |
# parsed_content = ast.parse(file_content) | |
# methods = {} | |
# for node in ast.walk(parsed_content): | |
# if isinstance(node, ast.FunctionDef): | |
# func_name = node.name | |
# func_body = ast.get_source_segment(file_content, node) | |
# methods[func_name] = func_body | |
# elif file_extension == '.java': | |
# with open(file_path, 'r') as file: | |
# lines = file.readlines() | |
# methods = {} | |
# inside_method = False | |
# method_name = None | |
# method_body = [] | |
# brace_count = 0 | |
# method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{') | |
# for line in lines: | |
# if not inside_method: | |
# match = method_signature_pattern.search(line) | |
# if match: | |
# modifiers, method_name = match.groups() | |
# inside_method = True | |
# method_body.append(line) | |
# brace_count = line.count('{') - line.count('}') | |
# else: | |
# method_body.append(line) | |
# brace_count += line.count('{') - line.count('}') | |
# if brace_count == 0: | |
# inside_method = False | |
# methods[method_name] = ''.join(method_body) | |
# method_body = [] | |
# if 'main' in methods.keys(): | |
# del(methods['main']) | |
# return methods | |
# functions = extract_functions_from_file(file_path, file_extension) | |
# list_of_functions = list(functions.keys()) | |
# st.write(list_of_functions) | |
# def res(prompt, model=None): | |
# # response = openai.chat.completions.create( | |
# # model=model, | |
# # messages=[ | |
# # {"role": "user", | |
# # "content": prompt, | |
# # } | |
# # ] | |
# # ) | |
# # return response.choices[0].message.content | |
# response = [ | |
# ChatMessage(role="system", content="You are a sincere and helpful coding assistant"), | |
# ChatMessage(role="user", content=prompt), | |
# ] | |
# # resp = Anthropic(model=model).chat(response) | |
# resp = MistralAI(model).chat(response) | |
# return resp | |
# # Initialize session state for chat messages | |
# if "messages" not in st.session_state: | |
# st.session_state.messages = [] | |
# # Display chat messages from history on app rerun | |
# for message in st.session_state.messages: | |
# with st.chat_message(message["role"]): | |
# st.markdown(message["content"]) | |
# # Accept user input | |
# if func := st.chat_input("Enter the function name for generating test cases:"): | |
# st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"}) | |
# st.success(f"Generating test cases for {func}") | |
# func = ''.join(func.split()) | |
# if func not in list_of_functions: | |
# st.write("Incorrect function name") | |
# else: | |
# snippet = functions[func] | |
# # Generation | |
# # model = "gpt-3.5-turbo" | |
# # model = "claude-3-haiku-20240307" | |
# # model = "claude-3-sonnet-20240229" | |
# # model = "claude-3-opus-20240229" | |
# model = "codestral-latest" | |
# # Generation | |
# # resp = ollama.generate(model='codellama', | |
# # prompt=f""" Your task is to generate unit test cases for this function : {snippet}\ | |
# # \n\n Politely refuse if the function is not suitable for generating test cases. | |
# # \n\n Generate atleast 5 unit test case. Include couple of edge cases as well. | |
# # \n\n There should be no duplicate test cases. | |
# # \n\n Avoid generating repeated statements. | |
# # """) | |
# prompt=f""" Your task is to generate unit test cases for this function : \n\n{snippet}\ | |
# \n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well. | |
# \n\n All the test cases should have the mandatory assert statement. | |
# \n\n Every test case should be defined as a method inside the class. | |
# \n\n All the test cases should have textual description. | |
# \n\n Politely refuse if the function is not suitable for generating test cases. | |
# \n\n There should be no duplicate and incomplete test case. | |
# \n\n Avoid generating repeated statements. | |
# \n\n Recheck your response before generating. | |
# \n\n Do not share the last Test Case. | |
# """ | |
# # print(prompt) | |
# resp = res(prompt = prompt, model = model) | |
# # Post Processing | |
# post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\ | |
# \n\n Do not add anything extra. Just copy and paste everything except the last test case. | |
# \n\n Do not mention the count of total number of test cases in the response. | |
# \n\n Do not mention this sentence - "I have excluded the last test case as per your request" | |
# """ | |
# post_resp = res(prompt = post_prompt, model = model) | |
# st.session_state.messages.append({"role": "assistant", "content": f"{post_resp}"}) | |
# st.markdown(post_resp) | |
# # st.session_state.messages.append({"role": "assistant", "content": f"{resp['response']}"}) | |
# # st.markdown(resp['response']) | |
import os | |
import re | |
import ast | |
import streamlit as st | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core.llms import ChatMessage | |
from llama_index.llms.anthropic import Anthropic | |
from llama_index.llms.mistralai import MistralAI | |
import nest_asyncio | |
class TestCaseGenerator: | |
def __init__(self): | |
nest_asyncio.apply() | |
self.key = os.getenv('MISTRAL_API_KEY') | |
os.environ["MISTRAL_API_KEY"] = self.key | |
self.model = "codestral-latest" | |
self.functions = {} | |
self.list_of_functions = [] | |
def setup_streamlit_ui(self): | |
st.title("Auto Test Case Generation using LLM") | |
uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py", "java"], accept_multiple_files=True) | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
self.process_uploaded_file(uploaded_file) | |
def process_uploaded_file(self, uploaded_file): | |
with open(f"./data/{uploaded_file.name}", 'wb') as f: | |
f.write(uploaded_file.getbuffer()) | |
st.success("File uploaded...") | |
_, file_extension = os.path.splitext(uploaded_file.name) | |
print(file_extension) | |
st.success("Fetching list of functions...") | |
file_path = f"./data/{uploaded_file.name}" | |
self.extract_functions_from_file(file_path, file_extension) | |
st.write(self.list_of_functions) | |
def extract_functions_from_file(self, file_path, file_extension): | |
if file_extension == '.py': | |
self.extract_python_functions(file_path) | |
elif file_extension == '.java': | |
self.extract_java_functions(file_path) | |
if 'main' in self.functions.keys(): | |
del(self.functions['main']) | |
self.list_of_functions = list(self.functions.keys()) | |
def extract_python_functions(self, file_path): | |
with open(file_path, "r") as file: | |
file_content = file.read() | |
parsed_content = ast.parse(file_content) | |
for node in ast.walk(parsed_content): | |
if isinstance(node, ast.FunctionDef): | |
func_name = node.name | |
func_body = ast.get_source_segment(file_content, node) | |
self.functions[func_name] = func_body | |
def extract_java_functions(self, file_path): | |
with open(file_path, 'r') as file: | |
lines = file.readlines() | |
inside_method = False | |
method_name = None | |
method_body = [] | |
brace_count = 0 | |
method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{') | |
for line in lines: | |
if not inside_method: | |
match = method_signature_pattern.search(line) | |
if match: | |
modifiers, method_name = match.groups() | |
inside_method = True | |
method_body.append(line) | |
brace_count = line.count('{') - line.count('}') | |
else: | |
method_body.append(line) | |
brace_count += line.count('{') - line.count('}') | |
if brace_count == 0: | |
inside_method = False | |
self.functions[method_name] = ''.join(method_body) | |
method_body = [] | |
# def generate_response(self, prompt): | |
# response = [ | |
# ChatMessage(role="system", content="You are a sincere and helpful coding assistant"), | |
# ChatMessage(role="user", content=prompt), | |
# ] | |
# resp = MistralAI(self.model).chat(response) | |
# return resp | |
def generate_response(self, prompt): | |
response = [ | |
ChatMessage(role="system", content="You are tasked with generating unit test cases, including descriptions and assert statements, for a given function. You will also calculate the test coverage percentage. Follow these instructions carefully:"), | |
ChatMessage(role="user", content=prompt), | |
] | |
resp = MistralAI(self.model).chat(response) | |
return resp | |
def generate_test_cases(self, func): | |
if func not in self.list_of_functions: | |
st.write("Incorrect function name") | |
return | |
snippet = self.functions[func] | |
, file_extension = os.path.splitext(uploaded_file.name) | |
lang = file_extension[1:] | |
prompt = f"""1. You will be provided with the following inputs: | |
<function_code> | |
{snippet} | |
</function_code> | |
<language>{lang}</language> | |
2. Analyze the provided function: | |
- Identify the function name, parameters, and return type | |
- Determine the main logic and branches in the function | |
- Note any potential edge cases or boundary conditions | |
3. Generate unit test cases: | |
- Create at least 3-5 test cases that cover different scenarios | |
- Include normal cases, edge cases, and potential error conditions | |
- Ensure that the test cases collectively cover all branches and logic paths in the function | |
4. For each test case, provide: | |
- A brief description of the test scenario | |
- Input values for the function parameters | |
- The expected output or behavior | |
- An assert statement in the appropriate syntax for the given programming language | |
5. Calculate the test coverage percentage: | |
- Determine the number of code paths or branches in the function | |
- Count how many of these paths are covered by your test cases | |
- Calculate the percentage: (covered paths / total paths) * 100 | |
6. Present your output in the following format: | |
<unit_tests> | |
<test_case> | |
<description>Description of the test case</description> | |
<input>Input values for the function</input> | |
<expected_output>Expected output or behavior</expected_output> | |
<assert_statement>Assert statement in the appropriate language syntax</assert_statement> | |
</test_case> | |
<!-- Repeat for each test case --> | |
</unit_tests> | |
<test_coverage> | |
<percentage>Calculated test coverage percentage</percentage> | |
<explanation>Brief explanation of how the percentage was calculated</explanation> | |
</test_coverage> | |
Remember to adapt your assert statements and syntax to the specific programming language provided. If you're unsure about the exact syntax for a particular language, use a general pseudocode format that clearly conveys the assertion logic. | |
""" | |
resp = self.generate_response(prompt) | |
# post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\ | |
# \n\n - Do not add anything extra. Just copy and paste everything except the last test case. | |
# \n\n - Do not mention the count of total number of test cases in the response. | |
# \n\n - Do not mention this sentence - "I have excluded the last test case as per your request" | |
# """ | |
# post_resp = self.generate_response(post_prompt) | |
return resp | |
def run(self): | |
self.setup_streamlit_ui() | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if func := st.chat_input("Enter the function name for generating test cases:"): | |
st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"}) | |
st.success(f"Generating test cases for {func}") | |
func = ''.join(func.split()) | |
test_cases = self.generate_test_cases(func) | |
st.session_state.messages.append({"role": "assistant", "content": f"{test_cases}"}) | |
st.markdown(test_cases) | |
if __name__ == "__main__": | |
test_case_generator = TestCaseGenerator() | |
test_case_generator.run() |