datawithsuman commited on
Commit
4264bb9
·
verified ·
1 Parent(s): 5e418aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +292 -155
app.py CHANGED
@@ -1,191 +1,328 @@
1
- import os
2
- import re
3
- import streamlit as st
4
- import ast
5
- import json
6
- import openai
7
- from llama_index.llms.openai import OpenAI
8
- from llama_index.core.llms import ChatMessage
9
- from llama_index.llms.anthropic import Anthropic
10
- from llama_index.llms.mistralai import MistralAI
11
 
12
- import nest_asyncio
13
- nest_asyncio.apply()
14
 
15
- # import ollama
16
- # from llama_index.llms.ollama import Ollama
17
- # from llama_index.core.llms import ChatMessage
18
 
19
 
20
- # OpenAI credentials
21
- # key = os.getenv('OPENAI_API_KEY')
22
- # openai.api_key = key
23
- # os.environ["OPENAI_API_KEY"] = key
24
 
25
- # Anthropic credentials
26
- # key = os.getenv('CLAUDE_API_KEY')
27
- # os.environ["ANTHROPIC_API_KEY"] = key
28
 
29
- # Mistral
30
- key = os.getenv('MISTRAL_API_KEY')
31
- os.environ["MISTRAL_API_KEY"] = key
32
 
33
 
34
 
35
- # Streamlit UI
36
- st.title("Auto Test Case Generation using LLM")
37
 
38
- uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py","java"], accept_multiple_files=True)
39
 
40
- if uploaded_files:
41
- for uploaded_file in uploaded_files:
42
- with open(f"./data/{uploaded_file.name}", 'wb') as f:
43
- f.write(uploaded_file.getbuffer())
44
- st.success("File uploaded...")
45
 
46
- # Check file type
47
- _, file_extension = os.path.splitext(uploaded_file.name)
48
- print(file_extension)
49
 
50
- st.success("Fetching list of functions...")
51
- file_path = f"./data/{uploaded_file.name}"
52
 
53
- def extract_functions_from_file(file_path, file_extension):
54
 
55
- if file_extension == '.py':
56
- with open(file_path, "r") as file:
57
- file_content = file.read()
58
 
59
- parsed_content = ast.parse(file_content)
60
- methods = {}
61
 
62
- for node in ast.walk(parsed_content):
63
- if isinstance(node, ast.FunctionDef):
64
- func_name = node.name
65
- func_body = ast.get_source_segment(file_content, node)
66
- methods[func_name] = func_body
67
 
68
- elif file_extension == '.java':
69
- with open(file_path, 'r') as file:
70
- lines = file.readlines()
71
 
72
- methods = {}
73
- inside_method = False
74
- method_name = None
75
- method_body = []
76
- brace_count = 0
77
 
78
- method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{')
79
 
80
- for line in lines:
81
- if not inside_method:
82
- match = method_signature_pattern.search(line)
83
- if match:
84
- modifiers, method_name = match.groups()
85
- inside_method = True
86
- method_body.append(line)
87
- brace_count = line.count('{') - line.count('}')
88
- else:
89
- method_body.append(line)
90
- brace_count += line.count('{') - line.count('}')
91
- if brace_count == 0:
92
- inside_method = False
93
- methods[method_name] = ''.join(method_body)
94
- method_body = []
95
 
96
- if 'main' in methods.keys():
97
- del(methods['main'])
98
 
99
- return methods
100
 
101
- functions = extract_functions_from_file(file_path, file_extension)
102
- list_of_functions = list(functions.keys())
103
- st.write(list_of_functions)
104
-
105
- def res(prompt, model=None):
106
-
107
- # response = openai.chat.completions.create(
108
- # model=model,
109
- # messages=[
110
- # {"role": "user",
111
- # "content": prompt,
112
- # }
113
- # ]
114
- # )
115
-
116
- # return response.choices[0].message.content
117
-
118
- response = [
119
- ChatMessage(role="system", content="You are a sincere and helpful coding assistant"),
120
- ChatMessage(role="user", content=prompt),
121
- ]
122
- # resp = Anthropic(model=model).chat(response)
123
- resp = MistralAI(model).chat(response)
124
- return resp
125
-
126
- # Initialize session state for chat messages
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  if "messages" not in st.session_state:
128
  st.session_state.messages = []
129
 
130
- # Display chat messages from history on app rerun
131
  for message in st.session_state.messages:
132
  with st.chat_message(message["role"]):
133
  st.markdown(message["content"])
134
 
135
- # Accept user input
136
  if func := st.chat_input("Enter the function name for generating test cases:"):
137
  st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"})
138
  st.success(f"Generating test cases for {func}")
139
-
140
  func = ''.join(func.split())
 
 
 
141
 
142
- if func not in list_of_functions:
143
- st.write("Incorrect function name")
144
-
145
- else:
146
- snippet = functions[func]
147
-
148
- # Generation
149
- # model = "gpt-3.5-turbo"
150
- # model = "claude-3-haiku-20240307"
151
- # model = "claude-3-sonnet-20240229"
152
- # model = "claude-3-opus-20240229"
153
- model = "codestral-latest"
154
-
155
- # Generation
156
- # resp = ollama.generate(model='codellama',
157
- # prompt=f""" Your task is to generate unit test cases for this function : {snippet}\
158
- # \n\n Politely refuse if the function is not suitable for generating test cases.
159
- # \n\n Generate atleast 5 unit test case. Include couple of edge cases as well.
160
- # \n\n There should be no duplicate test cases.
161
- # \n\n Avoid generating repeated statements.
162
- # """)
163
-
164
- prompt=f""" Your task is to generate unit test cases for this function : \n\n{snippet}\
165
- \n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well.
166
- \n\n All the test cases should have the mandatory assert statement.
167
- \n\n Every test case should be defined as a method inside the class.
168
- \n\n All the test cases should have textual description.
169
- \n\n Politely refuse if the function is not suitable for generating test cases.
170
- \n\n There should be no duplicate and incomplete test case.
171
- \n\n Avoid generating repeated statements.
172
- \n\n Recheck your response before generating.
173
- \n\n Do not share the last Test Case.
174
- """
175
-
176
- # print(prompt)
177
-
178
- resp = res(prompt = prompt, model = model)
179
-
180
- # Post Processing
181
- post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\
182
- \n\n Do not add anything extra. Just copy and paste everything except the last test case.
183
- \n\n Do not mention the count of total number of test cases in the response.
184
- \n\n Do not mention this sentence - "I have excluded the last test case as per your request"
185
- """
186
- post_resp = res(prompt = post_prompt, model = model)
187
- st.session_state.messages.append({"role": "assistant", "content": f"{post_resp}"})
188
- st.markdown(post_resp)
189
- # st.session_state.messages.append({"role": "assistant", "content": f"{resp['response']}"})
190
- # st.markdown(resp['response'])
191
-
 
1
+ # import os
2
+ # import re
3
+ # import streamlit as st
4
+ # import ast
5
+ # import json
6
+ # import openai
7
+ # from llama_index.llms.openai import OpenAI
8
+ # from llama_index.core.llms import ChatMessage
9
+ # from llama_index.llms.anthropic import Anthropic
10
+ # from llama_index.llms.mistralai import MistralAI
11
 
12
+ # import nest_asyncio
13
+ # nest_asyncio.apply()
14
 
15
+ # # import ollama
16
+ # # from llama_index.llms.ollama import Ollama
17
+ # # from llama_index.core.llms import ChatMessage
18
 
19
 
20
+ # # OpenAI credentials
21
+ # # key = os.getenv('OPENAI_API_KEY')
22
+ # # openai.api_key = key
23
+ # # os.environ["OPENAI_API_KEY"] = key
24
 
25
+ # # Anthropic credentials
26
+ # # key = os.getenv('CLAUDE_API_KEY')
27
+ # # os.environ["ANTHROPIC_API_KEY"] = key
28
 
29
+ # # Mistral
30
+ # key = os.getenv('MISTRAL_API_KEY')
31
+ # os.environ["MISTRAL_API_KEY"] = key
32
 
33
 
34
 
35
+ # # Streamlit UI
36
+ # st.title("Auto Test Case Generation using LLM")
37
 
38
+ # uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py","java"], accept_multiple_files=True)
39
 
40
+ # if uploaded_files:
41
+ # for uploaded_file in uploaded_files:
42
+ # with open(f"./data/{uploaded_file.name}", 'wb') as f:
43
+ # f.write(uploaded_file.getbuffer())
44
+ # st.success("File uploaded...")
45
 
46
+ # # Check file type
47
+ # _, file_extension = os.path.splitext(uploaded_file.name)
48
+ # print(file_extension)
49
 
50
+ # st.success("Fetching list of functions...")
51
+ # file_path = f"./data/{uploaded_file.name}"
52
 
53
+ # def extract_functions_from_file(file_path, file_extension):
54
 
55
+ # if file_extension == '.py':
56
+ # with open(file_path, "r") as file:
57
+ # file_content = file.read()
58
 
59
+ # parsed_content = ast.parse(file_content)
60
+ # methods = {}
61
 
62
+ # for node in ast.walk(parsed_content):
63
+ # if isinstance(node, ast.FunctionDef):
64
+ # func_name = node.name
65
+ # func_body = ast.get_source_segment(file_content, node)
66
+ # methods[func_name] = func_body
67
 
68
+ # elif file_extension == '.java':
69
+ # with open(file_path, 'r') as file:
70
+ # lines = file.readlines()
71
 
72
+ # methods = {}
73
+ # inside_method = False
74
+ # method_name = None
75
+ # method_body = []
76
+ # brace_count = 0
77
 
78
+ # method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{')
79
 
80
+ # for line in lines:
81
+ # if not inside_method:
82
+ # match = method_signature_pattern.search(line)
83
+ # if match:
84
+ # modifiers, method_name = match.groups()
85
+ # inside_method = True
86
+ # method_body.append(line)
87
+ # brace_count = line.count('{') - line.count('}')
88
+ # else:
89
+ # method_body.append(line)
90
+ # brace_count += line.count('{') - line.count('}')
91
+ # if brace_count == 0:
92
+ # inside_method = False
93
+ # methods[method_name] = ''.join(method_body)
94
+ # method_body = []
95
 
96
+ # if 'main' in methods.keys():
97
+ # del(methods['main'])
98
 
99
+ # return methods
100
 
101
+ # functions = extract_functions_from_file(file_path, file_extension)
102
+ # list_of_functions = list(functions.keys())
103
+ # st.write(list_of_functions)
104
+
105
+ # def res(prompt, model=None):
106
+
107
+ # # response = openai.chat.completions.create(
108
+ # # model=model,
109
+ # # messages=[
110
+ # # {"role": "user",
111
+ # # "content": prompt,
112
+ # # }
113
+ # # ]
114
+ # # )
115
+
116
+ # # return response.choices[0].message.content
117
+
118
+ # response = [
119
+ # ChatMessage(role="system", content="You are a sincere and helpful coding assistant"),
120
+ # ChatMessage(role="user", content=prompt),
121
+ # ]
122
+ # # resp = Anthropic(model=model).chat(response)
123
+ # resp = MistralAI(model).chat(response)
124
+ # return resp
125
+
126
+ # # Initialize session state for chat messages
127
+ # if "messages" not in st.session_state:
128
+ # st.session_state.messages = []
129
+
130
+ # # Display chat messages from history on app rerun
131
+ # for message in st.session_state.messages:
132
+ # with st.chat_message(message["role"]):
133
+ # st.markdown(message["content"])
134
+
135
+ # # Accept user input
136
+ # if func := st.chat_input("Enter the function name for generating test cases:"):
137
+ # st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"})
138
+ # st.success(f"Generating test cases for {func}")
139
+
140
+ # func = ''.join(func.split())
141
+
142
+ # if func not in list_of_functions:
143
+ # st.write("Incorrect function name")
144
+
145
+ # else:
146
+ # snippet = functions[func]
147
+
148
+ # # Generation
149
+ # # model = "gpt-3.5-turbo"
150
+ # # model = "claude-3-haiku-20240307"
151
+ # # model = "claude-3-sonnet-20240229"
152
+ # # model = "claude-3-opus-20240229"
153
+ # model = "codestral-latest"
154
+
155
+ # # Generation
156
+ # # resp = ollama.generate(model='codellama',
157
+ # # prompt=f""" Your task is to generate unit test cases for this function : {snippet}\
158
+ # # \n\n Politely refuse if the function is not suitable for generating test cases.
159
+ # # \n\n Generate atleast 5 unit test case. Include couple of edge cases as well.
160
+ # # \n\n There should be no duplicate test cases.
161
+ # # \n\n Avoid generating repeated statements.
162
+ # # """)
163
+
164
+ # prompt=f""" Your task is to generate unit test cases for this function : \n\n{snippet}\
165
+ # \n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well.
166
+ # \n\n All the test cases should have the mandatory assert statement.
167
+ # \n\n Every test case should be defined as a method inside the class.
168
+ # \n\n All the test cases should have textual description.
169
+ # \n\n Politely refuse if the function is not suitable for generating test cases.
170
+ # \n\n There should be no duplicate and incomplete test case.
171
+ # \n\n Avoid generating repeated statements.
172
+ # \n\n Recheck your response before generating.
173
+ # \n\n Do not share the last Test Case.
174
+ # """
175
+
176
+ # # print(prompt)
177
+
178
+ # resp = res(prompt = prompt, model = model)
179
+
180
+ # # Post Processing
181
+ # post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\
182
+ # \n\n Do not add anything extra. Just copy and paste everything except the last test case.
183
+ # \n\n Do not mention the count of total number of test cases in the response.
184
+ # \n\n Do not mention this sentence - "I have excluded the last test case as per your request"
185
+ # """
186
+ # post_resp = res(prompt = post_prompt, model = model)
187
+ # st.session_state.messages.append({"role": "assistant", "content": f"{post_resp}"})
188
+ # st.markdown(post_resp)
189
+ # # st.session_state.messages.append({"role": "assistant", "content": f"{resp['response']}"})
190
+ # # st.markdown(resp['response'])
191
+
192
+
193
+
194
+
195
+ import os
196
+ import re
197
+ import ast
198
+ import streamlit as st
199
+ from llama_index.llms.openai import OpenAI
200
+ from llama_index.core.llms import ChatMessage
201
+ from llama_index.llms.anthropic import Anthropic
202
+ from llama_index.llms.mistralai import MistralAI
203
+ import nest_asyncio
204
+
205
+ class TestCaseGenerator:
206
+ def __init__(self):
207
+ nest_asyncio.apply()
208
+ self.key = os.getenv('MISTRAL_API_KEY')
209
+ os.environ["MISTRAL_API_KEY"] = self.key
210
+ self.model = "codestral-latest"
211
+ self.functions = {}
212
+ self.list_of_functions = []
213
+
214
+ def setup_streamlit_ui(self):
215
+ st.title("Auto Test Case Generation using LLM")
216
+ uploaded_files = st.file_uploader("Upload a python or Java file", type=[".py", "java"], accept_multiple_files=True)
217
+ if uploaded_files:
218
+ for uploaded_file in uploaded_files:
219
+ self.process_uploaded_file(uploaded_file)
220
+
221
+ def process_uploaded_file(self, uploaded_file):
222
+ with open(f"./data/{uploaded_file.name}", 'wb') as f:
223
+ f.write(uploaded_file.getbuffer())
224
+ st.success("File uploaded...")
225
+ _, file_extension = os.path.splitext(uploaded_file.name)
226
+ print(file_extension)
227
+ st.success("Fetching list of functions...")
228
+ file_path = f"./data/{uploaded_file.name}"
229
+ self.extract_functions_from_file(file_path, file_extension)
230
+ st.write(self.list_of_functions)
231
+
232
+ def extract_functions_from_file(self, file_path, file_extension):
233
+ if file_extension == '.py':
234
+ self.extract_python_functions(file_path)
235
+ elif file_extension == '.java':
236
+ self.extract_java_functions(file_path)
237
+ if 'main' in self.functions.keys():
238
+ del(self.functions['main'])
239
+ self.list_of_functions = list(self.functions.keys())
240
+
241
+ def extract_python_functions(self, file_path):
242
+ with open(file_path, "r") as file:
243
+ file_content = file.read()
244
+ parsed_content = ast.parse(file_content)
245
+ for node in ast.walk(parsed_content):
246
+ if isinstance(node, ast.FunctionDef):
247
+ func_name = node.name
248
+ func_body = ast.get_source_segment(file_content, node)
249
+ self.functions[func_name] = func_body
250
+
251
+ def extract_java_functions(self, file_path):
252
+ with open(file_path, 'r') as file:
253
+ lines = file.readlines()
254
+ inside_method = False
255
+ method_name = None
256
+ method_body = []
257
+ brace_count = 0
258
+ method_signature_pattern = re.compile(r'((?:public|protected|private|static|\s)*)\s+[\w<>\[\]]+\s+(\w+)\s*\([^)]*\)\s*\{')
259
+ for line in lines:
260
+ if not inside_method:
261
+ match = method_signature_pattern.search(line)
262
+ if match:
263
+ modifiers, method_name = match.groups()
264
+ inside_method = True
265
+ method_body.append(line)
266
+ brace_count = line.count('{') - line.count('}')
267
+ else:
268
+ method_body.append(line)
269
+ brace_count += line.count('{') - line.count('}')
270
+ if brace_count == 0:
271
+ inside_method = False
272
+ self.functions[method_name] = ''.join(method_body)
273
+ method_body = []
274
+
275
+ def generate_response(self, prompt):
276
+ response = [
277
+ ChatMessage(role="system", content="You are a sincere and helpful coding assistant"),
278
+ ChatMessage(role="user", content=prompt),
279
+ ]
280
+ resp = MistralAI(self.model).chat(response)
281
+ return resp
282
+
283
+ def generate_test_cases(self, func):
284
+ if func not in self.list_of_functions:
285
+ st.write("Incorrect function name")
286
+ return
287
+
288
+ snippet = self.functions[func]
289
+ prompt = f"""Your task is to generate unit test cases for this function : \n\n{snippet}\
290
+ \n\n Generate between 3 to 8 unique unit test cases. Include couple of edge cases as well.
291
+ \n\n All the test cases should have the mandatory assert statement.
292
+ \n\n Every test case should be defined as a method inside the class.
293
+ \n\n All the test cases should have textual description.
294
+ \n\n Politely refuse if the function is not suitable for generating test cases.
295
+ \n\n There should be no duplicate and incomplete test case.
296
+ \n\n Avoid generating repeated statements.
297
+ \n\n Recheck your response before generating.
298
+ \n\n Do not share the last Test Case.
299
+ """
300
+ resp = self.generate_response(prompt)
301
+ post_prompt = f"""Except the last test case, display everything that is present in this end to end: \n\n{resp}\
302
+ \n\n Do not add anything extra. Just copy and paste everything except the last test case.
303
+ \n\n Do not mention the count of total number of test cases in the response.
304
+ \n\n Do not mention this sentence - "I have excluded the last test case as per your request"
305
+ """
306
+ post_resp = self.generate_response(post_prompt)
307
+ return post_resp
308
+
309
+ def run(self):
310
+ self.setup_streamlit_ui()
311
  if "messages" not in st.session_state:
312
  st.session_state.messages = []
313
 
 
314
  for message in st.session_state.messages:
315
  with st.chat_message(message["role"]):
316
  st.markdown(message["content"])
317
 
 
318
  if func := st.chat_input("Enter the function name for generating test cases:"):
319
  st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"})
320
  st.success(f"Generating test cases for {func}")
 
321
  func = ''.join(func.split())
322
+ test_cases = self.generate_test_cases(func)
323
+ st.session_state.messages.append({"role": "assistant", "content": f"{test_cases}"})
324
+ st.markdown(test_cases)
325
 
326
+ if __name__ == "__main__":
327
+ test_case_generator = TestCaseGenerator()
328
+ test_case_generator.run()