arjunanand13 commited on
Commit
c119679
1 Parent(s): 3c0e45f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -158
app.py CHANGED
@@ -1,168 +1,261 @@
 
 
1
  import os
2
- import time
3
- import spaces
 
 
 
 
 
4
  import torch
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
 
 
 
6
  import gradio as gr
7
- from threading import Thread
8
-
9
-
10
- HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
- MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
12
-
13
- TITLE = "<h1><center>Meta-Llama3.1-8B</center></h1>"
14
-
15
- PLACEHOLDER = """
16
- <center>
17
- <p>Hi! How can I help you today?</p>
18
- </center>
19
- """
20
-
21
-
22
- CSS = """
23
- .duplicate-button {
24
- margin: auto !important;
25
- color: white !important;
26
- background: black !important;
27
- border-radius: 100vh !important;
28
- }
29
- h3 {
30
- text-align: center;
31
- }
32
- """
33
-
34
- device = "cuda" # for GPU usage or "cpu" for CPU usage
35
-
36
- quantization_config = BitsAndBytesConfig(
37
- load_in_4bit=True,
38
- bnb_4bit_compute_dtype=torch.bfloat16,
39
- bnb_4bit_use_double_quant=True,
40
- bnb_4bit_quant_type= "nf4")
41
-
42
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
43
- model = AutoModelForCausalLM.from_pretrained(
44
- MODEL,
45
- torch_dtype=torch.bfloat16,
46
- device_map="auto",
47
- quantization_config=quantization_config)
48
-
49
- @spaces.GPU()
50
- def stream_chat(
51
- message: str,
52
- history: list,
53
- system_prompt: str,
54
- temperature: float = 0.8,
55
- max_new_tokens: int = 1024,
56
- top_p: float = 1.0,
57
- top_k: int = 20,
58
- penalty: float = 1.2,
59
- ):
60
- print(f'message: {message}')
61
- print(f'history: {history}')
62
-
63
- conversation = [
64
- {"role": "system", "content": system_prompt}
65
- ]
66
- for prompt, answer in history:
67
- conversation.extend([
68
- {"role": "user", "content": prompt},
69
- {"role": "assistant", "content": answer},
70
- ])
71
-
72
- conversation.append({"role": "user", "content": message})
73
-
74
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
75
-
76
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
77
 
78
- generate_kwargs = dict(
79
- input_ids=input_ids,
80
- max_new_tokens = max_new_tokens,
81
- do_sample = False if temperature == 0 else True,
82
- top_p = top_p,
83
- top_k = top_k,
84
- temperature = temperature,
85
- repetition_penalty=penalty,
86
- eos_token_id=[128001,128008,128009],
87
- streamer=streamer,
88
- )
89
-
90
- with torch.no_grad():
91
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
92
- thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
- buffer = ""
95
- for new_text in streamer:
96
- buffer += new_text
97
- yield buffer
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- chatbot = gr.Chatbot(height=600, placeholder=PLACEHOLDER)
101
-
102
- with gr.Blocks(css=CSS, theme="soft") as demo:
103
- gr.HTML(TITLE)
104
- gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
105
- gr.ChatInterface(
106
- fn=stream_chat,
107
- chatbot=chatbot,
108
- fill_height=True,
109
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
110
- additional_inputs=[
111
- gr.Textbox(
112
- value="You are a helpful assistant",
113
- label="System Prompt",
114
- render=False,
115
- ),
116
- gr.Slider(
117
- minimum=0,
118
- maximum=1,
119
- step=0.1,
120
- value=0.8,
121
- label="Temperature",
122
- render=False,
123
- ),
124
- gr.Slider(
125
- minimum=128,
126
- maximum=8192,
127
- step=1,
128
- value=1024,
129
- label="Max new tokens",
130
- render=False,
131
- ),
132
- gr.Slider(
133
- minimum=0.0,
134
- maximum=1.0,
135
- step=0.1,
136
- value=1.0,
137
- label="top_p",
138
- render=False,
139
- ),
140
- gr.Slider(
141
- minimum=1,
142
- maximum=20,
143
- step=1,
144
- value=20,
145
- label="top_k",
146
- render=False,
147
- ),
148
- gr.Slider(
149
- minimum=0.0,
150
- maximum=2.0,
151
- step=0.1,
152
- value=1.2,
153
- label="Repetition penalty",
154
- render=False,
155
- ),
156
- ],
157
- examples=[
158
- ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
159
- ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
160
- ["Tell me a random fun fact about the Roman Empire."],
161
- ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
162
- ],
163
- cache_examples=False,
164
- )
165
 
 
 
 
 
 
166
 
167
  if __name__ == "__main__":
168
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ "Single Thread"
2
+
3
  import os
4
+ import multiprocessing
5
+ import concurrent.futures
6
+ from langchain.document_loaders import TextLoader, DirectoryLoader
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.vectorstores import FAISS
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
  import torch
12
+ import numpy as np
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
14
+ from datetime import datetime
15
+ import json
16
  import gradio as gr
17
+ import re
18
+
19
+ class DocumentRetrievalAndGeneration:
20
+ def __init__(self, embedding_model_name, lm_model_id, data_folder):
21
+ self.all_splits = self.load_documents(data_folder)
22
+ self.embeddings = SentenceTransformer(embedding_model_name)
23
+ self.gpu_index = self.create_faiss_index()
24
+ self.llm = self.initialize_llm(lm_model_id)
25
+
26
+ def load_documents(self, folder_path):
27
+ loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
28
+ documents = loader.load()
29
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250)
30
+ all_splits = text_splitter.split_documents(documents)
31
+ print('Length of documents:', len(documents))
32
+ print("LEN of all_splits", len(all_splits))
33
+ for i in range(5):
34
+ print(all_splits[i].page_content)
35
+ return all_splits
36
+
37
+ def create_faiss_index(self):
38
+ all_texts = [split.page_content for split in self.all_splits]
39
+ embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy()
40
+ index = faiss.IndexFlatL2(embeddings.shape[1])
41
+ index.add(embeddings)
42
+ gpu_resource = faiss.StandardGpuResources()
43
+ gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index)
44
+ return gpu_index
45
+
46
+ def initialize_llm(self, model_id):
47
+ bnb_config = BitsAndBytesConfig(
48
+ load_in_4bit=True,
49
+ bnb_4bit_use_double_quant=True,
50
+ bnb_4bit_quant_type="nf4",
51
+ bnb_4bit_compute_dtype=torch.bfloat16
52
+ )
53
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
55
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
56
+ generate_text = pipeline(
57
+ model=model,
58
+ tokenizer=tokenizer,
59
+ return_full_text=True,
60
+ task='text-generation',
61
+ temperature=0.6,
62
+ max_new_tokens=256,
63
+ )
64
+ return generate_text
65
+
66
+ def generate_response_with_timeout(self, model_inputs):
67
+ try:
68
+ with concurrent.futures.ThreadPoolExecutor() as executor:
69
+ future = executor.submit(self.llm.model.generate, model_inputs, max_new_tokens=1000, do_sample=True)
70
+ generated_ids = future.result(timeout=60) # Timeout set to 60 seconds
71
+ return generated_ids
72
+ except concurrent.futures.TimeoutError:
73
+ raise TimeoutError("Text generation process timed out")
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ def query_and_generate_response(self, query):
76
+ query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
77
+ distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
78
+
79
+ content = ""
80
+ for idx in indices[0]:
81
+ content += "-" * 50 + "\n"
82
+ content += self.all_splits[idx].page_content + "\n"
83
+ print("CHUNK", idx)
84
+ print(self.all_splits[idx].page_content)
85
+ print("############################")
86
+ prompt = f"""<s>
87
+ You are a knowledgeable assistant with access to a comprehensive database.
88
+ I need you to answer my question and provide related information in a specific format.
89
+ I have provided five relatable json files {content}, choose the most suitable chunks for answering the query
90
+ Here's what I need:
91
+ Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
92
+ content
93
+ Here's my question:
94
+ Query:{query}
95
+ Solution==>
96
+ RETURN ONLY SOLUTION . IF THEIR IS NO ANSWER RELATABLE IN RETRIEVED CHUNKS , RETURN " NO SOLUTION AVAILABLE"
97
+ Example1
98
+ Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
99
+ Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
100
+
101
+ Example2
102
+ Query: "Can BQ25896 support I2C interface?",
103
+ Solution: "Yes, the BQ25896 charger supports the I2C interface for communication."
104
+ </s>
105
+ """
106
+ # prompt = f"Query: {query}\nSolution: {content}\n"
107
+
108
+ # Encode and prepare inputs
109
+ messages = [{"role": "user", "content": prompt}]
110
+ encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
111
+ model_inputs = encodeds.to(self.llm.device)
112
 
113
+ # Perform inference and measure time
114
+ start_time = datetime.now()
115
+ generated_ids = self.generate_response_with_timeout(model_inputs)
116
+ # generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
117
+ elapsed_time = datetime.now() - start_time
118
 
119
+ # Decode and return output
120
+ decoded = self.llm.tokenizer.batch_decode(generated_ids)
121
+ generated_response = decoded[0]
122
+ match1 = re.search(r'\[/INST\](.*?)</s>', generated_response, re.DOTALL)
123
+
124
+ match2 = re.search(r'Solution:(.*?)</s>', generated_response, re.DOTALL | re.IGNORECASE)
125
+ if match1:
126
+ solution_text = match1.group(1).strip()
127
+ print(solution_text)
128
+ if "Solution:" in solution_text:
129
+ solution_text = solution_text.split("Solution:", 1)[1].strip()
130
+ elif match2:
131
+ solution_text = match2.group(1).strip()
132
+ print(solution_text)
133
 
134
+ else:
135
+ solution_text=generated_response
136
+ print("Generated response:", generated_response)
137
+ print("Time elapsed:", elapsed_time)
138
+ print("Device in use:", self.llm.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ return solution_text, content
141
+
142
+ def qa_infer_gradio(self, query):
143
+ response = self.query_and_generate_response(query)
144
+ return response
145
 
146
  if __name__ == "__main__":
147
+ # Example usage
148
+ embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
149
+ lm_model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
150
+ data_folder = 'sample_embedding_folder2'
151
+
152
+ doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder)
153
+
154
+ """Dual Interface"""
155
+
156
+ def launch_interface():
157
+ css_code = """
158
+ .gradio-container {
159
+ background-color: #daccdb;
160
+ }
161
+ /* Button styling for all buttons */
162
+ button {
163
+ background-color: #927fc7; /* Default color for all other buttons */
164
+ color: black;
165
+ border: 1px solid black;
166
+ padding: 10px;
167
+ margin-right: 10px;
168
+ font-size: 16px; /* Increase font size */
169
+ font-weight: bold; /* Make text bold */
170
+ }
171
+ """
172
+ EXAMPLES = [
173
+ "On which devices can the VIP and CSI2 modules operate simultaneously?",
174
+ "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
175
+ "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"
176
+ ]
177
+
178
+ file_path = "ticketNames.txt"
179
+
180
+ # Read the file content
181
+ with open(file_path, "r") as file:
182
+ content = file.read()
183
+ ticket_names = json.loads(content)
184
+ dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
185
+
186
+ # Define Gradio interfaces
187
+ tab1 = gr.Interface(
188
+ fn=doc_retrieval_gen.qa_infer_gradio,
189
+ inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
190
+ allow_flagging='never',
191
+ examples=EXAMPLES,
192
+ cache_examples=False,
193
+ outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
194
+ css=css_code
195
+ )
196
+ tab2 = gr.Interface(
197
+ fn=doc_retrieval_gen.qa_infer_gradio,
198
+ inputs=[dropdown],
199
+ allow_flagging='never',
200
+ outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
201
+ css=css_code
202
+ )
203
+
204
+ # Combine interfaces into a tabbed interface
205
+ gr.TabbedInterface(
206
+ [tab1, tab2],
207
+ ["Textbox Input", "FAQs"],
208
+ title="TI E2E FORUM",
209
+ css=css_code
210
+ ).launch(debug=True)
211
+
212
+ # Launch the interface
213
+ launch_interface()
214
+
215
+
216
+
217
+ """Single Interface"""
218
+ # def launch_interface():
219
+ # css_code = """
220
+ # .gradio-container {
221
+ # background-color: #daccdb;
222
+ # }
223
+ # /* Button styling for all buttons */
224
+ # button {
225
+ # background-color: #927fc7; /* Default color for all other buttons */
226
+ # color: black;
227
+ # border: 1px solid black;
228
+ # padding: 10px;
229
+ # margin-right: 10px;
230
+ # font-size: 16px; /* Increase font size */
231
+ # font-weight: bold; /* Make text bold */
232
+ # }
233
+ # """
234
+ # EXAMPLES = ["On which devices can the VIP and CSI2 modules operate simultaneously? ",
235
+ # "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
236
+ # "Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC, considering it supports up to 10 multiplexed input ports and includes 3 dedicated video input modules?"]
237
+
238
+ # file_path = "ticketNames.txt"
239
+
240
+ # # Read the file content
241
+ # with open(file_path, "r") as file:
242
+ # content = file.read()
243
+ # ticket_names = json.loads(content)
244
+ # dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
245
+
246
+ # # Define Gradio interface
247
+ # interface = gr.Interface(
248
+ # fn=doc_retrieval_gen.qa_infer_gradio,
249
+ # inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
250
+ # allow_flagging='never',
251
+ # examples=EXAMPLES,
252
+ # cache_examples=False,
253
+ # outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES")],
254
+ # css=css_code
255
+ # )
256
+
257
+ # # Launch Gradio interface
258
+ # interface.launch(debug=True)
259
+
260
+ # # Launch the interface
261
+ # launch_interface()