arjunanand13 commited on
Commit
09e2eff
1 Parent(s): 50365d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from torch import cuda, bfloat16
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
5
+ from langchain.llms import HuggingFacePipeline
6
+ import gradio as gr
7
+ import os
8
+ import faiss
9
+ import numpy as np
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
+
12
+ class Chatbot:
13
+ def __init__(self):
14
+ self.HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
+ self.model_id = "mistralai/Mistral-7B-Instruct-v0.2"
16
+ self.device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
17
+ self.bnb_config = BitsAndBytesConfig(
18
+ load_in_4bit=True,
19
+ bnb_4bit_quant_type='nf4',
20
+ bnb_4bit_use_double_quant=True,
21
+ bnb_4bit_compute_dtype=bfloat16
22
+ )
23
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=self.HF_TOKEN)
24
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", token=self.HF_TOKEN, quantization_config=self.bnb_config)
25
+ self.stop_list = ['\nHuman:', '\n```\n']
26
+ self.stop_token_ids = [self.tokenizer(x)['input_ids'] for x in self.stop_list]
27
+ self.stop_token_ids = [torch.LongTensor(x).to(self.device) for x in self.stop_token_ids]
28
+ self.stopping_criteria = StoppingCriteriaList([self.StopOnTokens()])
29
+
30
+ self.generate_text = pipeline(
31
+ model=self.model,
32
+ tokenizer=self.tokenizer,
33
+ return_full_text=True,
34
+ task='text-generation',
35
+ temperature=0.1,
36
+ max_new_tokens=2048,
37
+ )
38
+ self.llm = HuggingFacePipeline(pipeline=self.generate_text)
39
+
40
+ # Initialize the embedding model
41
+ self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"})
42
+
43
+ try:
44
+ cpu_index = faiss.read_index('faiss_index_new_model3.index')
45
+ gpu_resource = faiss.StandardGpuResources()
46
+ self.vectorstore = faiss.index_cpu_to_gpu(gpu_resource, 0, cpu_index)
47
+ print("Loaded embedding successfully")
48
+ except Exception as e:
49
+ print("FAISS could not be imported or index could not be loaded.")
50
+ raise e
51
+
52
+ self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vectorstore.as_retriever(), return_source_documents=True)
53
+ self.chat_history = []
54
+
55
+ class StopOnTokens(StoppingCriteria):
56
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
57
+ for stop_ids in self.stop_token_ids:
58
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
59
+ return True
60
+ return False
61
+
62
+ def format_prompt(self, query):
63
+ prompt=f"""
64
+ You are a knowledgeable assistant with access to a comprehensive database.
65
+ I need you to answer my question and provide related information in a specific format.
66
+ I have provided four relatable json files , choose the most suitable chunks for answering the query
67
+ Here's what I need:
68
+ Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
69
+
70
+ Here's my question:
71
+ Query:{query}
72
+ Solution==>
73
+ Example1
74
+ Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
75
+ Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
76
+
77
+ Example2
78
+ Query: "Can BQ25896 support I2C interface?",
79
+ Solution: "Yes, the BQ25896 charger supports the I2C interface for communication.",
80
+ """
81
+ return prompt
82
+
83
+ def qa_infer(self, query):
84
+ content = ""
85
+ formatted_prompt = self.format_prompt(query)
86
+
87
+ # Embed the query
88
+ query_embedding = self.embeddings.embed_query(formatted_prompt)
89
+
90
+ # Perform the search
91
+ distances, indices = self.vectorstore.search(np.array([query_embedding]), k=5)
92
+
93
+ # Retrieve the top documents
94
+ for idx in indices[0]:
95
+ doc = self.vectorstore.get_document(idx)
96
+ content += "-" * 50 + "\n"
97
+ content += doc.page_content + "\n"
98
+
99
+ result = self.chain({"question": formatted_prompt, "chat_history": self.chat_history})
100
+ print(content)
101
+ print("#" * 100)
102
+ print(result['answer'])
103
+
104
+ output_file = "output.txt"
105
+ with open(output_file, "w") as f:
106
+ f.write("Query:\n")
107
+ f.write(query + "\n\n")
108
+ f.write("Answer:\n")
109
+ f.write(result['answer'] + "\n\n")
110
+ f.write("Source Documents:\n")
111
+ f.write(content + "\n")
112
+
113
+ download_link = f'<a href="file/{output_file}" download>Download Output File</a>'
114
+ return result['answer'], content, download_link
115
+
116
+ def launch_interface(self):
117
+ css_code = """
118
+ .gradio-container {
119
+ background-color: #daccdb;
120
+ }
121
+ /* Button styling for all buttons */
122
+ button {
123
+ background-color: #927fc7; /* Default color for all other buttons */
124
+ color: black;
125
+ border: 1px solid black;
126
+ padding: 10px;
127
+ margin-right: 10px;
128
+ font-size: 16px; /* Increase font size */
129
+ font-weight: bold; /* Make text bold */
130
+ }
131
+ """
132
+ EXAMPLES = ["TDA4 product planning and datasheet release progress? ",
133
+ "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
134
+ "Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."]
135
+
136
+
137
+
138
+ file_path = "ticketNames.txt"
139
+
140
+ # Read the file content
141
+ with open(file_path, "r") as file:
142
+ content = file.read()
143
+ ticket_names = json.loads(content)
144
+ dropdown = gr.Dropdown(label="Sample queries", choices=ticket_names)
145
+
146
+ tab1 = gr.Interface(fn=self.qa_infer, inputs=[gr.Textbox(label="QUERY", placeholder ="Enter your query here")], allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES"), gr.HTML()], css=css_code)
147
+ tab2 = gr.Interface(fn=self.qa_infer, inputs=[dropdown], allow_flagging='never', outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES"), gr.HTML()], css=css_code)#, title="FAQs")
148
+
149
+
150
+ # # Add dummy outputs to each interface
151
+ # tab1.outputs = dummy_outputs
152
+ # tab2.outputs = dummy_outputs
153
+
154
+ gr.TabbedInterface([tab1, tab2],["Textbox Input", "FAQs"],title="TI E2E FORUM",css=css_code).launch(debug=True)
155
+
156
+ # Instantiate and launch the chatbot
157
+ chatbot = Chatbot()
158
+ chatbot.launch_interface()