# import os | |
# import re | |
# import streamlit as st | |
# import ast | |
# import json | |
# import openai | |
# from llama_index.llms.openai import OpenAI | |
# from llama_index.core.llms import ChatMessage | |
# from llama_index.llms.anthropic import Anthropic | |
# from llama_index.llms.mistralai import MistralAI | |
# import nest_asyncio | |
# nest_asyncio.apply() | |
# # import ollama | |
# # from llama_index.llms.ollama import Ollama | |
# # from llama_index.core.llms import ChatMessage | |
# # OpenAI credentials | |
# # key = os.getenv('OPENAI_API_KEY') | |
# # openai.api_key = key | |
# # os.environ["OPENAI_API_KEY"] = key | |
# # Anthropic credentials | |
# # key = os.getenv('CLAUDE_API_KEY') | |
# # os.environ["ANTHROPIC_API_KEY"] = key | |
# # Mistral | |
# key = os.getenv('MISTRAL_API_KEY') | |
# os.environ["MISTRAL_API_KEY"] = key | |
# # Streamlit UI | |
# st.title("Auto Test Case Generation using LLM") | |
# uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py","java"], accept_multiple_files=True) | |
# if uploaded_files: | |
# for uploaded_file in uploaded_files: | |
# with open(f"./data/{uploaded_file.name}", 'wb') as f: | |
# f.write(uploaded_file.getbuffer()) | |
# st.success("File uploaded...") | |
# # Check file type | |
# _, file_extension = os.path.splitext(uploaded_file.name) | |
# print(file_extension) | |
# st.success("Fetching list of functions...") | |
# file_path = f"./data/{uploaded_file.name}" | |
# def extract_functions_from_file(file_path, file_extension): | |
# if file_extension == '.py': | |
# with open(file_path, "r") as file: | |
# file_content = file.read() | |
# parsed_content = ast.parse(file_content) | |
# methods = {} | |
# for node in ast.walk(parsed_content): | |
# if isinstance(node, ast.FunctionDef): | |
# func_name = node.name | |
# func_body = ast.get_source_segment(file_content, node) | |
# methods[func_name] = func_body | |
# elif file_extension == '.java': | |
# with open(file_path, 'r') as file: | |
# lines = file.readlines() | |
# methods = {} | |
# inside_method = False | |
# method_name = None | |
# method_body = [] | |
# brace_count = 0 | |
# method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{') | |
# for line in lines: | |
# if not inside_method: | |
# match = method_signature_pattern.search(line) | |
# if match: | |
# modifiers, method_name = match.groups() | |
# inside_method = True | |
# method_body.append(line) | |
# brace_count = line.count('{') - line.count('}') | |
# else: | |
# method_body.append(line) | |
# brace_count += line.count('{') - line.count('}') | |
# if brace_count == 0: | |
# inside_method = False | |
# methods[method_name] = ''.join(method_body) | |
# method_body = [] | |
# if 'main' in methods.keys(): | |
# del(methods['main']) | |
# return methods | |
# functions = extract_functions_from_file(file_path, file_extension) | |
# list_of_functions = list(functions.keys()) | |
# st.write(list_of_functions) | |
# def res(prompt, model=None): | |
# # response = openai.chat.completions.create( | |
# # model=model, | |
# # messages=[ | |
# # {"role": "user", | |
# # "content": prompt, | |
# # } | |
# # ] | |
# # ) | |
# # return response.choices[0].message.content | |
# response = [ | |
# ChatMessage(role="system", content="You are a sincere and helpful coding assistant"), | |
# ChatMessage(role="user", content=prompt), | |
# ] | |
# # resp = Anthropic(model=model).chat(response) | |
# resp = MistralAI(model).chat(response) | |
# return resp | |
# # Initialize session state for chat messages | |
# if "messages" not in st.session_state: | |
# st.session_state.messages = [] | |
# # Display chat messages from history on app rerun | |
# for message in st.session_state.messages: | |
# with st.chat_message(message["role"]): | |
# st.markdown(message["content"]) | |
# # Accept user input | |
# if func := st.chat_input("Enter the function name for generating test cases:"): | |
# st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"}) | |
# st.success(f"Generating test cases for {func}") | |
# func = ''.join(func.split()) | |
# if func not in list_of_functions: | |
# st.write("Incorrect function name") | |
# else: | |
# snippet = functions[func] | |
# # Generation | |
# # model = "gpt-3.5-turbo" | |
# # model = "claude-3-haiku-20240307" | |
# # model = "claude-3-sonnet-20240229" | |
# # model = "claude-3-opus-20240229" | |
# model = "codestral-latest" | |
# # Generation | |
# # resp = ollama.generate(model='codellama', | |
# # prompt=f""" Your task is to generate unit test cases for this function : {snippet}\ | |
# # \n\n Politely refuse if the function is not suitable for generating test cases. | |
# # \n\n Generate atleast 5 unit test case. Include couple of edge cases as well. | |
# # \n\n There should be no duplicate test cases. | |
# # \n\n Avoid generating repeated statements. | |
# # """) | |
# prompt=f""" Your task is to generate unit test cases for this function : \n\n{snippet}\ | |
# \n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well. | |
# \n\n All the test cases should have the mandatory assert statement. | |
# \n\n Every test case should be defined as a method inside the class. | |
# \n\n All the test cases should have textual description. | |
# \n\n Politely refuse if the function is not suitable for generating test cases. | |
# \n\n There should be no duplicate and incomplete test case. | |
# \n\n Avoid generating repeated statements. | |
# \n\n Recheck your response before generating. | |
# \n\n Do not share the last Test Case. | |
# """ | |
# # print(prompt) | |
# resp = res(prompt = prompt, model = model) | |
# # Post Processing | |
# post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\ | |
# \n\n Do not add anything extra. Just copy and paste everything except the last test case. | |
# \n\n Do not mention the count of total number of test cases in the response. | |
# \n\n Do not mention this sentence - "I have excluded the last test case as per your request" | |
# """ | |
# post_resp = res(prompt = post_prompt, model = model) | |
# st.session_state.messages.append({"role": "assistant", "content": f"{post_resp}"}) | |
# st.markdown(post_resp) | |
# # st.session_state.messages.append({"role": "assistant", "content": f"{resp['response']}"}) | |
# # st.markdown(resp['response']) | |
import os | |
import re | |
import ast | |
import streamlit as st | |
from llama_index.llms.openai import OpenAI | |
from llama_index.core.llms import ChatMessage | |
from llama_index.llms.anthropic import Anthropic | |
from llama_index.llms.mistralai import MistralAI | |
import nest_asyncio | |
class TestCaseGenerator: | |
def __init__(self): | |
nest_asyncio.apply() | |
self.key = os.getenv('MISTRAL_API_KEY') | |
os.environ["MISTRAL_API_KEY"] = self.key | |
self.model = "codestral-latest" | |
self.functions = {} | |
self.list_of_functions = [] | |
def setup_streamlit_ui(self): | |
st.title("Auto Test Case Generation using LLM") | |
uploaded_files = st.file_uploader("Upload a python or Java file", type=["py", "java"], accept_multiple_files=True) | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
self.process_uploaded_file(uploaded_file) | |
def process_uploaded_file(self, uploaded_file): | |
with open(f"./data/{uploaded_file.name}", 'wb') as f: | |
f.write(uploaded_file.getbuffer()) | |
st.success("File uploaded...") | |
_, file_extension = os.path.splitext(uploaded_file.name) | |
st.success("Fetching list of functions...") | |
file_path = f"./data/{uploaded_file.name}" | |
self.extract_functions_from_file(file_path, file_extension) | |
st.write(self.list_of_functions) | |
def extract_functions_from_file(self, file_path, file_extension): | |
if file_extension == '.py': | |
self.extract_python_functions(file_path) | |
elif file_extension == '.java': | |
self.extract_java_functions(file_path) | |
if 'main' in self.functions.keys(): | |
del(self.functions['main']) | |
self.list_of_functions = list(self.functions.keys()) | |
def extract_python_functions(self, file_path): | |
with open(file_path, "r") as file: | |
file_content = file.read() | |
parsed_content = ast.parse(file_content) | |
for node in ast.walk(parsed_content): | |
if isinstance(node, ast.FunctionDef): | |
func_name = node.name | |
func_body = ast.get_source_segment(file_content, node) | |
self.functions[func_name] = func_body | |
def extract_java_functions(self, file_path): | |
with open(file_path, 'r') as file: | |
lines = file.readlines() | |
inside_method = False | |
method_name = None | |
method_body = [] | |
brace_count = 0 | |
method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{') | |
for line in lines: | |
if not inside_method: | |
match = method_signature_pattern.search(line) | |
if match: | |
modifiers, method_name = match.groups() | |
inside_method = True | |
method_body.append(line) | |
brace_count = line.count('{') - line.count('}') | |
else: | |
method_body.append(line) | |
brace_count += line.count('{') - line.count('}') | |
if brace_count == 0: | |
inside_method = False | |
self.functions[method_name] = ''.join(method_body) | |
method_body = [] | |
# def generate_response(self, prompt): | |
# response = [ | |
# ChatMessage(role="system", content="You are a sincere and helpful coding assistant"), | |
# ChatMessage(role="user", content=prompt), | |
# ] | |
# resp = MistralAI(self.model).chat(response) | |
# return resp | |
def generate_response(self, prompt): | |
response = [ | |
ChatMessage(role="system", content="You are tasked with generating unit test cases, including descriptions and assert statements, for a given function. Follow these instructions carefully:"), | |
ChatMessage(role="user", content=prompt), | |
] | |
resp = MistralAI(self.model).chat(response) | |
return resp | |
def generate_test_cases(self, func): | |
if func not in self.list_of_functions: | |
st.write("Incorrect function name") | |
return | |
snippet = self.functions[func] | |
prompt=f""" Your task is to generate unit test cases for this function : \n\n{snippet}\ | |
\n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well. | |
\n\n All the test cases should have the mandatory assert statement. | |
\n\n All the test cases should have textual description. | |
\n\n Politely refuse if the function is not suitable for generating test cases. | |
\n\n There should be no duplicate and incomplete test case. | |
\n\n Avoid generating repeated statements. | |
\n\n Recheck your response before generating. | |
\n\n Do not share the last Test Case. | |
""" | |
resp = self.generate_response(prompt) | |
post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\ | |
\n\n - Do not add anything extra. Just copy and paste everything except the last test case. | |
\n\n - Do not mention the count of total number of test cases in the response. | |
\n\n - Do not mention this sentence - "I have excluded the last test case as per your request" | |
""" | |
post_resp = self.generate_response(post_prompt) | |
return post_resp | |
def run(self): | |
self.setup_streamlit_ui() | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if func := st.chat_input("Enter the function name for generating test cases:"): | |
st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"}) | |
st.success(f"Generating test cases for {func}") | |
func = ''.join(func.split()) | |
test_cases = self.generate_test_cases(func) | |
st.session_state.messages.append({"role": "assistant", "content": f"{test_cases}"}) | |
st.markdown(test_cases) | |
if __name__ == "__main__": | |
test_case_generator = TestCaseGenerator() | |
test_case_generator.run() |