arjunanand13 commited on
Commit
f785793
1 Parent(s): 9233424

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -129
app.py CHANGED
@@ -1,105 +1,100 @@
1
- import torch
2
- import json
3
- from torch import cuda, bfloat16
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
5
- from langchain.llms import HuggingFacePipeline
6
  from langchain.vectorstores import FAISS
7
- from langchain.chains import ConversationalRetrievalChain
 
 
 
 
 
8
  import gradio as gr
9
- from langchain.embeddings import HuggingFaceEmbeddings
10
- import os
11
 
12
- class Chatbot:
13
- def __init__(self):
14
- self.HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
- self.model_id = "mistralai/Mistral-7B-Instruct-v0.2"
16
- self.device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
17
- self.bnb_config = BitsAndBytesConfig(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  load_in_4bit=True,
19
- bnb_4bit_quant_type='nf4',
20
  bnb_4bit_use_double_quant=True,
21
- bnb_4bit_compute_dtype=bfloat16
 
22
  )
23
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=self.HF_TOKEN)
24
- self.model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", token=self.HF_TOKEN, quantization_config=self.bnb_config)
25
- self.stop_list = ['\nHuman:', '\n```\n']
26
- self.stop_token_ids = [self.tokenizer(x)['input_ids'] for x in self.stop_list]
27
- self.stop_token_ids = [torch.LongTensor(x).to(self.device) for x in self.stop_token_ids]
28
- self.stopping_criteria = StoppingCriteriaList([self.StopOnTokens()])
29
-
30
- self.generate_text = pipeline(
31
- model=self.model,
32
- tokenizer=self.tokenizer,
33
  return_full_text=True,
34
  task='text-generation',
35
- temperature=0.1,
36
  max_new_tokens=2048,
37
  )
38
- self.llm = HuggingFacePipeline(pipeline=self.generate_text)
39
-
40
- try:
41
- self.vectorstore = FAISS.load_local('faiss_index', HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"}))
42
- # self.vectorstore = FAISS.load_local('faiss_index', HuggingFaceEmbeddings(model_name="flax-sentence-embeddings/all_datasets_v3_MiniLM-L12", model_kwargs={"device": "cuda"}))
43
- print("Loaded embedding successfully")
44
- except ImportError as e:
45
- print("FAISS could not be imported. Make sure FAISS is installed correctly.")
46
- raise e
47
-
48
- self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vectorstore.as_retriever(), return_source_documents=True)
49
- self.chat_history = []
50
-
51
- class StopOnTokens(StoppingCriteria):
52
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
53
- for stop_ids in self.stop_token_ids:
54
- if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
55
- return True
56
- return False
57
-
58
- def format_prompt(self, query):
59
- prompt=f"""
60
- You are a knowledgeable assistant with access to a comprehensive database.
61
- I need you to answer my question and provide related information in a specific format.
62
- I have provided four relatable json files , choose the most suitable chunks for answering the query
63
- Here's what I need:
64
- Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
65
-
66
- Here's my question:
67
- Query:{query}
68
- Solution==>
69
- Example1
70
- Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
71
- Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
72
-
73
- Example2
74
- Query: "Can BQ25896 support I2C interface?",
75
- Solution: "Yes, the BQ25896 charger supports the I2C interface for communication.",
76
- """
77
- return prompt
78
 
79
- def qa_infer(self, query):
80
  content = ""
81
- formatted_prompt = self.format_prompt(query)
82
- result = self.chain({"question": formatted_prompt, "chat_history": self.chat_history})
83
- for doc in result['source_documents']:
84
  content += "-" * 50 + "\n"
85
- content += doc.page_content + "\n"
86
- print(content)
87
- print("#" * 100)
88
- print(result['answer'])
89
-
90
- output_file = "output.txt"
91
- with open(output_file, "w") as f:
92
- f.write("Query:\n")
93
- f.write(query + "\n\n")
94
- f.write("Answer:\n")
95
- f.write(result['answer'] + "\n\n")
96
- f.write("Source Documents:\n")
97
- f.write(content + "\n")
98
-
99
- download_link = f'<a href="file/{output_file}" download>Download Output File</a>'
100
- return result['answer'], content, download_link
101
-
102
- def launch_interface(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  css_code = """
104
  .gradio-container {
105
  background-color: #daccdb;
@@ -119,8 +114,6 @@ class Chatbot:
119
  "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
120
  "Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."]
121
 
122
-
123
-
124
  file_path = "ticketNames.txt"
125
 
126
  # Read the file content
@@ -129,43 +122,19 @@ class Chatbot:
129
  ticket_names = json.loads(content)
130
  dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
131
 
132
- tab1 = gr.Interface(fn=self.qa_infer, inputs=[gr.Textbox(label="QUERY", placeholder ="Enter your query here")], allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES"), gr.HTML()], css=css_code)
133
- tab2 = gr.Interface(fn=self.qa_infer, inputs=[dropdown], allow_flagging='never', outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES"), gr.HTML()], css=css_code)#, title="FAQs")
 
 
 
 
 
 
 
 
134
 
 
 
135
 
136
- # # Add dummy outputs to each interface
137
- # tab1.outputs = dummy_outputs
138
- # tab2.outputs = dummy_outputs
139
-
140
- gr.TabbedInterface([tab1, tab2],["Textbox Input", "FAQs"],title="TI E2E FORUM",css=css_code).launch(debug=True)
141
-
142
- # Instantiate and launch the chatbot
143
- chatbot = Chatbot()
144
- chatbot.launch_interface()
145
-
146
- """Single Tab Input Inference"""
147
-
148
- # def launch_interface(self):
149
- # css_code = """
150
- # .gradio-container {
151
- # background-color: #daccdb;
152
- # }
153
-
154
- # /* Button styling for all buttons */
155
- # button {
156
- # background-color: #927fc7; /* Default color for all other buttons */
157
- # color: black;
158
- # border: 1px solid black;
159
- # padding: 10px;
160
- # margin-right: 10px;
161
- # font-size: 16px; /* Increase font size */
162
- # font-weight: bold; /* Make text bold */
163
- # }
164
-
165
- # """
166
- # EXAMPLES = ["TDA4 product planning and datasheet release progress? ",
167
- # "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
168
- # "Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."]
169
-
170
- # demo = gr.Interface(fn=self.qa_infer, inputs=[gr.Textbox(label="QUERY", placeholder ="Enter your query here")], allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES"), gr.HTML()], css=css_code)
171
- # demo.launch()
 
1
+ import os
2
+ from langchain.document_loaders import TextLoader, DirectoryLoader
 
 
 
3
  from langchain.vectorstores import FAISS
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import torch
7
+ import numpy as np
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
9
+ from datetime import datetime
10
  import gradio as gr
 
 
11
 
12
+ class DocumentRetrievalAndGeneration:
13
+ def __init__(self, embedding_model_name, lm_model_id, data_folder, faiss_index_path):
14
+ self.documents = self.load_documents(data_folder)
15
+ self.embeddings = SentenceTransformer(embedding_model_name)
16
+ self.gpu_index = self.load_faiss_index(faiss_index_path)
17
+ self.llm = self.initialize_llm(lm_model_id)
18
+
19
+ def load_documents(self, folder_path):
20
+ loader = DirectoryLoader(folder_path, loader_cls=TextLoader)
21
+ documents = loader.load()
22
+ print('Length of documents:', len(documents))
23
+ return documents
24
+
25
+ def load_faiss_index(self, faiss_index_path):
26
+ cpu_index = faiss.read_index(faiss_index_path)
27
+ gpu_resource = faiss.StandardGpuResources()
28
+ gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, cpu_index)
29
+ return gpu_index
30
+
31
+ def initialize_llm(self, model_id):
32
+ bnb_config = BitsAndBytesConfig(
33
  load_in_4bit=True,
 
34
  bnb_4bit_use_double_quant=True,
35
+ bnb_4bit_quant_type="nf4",
36
+ bnb_4bit_compute_dtype=torch.bfloat16
37
  )
38
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+ model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
41
+ generate_text = pipeline(
42
+ model=model,
43
+ tokenizer=tokenizer,
 
 
 
 
44
  return_full_text=True,
45
  task='text-generation',
46
+ temperature=0.6,
47
  max_new_tokens=2048,
48
  )
49
+ return generate_text
50
+
51
+ def query_and_generate_response(self, query):
52
+ query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy()
53
+ distances, indices = self.gpu_index.search(np.array([query_embedding]), k=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
 
55
  content = ""
56
+ for idx in indices[0]:
 
 
57
  content += "-" * 50 + "\n"
58
+ content += self.documents[idx].page_content + "\n"
59
+ print(self.documents[idx].page_content)
60
+ print("############################")
61
+
62
+ prompt = f"Query: {query}\nSolution: {content}\n"
63
+
64
+ # Encode and prepare inputs
65
+ messages = [{"role": "user", "content": prompt}]
66
+ encodeds = self.llm.tokenizer.apply_chat_template(messages, return_tensors="pt")
67
+ model_inputs = encodeds.to(self.llm.device)
68
+
69
+ # Perform inference and measure time
70
+ start_time = datetime.now()
71
+ generated_ids = self.llm.model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
72
+ elapsed_time = datetime.now() - start_time
73
+
74
+ # Decode and return output
75
+ decoded = self.llm.tokenizer.batch_decode(generated_ids)
76
+ generated_response = decoded[0]
77
+ print("Generated response:", generated_response)
78
+ print("Time elapsed:", elapsed_time)
79
+ print("Device in use:", self.llm.device)
80
+
81
+ return generated_response
82
+
83
+ def qa_infer_gradio(self, query):
84
+ response = self.query_and_generate_response(query)
85
+ return response
86
+
87
+ if __name__ == "__main__":
88
+ # Example usage
89
+ embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12'
90
+ lm_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
91
+ data_folder = 'sample_embedding_folder'
92
+ faiss_index_path = 'faiss_index_new_model3.index'
93
+
94
+ doc_retrieval_gen = DocumentRetrievalAndGeneration(embedding_model_name, lm_model_id, data_folder, faiss_index_path)
95
+
96
+ # Define Gradio interface function
97
+ def launch_interface():
98
  css_code = """
99
  .gradio-container {
100
  background-color: #daccdb;
 
114
  "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
115
  "Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."]
116
 
 
 
117
  file_path = "ticketNames.txt"
118
 
119
  # Read the file content
 
122
  ticket_names = json.loads(content)
123
  dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
124
 
125
+ # Define Gradio interface
126
+ interface = gr.Interface(
127
+ fn=doc_retrieval_gen.qa_infer_gradio,
128
+ inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")],
129
+ allow_flagging='never',
130
+ examples=EXAMPLES,
131
+ cache_examples=False,
132
+ outputs=gr.Textbox(label="SOLUTION"),
133
+ css=css_code
134
+ )
135
 
136
+ # Launch Gradio interface
137
+ interface.launch(debug=True)
138
 
139
+ # Launch the interface
140
+ launch_interface()