chasetank commited on
Commit
43d88cd
·
1 Parent(s): 4cf7340

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -257
app.py CHANGED
@@ -1,257 +1,25 @@
1
- import plotly.graph_objs as go
2
- from sklearn.cluster import KMeans
3
- from sklearn.decomposition import PCA
4
- import plotly.express as px
5
- import numpy as np
6
- import os
7
- import pprint
8
- import codecs
9
- import chardet
10
- import gradio as gr
11
- from langchain.llms import HuggingFacePipeline, OpenAIChat
12
- from langchain.text_splitter import RecursiveCharacterTextSplitter
13
- from langchain.embeddings import HuggingFaceEmbeddings
14
- from langchain.vectorstores import FAISS
15
- from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
16
- from langchain.chains.conversation.memory import ConversationBufferMemory
17
- from EdgeGPT import Chatbot
18
- import whisper
19
- from datetime import datetime
20
- import json
21
- import requests
22
- from langchain.chains.question_answering import load_qa_chain
23
- import langchain
24
-
25
-
26
- class ChatbotClass:
27
- def __init__(self):
28
- FOLDER_PATH = './data/eqe-manual'
29
- QUERY = 'How do I charge my vehicle?'
30
- K = 10
31
- self.whisper_model = whisper.load_model(name='tiny')
32
-
33
- self.embeddings = HuggingFaceEmbeddings()
34
-
35
- self.index = FAISS.load_local(
36
- folder_path=FOLDER_PATH, embeddings=self.embeddings
37
- )
38
-
39
- self.llm = OpenAIChat(temperature=0)
40
-
41
- self.memory = ConversationBufferMemory(
42
- memory_key="chat_history", input_key="human_input", return_messages=True
43
- )
44
-
45
- self.keyword_chain = self.init_keyword_chain()
46
- self.context_chain = self.init_context_chain()
47
- self.document_retrieval_chain = self.init_document_retrieval()
48
- self.conversation_chain = self.init_conversation()
49
-
50
- def format_history(self, memory):
51
- history = memory.chat_memory.messages
52
- if len(history) == 0:
53
- return []
54
-
55
- formatted_history = []
56
-
57
- for h in history:
58
- if isinstance(h, langchain.schema.HumanMessage):
59
- user_response = h.content
60
- elif isinstance(h, langchain.schema.AIMessage):
61
- ai_response = h.content
62
- formatted_history.append((user_response, ai_response))
63
-
64
- return formatted_history
65
-
66
- def init_document_retrieval(self):
67
- retrieve_documents_template = """This function retrieves exerts from a Vehicle Owner's Manual. The function is useful for adding vehicle-specific context to answer questions. Based on a request, determine if vehicle specific information is needed. Respond with "Yes" or "No". If the answer is both, respond with "Yes":\nrequest: How do I change the tire?\nresponse: Yes\nrequest: Hello\nresponse: No\nrequest: I was in an accident. What should I do?\nresponse: Yes\nrequest: {request}\nresponse:"""
68
-
69
- prompt_template = PromptTemplate(
70
- input_variables=["request"],
71
- template=retrieve_documents_template
72
- )
73
-
74
- document_retrieval_chain = LLMChain(
75
- llm=self.llm, prompt=prompt_template, verbose=True
76
- )
77
-
78
- return document_retrieval_chain
79
-
80
- def init_keyword_chain(self):
81
- keyword_template = """You are a vehicle owner searching for content in your vehicle's owner manual. Your job is to come up with keywords to use when searching inside your manual, based on a question you have.
82
- Question: {question}
83
- Keywords:"""
84
-
85
- prompt_template = PromptTemplate(
86
- input_variables=["question"], template=keyword_template
87
- )
88
-
89
- keyword_chain = LLMChain(
90
- llm=self.llm, prompt=prompt_template, verbose=True)
91
-
92
- return keyword_chain
93
-
94
- def init_context_chain(self):
95
- context_template = """You are a friendly and helpful chatbot having a conversation with a human.
96
- Given the following extracted parts of a long document and a question, create a final answer.
97
-
98
- {context}
99
-
100
- {chat_history}
101
- Human: {human_input}
102
- Chatbot:"""
103
-
104
- context_prompt = PromptTemplate(
105
- input_variables=["chat_history", "human_input", "context"],
106
- template=context_template
107
- )
108
-
109
- self.memory = ConversationBufferMemory(
110
- memory_key="chat_history", input_key="human_input", return_messages=True
111
- )
112
-
113
- context_chain = load_qa_chain(
114
- self.llm, chain_type="stuff", memory=self.memory, prompt=context_prompt
115
- )
116
-
117
- return context_chain
118
-
119
- def init_conversation(self):
120
- template = """You are a chatbot having a conversation with a human.
121
-
122
- {chat_history}
123
- Human: {human_input}
124
- Chatbot:"""
125
-
126
- prompt = PromptTemplate(
127
- input_variables=["chat_history", "human_input"],
128
- template=template
129
- )
130
-
131
- conversation_chain = LLMChain(
132
- llm=self.llm,
133
- prompt=prompt,
134
- verbose=True,
135
- memory=self.memory,
136
- )
137
-
138
- return conversation_chain
139
-
140
-
141
- def transcribe_audio(self, audio_file, model):
142
- result = self.whisper_model.transcribe(audio_file)
143
- return result['text']
144
-
145
-
146
- def ask_question(self, query, k=4):
147
- tool_usage = self.document_retrieval_chain.run(query)
148
- print('\033[1;32m' f'search manual?: {tool_usage}' "\033[0m")
149
-
150
- chat_history = self.format_history(self.memory)
151
-
152
- if tool_usage == 'Yes':
153
- keywords = self.keyword_chain.run(question=query)
154
- print('\033[1;32m' f'keywords:{keywords}' "\033[0m")
155
-
156
- context = self.index.similarity_search(query=keywords, k=k)
157
-
158
- result = self.context_chain.run(
159
- input_documents=context, human_input=query
160
- )
161
- else:
162
- result = self.conversation_chain.run(query)
163
-
164
- return [(query, result)], chat_history
165
-
166
-
167
- def invoke_exh_api(self, bot_response, bot_name='Zippy', voice_name='Fiona', idle_url='https://ugc-idle.s3-us-west-2.amazonaws.com/4a6a607a466bdf6605bbd97ef146751b.mp4', animation_pipeline='high_quality', bearer_token='eyJhbGciOiJIUzUxMiJ9.eyJ1c2VybmFtZSI6IndlYiJ9.LSzIQx6h61l5FXs52s0qcY8WqauET6z9nnxgSzvoNBx8RYEKm8OpOohcK8wjuwteV4ZGug4NOjoGQoUZIKH84A'):
168
- if len(bot_response) > 200:
169
- print('Input is over 200 characters. Shorten the message')
170
-
171
- url = 'https://api.exh.ai/animations/v1/generate_lipsync'
172
- headers = {
173
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.46',
174
- 'authority': 'api.exh.ai',
175
- 'accept': '*/*',
176
- 'accept-encoding': 'gzip, deflate, br',
177
- 'accept-language': 'en-US,en;q=0.9',
178
- 'authorization': f'Bearer {bearer_token}',
179
- 'content-type': 'application/json',
180
- 'origin': 'https://admin.exh.ai',
181
- 'referer': 'https://admin.exh.ai/',
182
- 'sec-ch-ua': '"Chromium";v="110", "Not A(Brand";v="24", "Microsoft Edge";v="110"',
183
- 'sec-ch-ua-mobile': '?0',
184
- 'sec-ch-ua-platform': '"Windows"',
185
- 'sec-fetch-dest': 'empty',
186
- 'sec-fetch-mode': 'cors',
187
- 'sec-fetch-site': 'same-site',
188
- }
189
-
190
- data = {
191
- 'bot_name': bot_name,
192
- 'bot_response': bot_response,
193
- 'voice_name': voice_name,
194
- 'idle_url': idle_url,
195
- 'animation_pipeline': animation_pipeline,
196
- }
197
-
198
- r = requests.post(url, headers=headers, data=json.dumps(data))
199
-
200
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f')
201
- outfile = f'talking_head_{timestamp}.mp4'
202
-
203
- with open(outfile, 'wb') as f:
204
- f.write(r.content)
205
-
206
- return outfile
207
-
208
-
209
- def predict(self, input_data, state=[], k=4, input_type='audio'):
210
- if input_type == 'audio':
211
- txt = self.transcribe_audio(input_data[0], self.whisper_model)
212
- else:
213
- txt = input_data[1]
214
- result, chat_history = self.ask_question(txt, k=k)
215
- state.append(chat_history)
216
- return result, state
217
-
218
-
219
- def predict_wrapper(self, input_text=None, input_audio=None):
220
- if input_audio is not None:
221
- result, state = self.predict(
222
- input_data=(input_audio,), input_type='audio')
223
- else:
224
- result, state = self.predict(
225
- input_data=('', input_text), input_type='text')
226
-
227
- response = result[0][1][:195]
228
- avatar = self.invoke_exh_api(response)
229
-
230
- return result,avatar
231
-
232
-
233
- man_chatbot = ChatbotClass()
234
-
235
- iface = gr.Interface(
236
- fn=man_chatbot.predict_wrapper,
237
- inputs=[gr.inputs.Textbox(label="Text Input"),
238
- gr.inputs.Audio(source="microphone", type='filepath')],
239
- outputs=[gr.outputs.Textbox(label="Result"),
240
- gr.outputs.Video().style(width=360, height=360, container=True)]
241
- )
242
- iface.launch()
243
-
244
- '''
245
- iface.launch()
246
- with gr.Blocks() as demo:
247
- chatbot = gr.Chatbot()
248
- state = gr.State([])
249
-
250
- with gr.Row():
251
- txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(
252
- container=False)
253
- k_slider = gr.Slider(minimum=1, maximum=10, default=4,label='k')
254
- txt.submit(man_chatbot.predict, [txt, state,k_slider],[chatbot,state])
255
-
256
- demo.launch()
257
- '''
 
1
+ from InnovationHub.llm.vector_store import *
2
+ from InnovationHub.llm.chain import *
3
+
4
+
5
+ def start_ui():
6
+ chatbot_interface = gradio.Interface(
7
+ fn=chat,
8
+ inputs=["text",
9
+ gradio.inputs.Dropdown(vehicle_options, label="Select a Mercedes-Benz Owner's Manual")
10
+ #gradio.inputs.Slider(minimum=1, maximum=10, step=1, label="k"),
11
+ #gradio.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature")
12
+ ],
13
+ outputs="text",
14
+ title="Mercedes-Benz Owner's Manual",
15
+ description="Ask questions to get answers from the Mercedes-Benz Owner's manual. <br/><font size='-2'><u>Disclaimer:</u> THIS IS NOT OFFICIAL AND MAY NOT BE AVAILABLE ALL THE TIME. ALWAYS LOOK AT THE OFFICIAL DOCUMENTATION at https://www.mbusa.com/en/owners/manuals</font>",
16
+ examples=[["What are the different features of the dashboard console?", "2023 S-Class", 10, 0.01],
17
+ ["What is flacon?", "2023 S-Class", 10, 0.01],
18
+ ["What is hyperscreen?", "2023 EQS", 10, 0.01]],
19
+ article = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=kaushikdatta.owner-manual" alt="visitor badge"/></center>'
20
+ )
21
+
22
+ chatbot_interface.launch()
23
+
24
+ if __name__ == '__main__':
25
+ start_ui()