datawithsuman's picture
Update app.py
46f624d verified
raw
history blame
16.8 kB
# import os
# import re
# import streamlit as st
# import ast
# import json
# import openai
# from llama_index.llms.openai import OpenAI
# from llama_index.core.llms import ChatMessage
# from llama_index.llms.anthropic import Anthropic
# from llama_index.llms.mistralai import MistralAI
# import nest_asyncio
# nest_asyncio.apply()
# # import ollama
# # from llama_index.llms.ollama import Ollama
# # from llama_index.core.llms import ChatMessage
# # OpenAI credentials
# # key = os.getenv('OPENAI_API_KEY')
# # openai.api_key = key
# # os.environ["OPENAI_API_KEY"] = key
# # Anthropic credentials
# # key = os.getenv('CLAUDE_API_KEY')
# # os.environ["ANTHROPIC_API_KEY"] = key
# # Mistral
# key = os.getenv('MISTRAL_API_KEY')
# os.environ["MISTRAL_API_KEY"] = key
# # Streamlit UI
# st.title("Auto Test Case Generation using LLM")
# uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py","java"], accept_multiple_files=True)
# if uploaded_files:
# for uploaded_file in uploaded_files:
# with open(f"./data/{uploaded_file.name}", 'wb') as f:
# f.write(uploaded_file.getbuffer())
# st.success("File uploaded...")
# # Check file type
# _, file_extension = os.path.splitext(uploaded_file.name)
# print(file_extension)
# st.success("Fetching list of functions...")
# file_path = f"./data/{uploaded_file.name}"
# def extract_functions_from_file(file_path, file_extension):
# if file_extension == '.py':
# with open(file_path, "r") as file:
# file_content = file.read()
# parsed_content = ast.parse(file_content)
# methods = {}
# for node in ast.walk(parsed_content):
# if isinstance(node, ast.FunctionDef):
# func_name = node.name
# func_body = ast.get_source_segment(file_content, node)
# methods[func_name] = func_body
# elif file_extension == '.java':
# with open(file_path, 'r') as file:
# lines = file.readlines()
# methods = {}
# inside_method = False
# method_name = None
# method_body = []
# brace_count = 0
# method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{')
# for line in lines:
# if not inside_method:
# match = method_signature_pattern.search(line)
# if match:
# modifiers, method_name = match.groups()
# inside_method = True
# method_body.append(line)
# brace_count = line.count('{') - line.count('}')
# else:
# method_body.append(line)
# brace_count += line.count('{') - line.count('}')
# if brace_count == 0:
# inside_method = False
# methods[method_name] = ''.join(method_body)
# method_body = []
# if 'main' in methods.keys():
# del(methods['main'])
# return methods
# functions = extract_functions_from_file(file_path, file_extension)
# list_of_functions = list(functions.keys())
# st.write(list_of_functions)
# def res(prompt, model=None):
# # response = openai.chat.completions.create(
# # model=model,
# # messages=[
# # {"role": "user",
# # "content": prompt,
# # }
# # ]
# # )
# # return response.choices[0].message.content
# response = [
# ChatMessage(role="system", content="You are a sincere and helpful coding assistant"),
# ChatMessage(role="user", content=prompt),
# ]
# # resp = Anthropic(model=model).chat(response)
# resp = MistralAI(model).chat(response)
# return resp
# # Initialize session state for chat messages
# if "messages" not in st.session_state:
# st.session_state.messages = []
# # Display chat messages from history on app rerun
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
# # Accept user input
# if func := st.chat_input("Enter the function name for generating test cases:"):
# st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"})
# st.success(f"Generating test cases for {func}")
# func = ''.join(func.split())
# if func not in list_of_functions:
# st.write("Incorrect function name")
# else:
# snippet = functions[func]
# # Generation
# # model = "gpt-3.5-turbo"
# # model = "claude-3-haiku-20240307"
# # model = "claude-3-sonnet-20240229"
# # model = "claude-3-opus-20240229"
# model = "codestral-latest"
# # Generation
# # resp = ollama.generate(model='codellama',
# # prompt=f""" Your task is to generate unit test cases for this function : {snippet}\
# # \n\n Politely refuse if the function is not suitable for generating test cases.
# # \n\n Generate atleast 5 unit test case. Include couple of edge cases as well.
# # \n\n There should be no duplicate test cases.
# # \n\n Avoid generating repeated statements.
# # """)
# prompt=f""" Your task is to generate unit test cases for this function : \n\n{snippet}\
# \n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well.
# \n\n All the test cases should have the mandatory assert statement.
# \n\n Every test case should be defined as a method inside the class.
# \n\n All the test cases should have textual description.
# \n\n Politely refuse if the function is not suitable for generating test cases.
# \n\n There should be no duplicate and incomplete test case.
# \n\n Avoid generating repeated statements.
# \n\n Recheck your response before generating.
# \n\n Do not share the last Test Case.
# """
# # print(prompt)
# resp = res(prompt = prompt, model = model)
# # Post Processing
# post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\
# \n\n Do not add anything extra. Just copy and paste everything except the last test case.
# \n\n Do not mention the count of total number of test cases in the response.
# \n\n Do not mention this sentence - "I have excluded the last test case as per your request"
# """
# post_resp = res(prompt = post_prompt, model = model)
# st.session_state.messages.append({"role": "assistant", "content": f"{post_resp}"})
# st.markdown(post_resp)
# # st.session_state.messages.append({"role": "assistant", "content": f"{resp['response']}"})
# # st.markdown(resp['response'])
import os
import re
import ast
import streamlit as st
from llama_index.llms.openai import OpenAI
from llama_index.core.llms import ChatMessage
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.mistralai import MistralAI
import nest_asyncio
class TestCaseGenerator:
def __init__(self):
nest_asyncio.apply()
self.key = os.getenv('MISTRAL_API_KEY')
os.environ["MISTRAL_API_KEY"] = self.key
self.model = "codestral-latest"
self.functions = {}
self.list_of_functions = []
def setup_streamlit_ui(self):
st.title("Auto Test Case Generation using LLM")
uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py", "java"], accept_multiple_files=True)
if uploaded_files:
for uploaded_file in uploaded_files:
self.process_uploaded_file(uploaded_file)
def process_uploaded_file(self, uploaded_file):
with open(f"./data/{uploaded_file.name}", 'wb') as f:
f.write(uploaded_file.getbuffer())
st.success("File uploaded...")
_, file_extension = os.path.splitext(uploaded_file.name)
print(file_extension)
st.success("Fetching list of functions...")
file_path = f"./data/{uploaded_file.name}"
self.extract_functions_from_file(file_path, file_extension)
st.write(self.list_of_functions)
def extract_functions_from_file(self, file_path, file_extension):
if file_extension == '.py':
self.extract_python_functions(file_path)
elif file_extension == '.java':
self.extract_java_functions(file_path)
if 'main' in self.functions.keys():
del(self.functions['main'])
self.list_of_functions = list(self.functions.keys())
def extract_python_functions(self, file_path):
with open(file_path, "r") as file:
file_content = file.read()
parsed_content = ast.parse(file_content)
for node in ast.walk(parsed_content):
if isinstance(node, ast.FunctionDef):
func_name = node.name
func_body = ast.get_source_segment(file_content, node)
self.functions[func_name] = func_body
def extract_java_functions(self, file_path):
with open(file_path, 'r') as file:
lines = file.readlines()
inside_method = False
method_name = None
method_body = []
brace_count = 0
method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{')
for line in lines:
if not inside_method:
match = method_signature_pattern.search(line)
if match:
modifiers, method_name = match.groups()
inside_method = True
method_body.append(line)
brace_count = line.count('{') - line.count('}')
else:
method_body.append(line)
brace_count += line.count('{') - line.count('}')
if brace_count == 0:
inside_method = False
self.functions[method_name] = ''.join(method_body)
method_body = []
# def generate_response(self, prompt):
# response = [
# ChatMessage(role="system", content="You are a sincere and helpful coding assistant"),
# ChatMessage(role="user", content=prompt),
# ]
# resp = MistralAI(self.model).chat(response)
# return resp
def generate_response(self, prompt):
response = [
ChatMessage(role="system", content="You are tasked with generating unit test cases, including descriptions and assert statements, for a given function. You will also calculate the test coverage percentage. Follow these instructions carefully:"),
ChatMessage(role="user", content=prompt),
]
resp = MistralAI(self.model).chat(response)
return resp
def generate_test_cases(self, func):
if func not in self.list_of_functions:
st.write("Incorrect function name")
return
snippet = self.functions[func]
, file_extension = os.path.splitext(uploaded_file.name)
lang = file_extension[1:]
prompt = f"""1. You will be provided with the following inputs:
<function_code>
{snippet}
</function_code>
<language>{lang}</language>
2. Analyze the provided function:
- Identify the function name, parameters, and return type
- Determine the main logic and branches in the function
- Note any potential edge cases or boundary conditions
3. Generate unit test cases:
- Create at least 3-5 test cases that cover different scenarios
- Include normal cases, edge cases, and potential error conditions
- Ensure that the test cases collectively cover all branches and logic paths in the function
4. For each test case, provide:
- A brief description of the test scenario
- Input values for the function parameters
- The expected output or behavior
- An assert statement in the appropriate syntax for the given programming language
5. Calculate the test coverage percentage:
- Determine the number of code paths or branches in the function
- Count how many of these paths are covered by your test cases
- Calculate the percentage: (covered paths / total paths) * 100
6. Present your output in the following format:
<unit_tests>
<test_case>
<description>Description of the test case</description>
<input>Input values for the function</input>
<expected_output>Expected output or behavior</expected_output>
<assert_statement>Assert statement in the appropriate language syntax</assert_statement>
</test_case>
<!-- Repeat for each test case -->
</unit_tests>
<test_coverage>
<percentage>Calculated test coverage percentage</percentage>
<explanation>Brief explanation of how the percentage was calculated</explanation>
</test_coverage>
Remember to adapt your assert statements and syntax to the specific programming language provided. If you're unsure about the exact syntax for a particular language, use a general pseudocode format that clearly conveys the assertion logic.
"""
resp = self.generate_response(prompt)
# post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\
# \n\n - Do not add anything extra. Just copy and paste everything except the last test case.
# \n\n - Do not mention the count of total number of test cases in the response.
# \n\n - Do not mention this sentence - "I have excluded the last test case as per your request"
# """
# post_resp = self.generate_response(post_prompt)
return resp
def run(self):
self.setup_streamlit_ui()
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if func := st.chat_input("Enter the function name for generating test cases:"):
st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"})
st.success(f"Generating test cases for {func}")
func = ''.join(func.split())
test_cases = self.generate_test_cases(func)
st.session_state.messages.append({"role": "assistant", "content": f"{test_cases}"})
st.markdown(test_cases)
if __name__ == "__main__":
test_case_generator = TestCaseGenerator()
test_case_generator.run()