datawithsuman commited on
Commit
04a7b4c
·
verified ·
1 Parent(s): bb5d964

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -3
app.py CHANGED
@@ -1,11 +1,101 @@
1
  import streamlit as st
2
  import ast
3
  import json
4
- import ollama
5
- from llama_index.llms.ollama import Ollama
6
  from llama_index.core.llms import ChatMessage
 
 
 
 
 
 
 
7
 
8
  # Streamlit UI
9
  st.title("Auto Test Case Generation using LLM")
10
 
11
- uploaded_files = st.file_uploader("Upload a python(.py) file", type=".py", accept_multiple_files=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import ast
3
  import json
 
 
4
  from llama_index.core.llms import ChatMessage
5
+ from transformers import AutoTokenizer
6
+ import transformers
7
+ import torch
8
+
9
+ # import ollama
10
+ # from llama_index.llms.ollama import Ollama
11
+ # from llama_index.core.llms import ChatMessage
12
 
13
  # Streamlit UI
14
  st.title("Auto Test Case Generation using LLM")
15
 
16
+ uploaded_files = st.file_uploader("Upload a python(.py) file", type=".py", accept_multiple_files=True)
17
+
18
+ if uploaded_files:
19
+ for uploaded_file in uploaded_files:
20
+ with open(f"./data/{uploaded_file.name}", 'wb') as f:
21
+ f.write(uploaded_file.getbuffer())
22
+ st.success("File uploaded...")
23
+
24
+ st.success("Fetching list of functions...")
25
+ file_path = f"./data/{uploaded_file.name}"
26
+ def extract_functions_from_file(file_path):
27
+ with open(file_path, "r") as file:
28
+ file_content = file.read()
29
+
30
+ parsed_content = ast.parse(file_content)
31
+ functions = {}
32
+
33
+ for node in ast.walk(parsed_content):
34
+ if isinstance(node, ast.FunctionDef):
35
+ func_name = node.name
36
+ func_body = ast.get_source_segment(file_content, node)
37
+ functions[func_name] = func_body
38
+
39
+ return functions
40
+
41
+ functions = extract_functions_from_file(file_path)
42
+
43
+ list_of_functions = list(functions.keys())
44
+ st.write(list_of_functions)
45
+
46
+ # Initialize session state for chat messages
47
+ if "messages" not in st.session_state:
48
+ st.session_state.messages = []
49
+
50
+ # Display chat messages from history on app rerun
51
+ for message in st.session_state.messages:
52
+ with st.chat_message(message["role"]):
53
+ st.markdown(message["content"])
54
+
55
+ # Accept user input
56
+ if func := st.chat_input("Enter the function name for generating test cases:"):
57
+ st.session_state.messages.append({"role": "assistant", "content": f"Generating test cases for {func}"})
58
+ st.success(f"Generating test cases for {func}")
59
+
60
+ func = ''.join(func.split())
61
+
62
+ if func not in list_of_functions:
63
+ st.write("Incorrect function name")
64
+
65
+ else:
66
+ snippet = functions[func]
67
+
68
+ model = "codellama/CodeLlama-7b-Instruct-hf"
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(model)
71
+ pipeline = transformers.pipeline(
72
+ model=model,
73
+ torch_dtype=torch.float16,
74
+ device_map="auto",
75
+ )
76
+
77
+ # Generation
78
+ # resp = ollama.generate(model='codellama',
79
+ # prompt=f"""You are a helpful coding assistant. Your task is to generate unit test cases for this function : {snippet}\
80
+ # \n\nPolitely refuse if the function is not suitable for generating test cases.
81
+ # \n\nGenerate atleast 5 unit test case. Include couple of edge cases as well.
82
+ # \n\nThere should be no duplicate test cases. Avoid generating repeated statements.
83
+ # """)
84
+ resp = pipeline(
85
+ f"""You are a helpful coding assistant. Your task is to generate unit test cases for this function : {snippet}\
86
+ \n\nPolitely refuse if the function is not suitable for generating test cases.
87
+ \n\nGenerate atleast 5 unit test case. Include couple of edge cases as well.
88
+ \n\nThere should be no duplicate test cases. Avoid generating repeated statements.
89
+ """,
90
+ do_sample=True,
91
+ top_k=10,
92
+ temperature=0.1,
93
+ top_p=0.95,
94
+ num_return_sequences=1,
95
+ eos_token_id=tokenizer.eos_token_id,
96
+ )
97
+ resp_list = [n['generated_text'] for n in resp]
98
+ response = " ".join(resp_list)
99
+ st.session_state.messages.append({"role": "assistant", "content": f"{resp['response']}"})
100
+ st.markdown(resp['response'])
101
+