arjunanand13 commited on
Commit
0545ca0
1 Parent(s): e21e9b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import cuda, bfloat16
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
4
+ from langchain.llms import HuggingFacePipeline
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ import gradio as gr
8
+ from langchain.embeddings import HuggingFaceEmbeddings
9
+ import os
10
+
11
+ class Chatbot:
12
+ def __init__(self):
13
+ self.HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
+ self.model_id = "mistralai/Mistral-7B-Instruct-v0.2"
15
+ self.device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
16
+ self.bnb_config = BitsAndBytesConfig(
17
+ load_in_4bit=True,
18
+ bnb_4bit_quant_type='nf4',
19
+ bnb_4bit_use_double_quant=True,
20
+ bnb_4bit_compute_dtype=bfloat16
21
+ )
22
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, token=self.HF_TOKEN)
23
+ self.model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="auto", token=self.HF_TOKEN, quantization_config=self.bnb_config)
24
+ self.stop_list = ['\nHuman:', '\n```\n']
25
+ self.stop_token_ids = [self.tokenizer(x)['input_ids'] for x in self.stop_list]
26
+ self.stop_token_ids = [torch.LongTensor(x).to(self.device) for x in self.stop_token_ids]
27
+ self.stopping_criteria = StoppingCriteriaList([self.StopOnTokens()])
28
+
29
+ self.generate_text = pipeline(
30
+ model=self.model,
31
+ tokenizer=self.tokenizer,
32
+ return_full_text=True,
33
+ task='text-generation',
34
+ temperature=0.1,
35
+ max_new_tokens=2048,
36
+ )
37
+ self.llm = HuggingFacePipeline(pipeline=self.generate_text)
38
+
39
+ try:
40
+ self.vectorstore = FAISS.load_local('faiss_index', HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2", model_kwargs={"device": "cuda"}))
41
+ print("Loaded embedding successfully")
42
+ except ImportError as e:
43
+ print("FAISS could not be imported. Make sure FAISS is installed correctly.")
44
+ raise e
45
+
46
+ self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.vectorstore.as_retriever(), return_source_documents=True)
47
+ self.chat_history = []
48
+
49
+ class StopOnTokens(StoppingCriteria):
50
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
51
+ for stop_ids in self.stop_token_ids:
52
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
53
+ return True
54
+ return False
55
+
56
+ def format_prompt(self, query):
57
+ prompt=f"""
58
+ You are a knowledgeable assistant with access to a comprehensive database.
59
+ I need you to answer my question and provide related information in a specific format.
60
+ I have provided four relatable json files , choose the most suitable chunks for answering the query
61
+ Here's what I need:
62
+ Include a final answer without additional comments, sign-offs, or extra phrases. Be direct and to the point.
63
+
64
+ Here's my question:
65
+ {query}
66
+ Solution==>
67
+ Example1
68
+ Query: "How to use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM",
69
+ Solution: "To use IPU1_0 instead of A15_0 to process NDK in TDA2x-EVM, you need to modify the configuration file of the NDK application. Specifically, change the processor reference from 'A15_0' to 'IPU1_0'.",
70
+
71
+ Example2
72
+ Query: "Can BQ25896 support I2C interface?",
73
+ Solution: "Yes, the BQ25896 charger supports the I2C interface for communication.",
74
+ """
75
+ return prompt
76
+
77
+ def qa_infer(self, query):
78
+ content = ""
79
+ formatted_prompt = self.format_prompt(query)
80
+ result = self.chain({"question": formatted_prompt, "chat_history": self.chat_history})
81
+ for doc in result['source_documents']:
82
+ content += "-" * 50 + "\n"
83
+ content += doc.page_content + "\n"
84
+ print(content)
85
+ print("#" * 100)
86
+ print(result['answer'])
87
+
88
+ output_file = "output.txt"
89
+ with open(output_file, "w") as f:
90
+ f.write("Query:\n")
91
+ f.write(query + "\n\n")
92
+ f.write("Answer:\n")
93
+ f.write(result['answer'] + "\n\n")
94
+ f.write("Source Documents:\n")
95
+ f.write(content + "\n")
96
+
97
+ download_link = f'<a href="file/{output_file}" download>Download Output File</a>'
98
+ return result['answer'], content, download_link
99
+
100
+ def launch_interface(self):
101
+ css_code = """
102
+ .gradio-container {
103
+ background-color: #daccdb;
104
+ }
105
+
106
+ /* Button styling for all buttons */
107
+ button {
108
+ background-color: #927fc7; /* Default color for all other buttons */
109
+ color: black;
110
+ border: 1px solid black;
111
+ padding: 10px;
112
+ margin-right: 10px;
113
+ font-size: 16px; /* Increase font size */
114
+ font-weight: bold; /* Make text bold */
115
+ }
116
+
117
+ """
118
+ EXAMPLES = ["TDA4 product planning and datasheet release progress? ",
119
+ "I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?",
120
+ "Master core in TDA2XX is a15 and in TDA3XX it is m4,so we have to shift all modules that are being used by a15 in TDA2XX to m4 in TDA3xx."]
121
+
122
+ demo = gr.Interface(fn=self.qa_infer, inputs=[gr.Textbox(label="QUERY", placeholder ="Enter your query here")], allow_flagging='never', examples=EXAMPLES, cache_examples=False, outputs=[gr.Textbox(label="SOLUTION"), gr.Textbox(label="RELATED QUERIES"), gr.HTML()], css=css_code)
123
+ demo.launch()
124
+
125
+ # Instantiate and launch the chatbot
126
+ chatbot = Chatbot()
127
+ chatbot.launch_interface()