chasetank commited on
Commit
7db35c9
1 Parent(s): bd73b6b

first commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ data/eqe-manual/index.faiss filter=lfs diff=lfs merge=lfs -text
36
+ data/eqs-manual/index.faiss filter=lfs diff=lfs merge=lfs -text
37
+ data/s-class-manual/index.faiss filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import plotly.graph_objs as go
2
+ from sklearn.cluster import KMeans
3
+ from sklearn.decomposition import PCA
4
+ import plotly.express as px
5
+ import numpy as np
6
+ import os
7
+ import pprint
8
+ import codecs
9
+ import chardet
10
+ import gradio as gr
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate
16
+ from langchain.chains.conversation.memory import ConversationBufferMemory
17
+ from EdgeGPT import Chatbot
18
+ import whisper
19
+ from datetime import datetime
20
+ import json
21
+ import requests
22
+
23
+ class ChatbotClass:
24
+ def __init__(self):
25
+ FOLDER_PATH = './data/eqe-manual'
26
+ QUERY = 'How do I charge my vehicle?'
27
+ K = 10
28
+ self.whisper_model = whisper.load_model(name='tiny')
29
+
30
+ self.embeddings = HuggingFaceEmbeddings()
31
+
32
+ self.index = FAISS.load_local(
33
+ folder_path=FOLDER_PATH, embeddings=self.embeddings
34
+ )
35
+
36
+ self.llm = OpenAIChat(temperature=0)
37
+
38
+ self.memory = ConversationBufferMemory(
39
+ memory_key="chat_history", input_key="human_input", return_messages=True
40
+ )
41
+
42
+ self.keyword_chain = self.init_keyword_chain()
43
+ self.context_chain = self.init_context_chain()
44
+ self.document_retrieval_chain = self.init_document_retrieval()
45
+ self.conversation_chain = self.init_conversation()
46
+
47
+ def format_history(self, memory):
48
+ history = memory.chat_memory.messages
49
+ if len(history) == 0:
50
+ return []
51
+
52
+ formatted_history = []
53
+
54
+ for h in history:
55
+ if isinstance(h, langchain.schema.HumanMessage):
56
+ user_response = h.content
57
+ elif isinstance(h, langchain.schema.AIMessage):
58
+ ai_response = h.content
59
+ formatted_history.append((user_response, ai_response))
60
+
61
+ return formatted_history
62
+
63
+ def init_document_retrieval(self):
64
+ retrieve_documents_template = """This function retrieves exerts from a Vehicle Owner's Manual. The function is useful for adding vehicle-specific context to answer questions. Based on a request, determine if vehicle specific information is needed. Respond with "Yes" or "No". If the answer is both, respond with "Yes":\nrequest: How do I change the tire?\nresponse: Yes\nrequest: Hello\nresponse: No\nrequest: I was in an accident. What should I do?\nresponse: Yes\nrequest: {request}\nresponse:"""
65
+
66
+ prompt_template = PromptTemplate(
67
+ input_variables=["request"],
68
+ template=retrieve_documents_template
69
+ )
70
+
71
+ document_retrieval_chain = LLMChain(
72
+ llm=self.llm, prompt=prompt_template, verbose=True
73
+ )
74
+
75
+ return document_retrieval_chain
76
+
77
+ def init_keyword_chain(self):
78
+ keyword_template = """You are a vehicle owner searching for content in your vehicle's owner manual. Your job is to come up with keywords to use when searching inside your manual, based on a question you have.
79
+ Question: {question}
80
+ Keywords:"""
81
+
82
+ prompt_template = PromptTemplate(
83
+ input_variables=["question"], template=keyword_template
84
+ )
85
+
86
+ keyword_chain = LLMChain(
87
+ llm=self.llm, prompt=prompt_template, verbose=True)
88
+
89
+ return keyword_chain
90
+
91
+ def init_context_chain(self):
92
+ context_template = """You are a friendly and helpful chatbot having a conversation with a human.
93
+ Given the following extracted parts of a long document and a question, create a final answer.
94
+
95
+ {context}
96
+
97
+ {chat_history}
98
+ Human: {human_input}
99
+ Chatbot:"""
100
+
101
+ context_prompt = PromptTemplate(
102
+ input_variables=["chat_history", "human_input", "context"],
103
+ template=context_template
104
+ )
105
+
106
+ self.memory = ConversationBufferMemory(
107
+ memory_key="chat_history", input_key="human_input", return_messages=True
108
+ )
109
+
110
+ context_chain = load_qa_chain(
111
+ self.llm, chain_type="stuff", memory=self.memory, prompt=context_prompt
112
+ )
113
+
114
+ return context_chain
115
+
116
+ def init_conversation(self):
117
+ template = """You are a chatbot having a conversation with a human.
118
+
119
+ {chat_history}
120
+ Human: {human_input}
121
+ Chatbot:"""
122
+
123
+ prompt = PromptTemplate(
124
+ input_variables=["chat_history", "human_input"],
125
+ template=template
126
+ )
127
+
128
+ conversation_chain = LLMChain(
129
+ llm=self.llm,
130
+ prompt=prompt,
131
+ verbose=True,
132
+ memory=self.memory,
133
+ )
134
+
135
+ return conversation_chain
136
+
137
+
138
+ def transcribe_audio(self, audio_file, model):
139
+ result = self.whisper_model.transcribe(audio_file)
140
+ return result['text']
141
+
142
+
143
+ def ask_question(self, query, k=4):
144
+ tool_usage = self.document_retrieval_chain.run(query)
145
+ print('\033[1;32m' f'search manual?: {tool_usage}' "\033[0m")
146
+
147
+ chat_history = self.format_history(self.memory)
148
+
149
+ if tool_usage == 'Yes':
150
+ keywords = self.keyword_chain.run(question=query)
151
+ print('\033[1;32m' f'keywords:{keywords}' "\033[0m")
152
+
153
+ context = self.index.similarity_search(query=keywords, k=k)
154
+
155
+ result = self.context_chain.run(
156
+ input_documents=context, human_input=query
157
+ )
158
+ else:
159
+ result = self.conversation_chain.run(query)
160
+
161
+ return [(query, result)], chat_history
162
+
163
+
164
+ def invoke_exh_api(self, bot_response, bot_name='Zippy', voice_name='Fiona', idle_url='https://ugc-idle.s3-us-west-2.amazonaws.com/4a6a607a466bdf6605bbd97ef146751b.mp4', animation_pipeline='high_quality', bearer_token='eyJhbGciOiJIUzUxMiJ9.eyJ1c2VybmFtZSI6IndlYiJ9.LSzIQx6h61l5FXs52s0qcY8WqauET6z9nnxgSzvoNBx8RYEKm8OpOohcK8wjuwteV4ZGug4NOjoGQoUZIKH84A'):
165
+ if len(bot_response) > 200:
166
+ print('Input is over 200 characters. Shorten the message')
167
+
168
+ url = 'https://api.exh.ai/animations/v1/generate_lipsync'
169
+ headers = {
170
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.46',
171
+ 'authority': 'api.exh.ai',
172
+ 'accept': '*/*',
173
+ 'accept-encoding': 'gzip, deflate, br',
174
+ 'accept-language': 'en-US,en;q=0.9',
175
+ 'authorization': f'Bearer {bearer_token}',
176
+ 'content-type': 'application/json',
177
+ 'origin': 'https://admin.exh.ai',
178
+ 'referer': 'https://admin.exh.ai/',
179
+ 'sec-ch-ua': '"Chromium";v="110", "Not A(Brand";v="24", "Microsoft Edge";v="110"',
180
+ 'sec-ch-ua-mobile': '?0',
181
+ 'sec-ch-ua-platform': '"Windows"',
182
+ 'sec-fetch-dest': 'empty',
183
+ 'sec-fetch-mode': 'cors',
184
+ 'sec-fetch-site': 'same-site',
185
+ }
186
+
187
+ data = {
188
+ 'bot_name': bot_name,
189
+ 'bot_response': bot_response,
190
+ 'voice_name': voice_name,
191
+ 'idle_url': idle_url,
192
+ 'animation_pipeline': animation_pipeline,
193
+ }
194
+
195
+ r = requests.post(url, headers=headers, data=json.dumps(data))
196
+
197
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f')
198
+ outfile = f'talking_head_{timestamp}.mp4'
199
+
200
+ with open(outfile, 'wb') as f:
201
+ f.write(r.content)
202
+
203
+ return outfile
204
+
205
+
206
+ def predict(self, input_data, state=[], k=4, input_type='audio'):
207
+ if input_type == 'audio':
208
+ txt = self.transcribe_audio(input_data[0], self.whisper_model)
209
+ else:
210
+ txt = input_data[1]
211
+ result, chat_history = self.ask_question(txt, k=k)
212
+ state.append(chat_history)
213
+ return result, state
214
+
215
+
216
+ def predict_wrapper(self, input_text=None, input_audio=None):
217
+ if input_audio is not None:
218
+ result, state = self.predict(
219
+ input_data=(input_audio,), input_type='audio')
220
+ else:
221
+ result, state = self.predict(
222
+ input_data=('', input_text), input_type='text')
223
+
224
+ response = result[0][1][:195]
225
+ avatar = self.invoke_exh_api(response)
226
+
227
+ return result,avatar
228
+
229
+
230
+ man_chatbot = ChatbotClass()
231
+
232
+ iface = gr.Interface(
233
+ fn=man_chatbot.predict_wrapper,
234
+ inputs=[gr.inputs.Textbox(label="Text Input"),
235
+ gr.inputs.Audio(source="microphone", type='filepath')],
236
+ outputs=[gr.outputs.Textbox(label="Result"),
237
+ gr.outputs.Video().style(height=100, container=True)]
238
+ )
239
+ iface.launch()
240
+
241
+ '''
242
+ iface.launch()
243
+ with gr.Blocks() as demo:
244
+ chatbot = gr.Chatbot()
245
+ state = gr.State([])
246
+
247
+ with gr.Row():
248
+ txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style(
249
+ container=False)
250
+ k_slider = gr.Slider(minimum=1, maximum=10, default=4,label='k')
251
+ txt.submit(man_chatbot.predict, [txt, state,k_slider],[chatbot,state])
252
+
253
+ demo.launch()
254
+ '''
data/eqe-manual/eqe-manual.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/eqe-manual/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f6cbbed9a851a19b27e78a27bea7157a10dce186fdfbd21dbd9c3bd5c2caa3e
3
+ size 3013677
data/eqe-manual/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:132e3b6c570654ae9a884ebf78f71a95ab9b7cda5ba2526f5fef5e3a1f1c39a4
3
+ size 1035885
data/eqs-manual/eqs-manual.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/eqs-manual/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e11585c41125f77a3cee43f28ea0ed061271e4976ec1efff02609ab9fe575df
3
+ size 3207213
data/eqs-manual/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b06c7df0a01e5454a333578672ef7b836fdcc197447a15ee711e2ebb08be5a51
3
+ size 1054884
data/s-class-manual/index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cc3da5a2b2c309ca81cc41a1cf1192c619c91ee1a1bfac39cf58b99bc8995fa
3
+ size 3062829
data/s-class-manual/index.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1749b9feae4721bbc0390b20cd9ceef84cb531d8fda7a44b89797e65b144af5
3
+ size 1023700
data/s-class-manual/s-class-manual.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ langchain
3
+ huggingface_hub
4
+ faiss-cpu
5
+ chardet
6
+ openapi-codec
7
+ pprintpp
8
+ EdgeGPT
9
+ sentence_transformers
10
+ plotly
11
+ openai
12
+ whisper