Spaces:
Runtime error
Runtime error
multiple models
Browse files
app.py
CHANGED
@@ -7,19 +7,22 @@ from transformers import pipeline
|
|
7 |
|
8 |
p = pipeline("automatic-speech-recognition")
|
9 |
|
10 |
-
"""Use text to call chat method from
|
11 |
-
|
|
|
|
|
|
|
12 |
print("Question asked: " + text)
|
13 |
-
response = run_model(text)
|
14 |
history = history + [(text, response)]
|
15 |
print(history)
|
16 |
return history, ""
|
17 |
|
18 |
|
19 |
-
def run_model(text):
|
20 |
start_time = time.time()
|
21 |
print("start time:" + str(start_time))
|
22 |
-
response = run(
|
23 |
end_time = time.time()
|
24 |
# If response contains string `SOURCES:`, then add a \n before `SOURCES`
|
25 |
if "SOURCES:" in response:
|
@@ -31,12 +34,12 @@ def run_model(text):
|
|
31 |
|
32 |
|
33 |
|
34 |
-
def get_output(history, audio):
|
35 |
|
36 |
txt = p(audio)["text"]
|
37 |
# history.append(( (audio, ) , txt))
|
38 |
audio_path = 'response.wav'
|
39 |
-
response = run_model(txt)
|
40 |
# Remove all text from SOURCES: to the end of the string
|
41 |
trimmed_response = response.split("SOURCES:")[0]
|
42 |
myobj = gTTS(text=trimmed_response, lang='en', slow=False)
|
@@ -48,34 +51,54 @@ def get_output(history, audio):
|
|
48 |
print(history)
|
49 |
return history
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
def bot(history):
|
52 |
return history
|
53 |
|
54 |
with gr.Blocks() as demo:
|
55 |
-
index()
|
56 |
-
chatbot = gr.Chatbot([(None,'Learn about <a href="https://www.coursera.org/learn/3d-printing-revolution/home">3D printing Revolution</a> course with referred sources. Try out the new voice to voice Q&A on the course! ')], elem_id="chatbot").style(height=750)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
58 |
with gr.Row():
|
59 |
-
with gr.Column(scale=0.
|
60 |
txt = gr.Textbox(
|
61 |
label="Coursera Voice Q&A Bot",
|
62 |
placeholder="Enter text and press enter, or upload an image", lines=1
|
63 |
).style(container=False)
|
64 |
|
65 |
-
with gr.Column(scale=0.
|
66 |
-
audio = gr.Audio(source="microphone", type="filepath")
|
67 |
|
68 |
-
txt.submit(add_text, [chatbot, txt], [chatbot, txt], postprocess=False).then(
|
69 |
bot, chatbot, chatbot
|
70 |
)
|
71 |
|
72 |
-
|
73 |
-
audio.change(fn=get_output, inputs=[chatbot, audio], outputs=[chatbot]).then(
|
74 |
bot, chatbot, chatbot
|
75 |
)
|
76 |
|
77 |
-
|
|
|
78 |
audio.change(lambda:None, None, audio)
|
79 |
|
|
|
|
|
80 |
if __name__ == "__main__":
|
|
|
|
|
81 |
demo.launch(debug=True)
|
|
|
7 |
|
8 |
p = pipeline("automatic-speech-recognition")
|
9 |
|
10 |
+
"""Use text to call chat method from main.py"""
|
11 |
+
|
12 |
+
models = ["GPT-3.5", "Flan UL2", "GPT-4", "Flan T5"]
|
13 |
+
|
14 |
+
def add_text(history, text, model):
|
15 |
print("Question asked: " + text)
|
16 |
+
response = run_model(text, model)
|
17 |
history = history + [(text, response)]
|
18 |
print(history)
|
19 |
return history, ""
|
20 |
|
21 |
|
22 |
+
def run_model(text, model):
|
23 |
start_time = time.time()
|
24 |
print("start time:" + str(start_time))
|
25 |
+
response = run(text, model)
|
26 |
end_time = time.time()
|
27 |
# If response contains string `SOURCES:`, then add a \n before `SOURCES`
|
28 |
if "SOURCES:" in response:
|
|
|
34 |
|
35 |
|
36 |
|
37 |
+
def get_output(history, audio, model):
|
38 |
|
39 |
txt = p(audio)["text"]
|
40 |
# history.append(( (audio, ) , txt))
|
41 |
audio_path = 'response.wav'
|
42 |
+
response = run_model(txt, model)
|
43 |
# Remove all text from SOURCES: to the end of the string
|
44 |
trimmed_response = response.split("SOURCES:")[0]
|
45 |
myobj = gTTS(text=trimmed_response, lang='en', slow=False)
|
|
|
51 |
print(history)
|
52 |
return history
|
53 |
|
54 |
+
def set_model(history, model):
|
55 |
+
print("Model selected: " + model)
|
56 |
+
history = get_first_message(history)
|
57 |
+
index(model)
|
58 |
+
return history
|
59 |
+
|
60 |
+
|
61 |
+
def get_first_message(history):
|
62 |
+
history = [(None,
|
63 |
+
'Learn about <a href="https://www.coursera.org/learn/3d-printing-revolution/home">3D printing Revolution</a> course with referred sources. Try out the new voice to voice Q&A on the course! ')]
|
64 |
+
return history
|
65 |
+
|
66 |
+
|
67 |
def bot(history):
|
68 |
return history
|
69 |
|
70 |
with gr.Blocks() as demo:
|
|
|
|
|
71 |
|
72 |
+
chatbot = gr.Chatbot(get_first_message([]), elem_id="chatbot").style(height=600)
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
# Create radio button to select model
|
76 |
+
radio = gr.Radio(models, label="Choose a model", value="GPT-3.5", type="value")
|
77 |
with gr.Row():
|
78 |
+
with gr.Column(scale=0.75):
|
79 |
txt = gr.Textbox(
|
80 |
label="Coursera Voice Q&A Bot",
|
81 |
placeholder="Enter text and press enter, or upload an image", lines=1
|
82 |
).style(container=False)
|
83 |
|
84 |
+
with gr.Column(scale=0.25):
|
85 |
+
audio = gr.Audio(source="microphone", type="filepath").style(container=False)
|
86 |
|
87 |
+
txt.submit(add_text, [chatbot, txt, radio], [chatbot, txt], postprocess=False).then(
|
88 |
bot, chatbot, chatbot
|
89 |
)
|
90 |
|
91 |
+
audio.change(fn=get_output, inputs=[chatbot, audio, radio], outputs=[chatbot]).then(
|
|
|
92 |
bot, chatbot, chatbot
|
93 |
)
|
94 |
|
95 |
+
radio.change(fn=set_model, inputs=[chatbot, radio], outputs=[chatbot]).then(bot, chatbot, chatbot)
|
96 |
+
|
97 |
audio.change(lambda:None, None, audio)
|
98 |
|
99 |
+
set_model(chatbot, radio.value)
|
100 |
+
|
101 |
if __name__ == "__main__":
|
102 |
+
demo.queue()
|
103 |
+
demo.queue(concurrency_count=5)
|
104 |
demo.launch(debug=True)
|
main.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
-
from utils import get_search_index,
|
2 |
|
3 |
-
def index():
|
4 |
-
|
|
|
5 |
return True
|
6 |
|
7 |
-
def run(question):
|
8 |
-
index()
|
9 |
return generate_answer(question)
|
|
|
1 |
+
from utils import get_search_index, generate_answer, set_model_and_embeddings
|
2 |
|
3 |
+
def index(model):
|
4 |
+
set_model_and_embeddings(model)
|
5 |
+
get_search_index(model)
|
6 |
return True
|
7 |
|
8 |
+
def run(question, model):
|
9 |
+
index(model)
|
10 |
return generate_answer(question)
|
requirements.txt
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
langchain
|
2 |
openai
|
3 |
faiss-cpu==1.7.3
|
4 |
unstructured==0.5.8
|
@@ -6,4 +6,5 @@ ffmpeg-python
|
|
6 |
transformers
|
7 |
gtts
|
8 |
torch
|
9 |
-
tiktoken
|
|
|
|
1 |
+
langchain
|
2 |
openai
|
3 |
faiss-cpu==1.7.3
|
4 |
unstructured==0.5.8
|
|
|
6 |
transformers
|
7 |
gtts
|
8 |
torch
|
9 |
+
tiktoken
|
10 |
+
huggingface-hub
|
utils.py
CHANGED
@@ -2,10 +2,11 @@ import os
|
|
2 |
import pickle
|
3 |
|
4 |
import faiss
|
|
|
5 |
from langchain.chains import ConversationalRetrievalChain
|
6 |
from langchain.chat_models import ChatOpenAI
|
7 |
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
|
8 |
-
from langchain.embeddings import OpenAIEmbeddings
|
9 |
from langchain.memory import ConversationBufferWindowMemory
|
10 |
from langchain.prompts.chat import (
|
11 |
ChatPromptTemplate,
|
@@ -16,24 +17,26 @@ from langchain.text_splitter import CharacterTextSplitter
|
|
16 |
from langchain.vectorstores.faiss import FAISS
|
17 |
|
18 |
|
19 |
-
|
20 |
-
index_file = "open_ai.index"
|
21 |
|
|
|
22 |
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
|
26 |
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
|
27 |
|
28 |
-
|
29 |
chat_history = []
|
30 |
|
31 |
-
memory = ConversationBufferWindowMemory(memory_key="chat_history")
|
32 |
|
33 |
-
|
34 |
|
35 |
system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can.
|
36 |
-
You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore.
|
37 |
Use the following pieces of context to answer the users question.
|
38 |
----------------
|
39 |
{context}"""
|
@@ -44,32 +47,82 @@ messages = [
|
|
44 |
]
|
45 |
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Load index from pickle file
|
51 |
-
with open(pickle_file, "rb") as f:
|
52 |
search_index = pickle.load(f)
|
|
|
53 |
else:
|
54 |
-
search_index = create_index()
|
|
|
55 |
|
56 |
-
|
57 |
return search_index
|
58 |
|
59 |
|
60 |
-
def create_index():
|
61 |
source_chunks = create_chunk_documents()
|
62 |
search_index = search_index_from_docs(source_chunks)
|
63 |
-
faiss.write_index(search_index.index, index_file)
|
64 |
# Save index to pickle file
|
65 |
-
with open(pickle_file, "wb") as f:
|
66 |
pickle.dump(search_index, f)
|
67 |
return search_index
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def search_index_from_docs(source_chunks):
|
71 |
# print("source chunks: " + str(len(source_chunks)))
|
72 |
# print("embeddings: " + str(embeddings))
|
|
|
73 |
search_index = FAISS.from_documents(source_chunks, embeddings)
|
74 |
return search_index
|
75 |
|
@@ -83,7 +136,7 @@ def get_html_files():
|
|
83 |
def fetch_data_for_embeddings():
|
84 |
document_list = get_text_files()
|
85 |
document_list.extend(get_html_files())
|
86 |
-
print("document list" + str(len(document_list)))
|
87 |
return document_list
|
88 |
|
89 |
|
@@ -100,20 +153,26 @@ def create_chunk_documents():
|
|
100 |
|
101 |
source_chunks = splitter.split_documents(sources)
|
102 |
|
103 |
-
print("
|
104 |
|
105 |
return source_chunks
|
106 |
|
107 |
|
108 |
-
def get_qa_chain(
|
109 |
-
global
|
|
|
|
|
110 |
# embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
|
111 |
# compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever())
|
112 |
-
|
|
|
|
|
|
|
113 |
verbose=True, get_chat_history=get_chat_history,
|
114 |
combine_docs_chain_kwargs={"prompt": CHAT_PROMPT})
|
115 |
return chain
|
116 |
|
|
|
117 |
def get_chat_history(inputs) -> str:
|
118 |
res = []
|
119 |
for human, ai in inputs:
|
@@ -122,18 +181,19 @@ def get_chat_history(inputs) -> str:
|
|
122 |
|
123 |
|
124 |
def generate_answer(question) -> str:
|
125 |
-
global chat_history,
|
126 |
-
|
127 |
|
128 |
-
result =
|
129 |
{"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}})
|
130 |
chat_history = [(question, result["answer"])]
|
131 |
sources = []
|
132 |
-
print(result
|
133 |
|
134 |
for document in result['source_documents']:
|
135 |
source = document.metadata['source']
|
136 |
sources.append(source.split('/')[-1].split('.')[0])
|
|
|
137 |
|
138 |
source = ',\n'.join(set(sources))
|
139 |
-
return result['answer'] + '\nSOURCES: ' + source
|
|
|
2 |
import pickle
|
3 |
|
4 |
import faiss
|
5 |
+
from langchain import HuggingFaceHub
|
6 |
from langchain.chains import ConversationalRetrievalChain
|
7 |
from langchain.chat_models import ChatOpenAI
|
8 |
from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
|
9 |
+
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings
|
10 |
from langchain.memory import ConversationBufferWindowMemory
|
11 |
from langchain.prompts.chat import (
|
12 |
ChatPromptTemplate,
|
|
|
17 |
from langchain.vectorstores.faiss import FAISS
|
18 |
|
19 |
|
20 |
+
global model_name
|
|
|
21 |
|
22 |
+
models = ["GPT-3.5", "Flan UL2", "GPT-4", "Flan T5"]
|
23 |
|
24 |
+
pickle_file = "_vs.pkl"
|
25 |
+
index_file = "_vs.index"
|
26 |
+
models_folder = "models/"
|
27 |
|
28 |
+
llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
|
29 |
|
30 |
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
|
31 |
|
|
|
32 |
chat_history = []
|
33 |
|
34 |
+
memory = ConversationBufferWindowMemory(memory_key="chat_history", k=10)
|
35 |
|
36 |
+
vectorstore_index = None
|
37 |
|
38 |
system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can.
|
39 |
+
You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore or context.
|
40 |
Use the following pieces of context to answer the users question.
|
41 |
----------------
|
42 |
{context}"""
|
|
|
47 |
]
|
48 |
CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
|
49 |
|
50 |
+
|
51 |
+
def set_model_and_embeddings(model):
|
52 |
+
global chat_history
|
53 |
+
set_model(model)
|
54 |
+
# set_embeddings(model)
|
55 |
+
chat_history = []
|
56 |
+
|
57 |
+
|
58 |
+
def set_model(model):
|
59 |
+
global llm
|
60 |
+
print("Setting model to " + str(model))
|
61 |
+
if model == "GPT-3.5":
|
62 |
+
print("Loading GPT-3.5")
|
63 |
+
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.1)
|
64 |
+
elif model == "GPT-4":
|
65 |
+
print("Loading GPT-4")
|
66 |
+
llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
|
67 |
+
elif model == "Flan UL2":
|
68 |
+
print("Loading Flan-UL2")
|
69 |
+
llm = HuggingFaceHub(repo_id="google/flan-ul2", model_kwargs={"temperature": 0.1, "max_new_tokens":500})
|
70 |
+
elif model == "Flan T5":
|
71 |
+
print("Loading Flan T5")
|
72 |
+
llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.1})
|
73 |
+
else:
|
74 |
+
print("Loading GPT-3.5 from else")
|
75 |
+
llm = ChatOpenAI(model_name="text-davinci-002", temperature=0.1)
|
76 |
+
|
77 |
+
|
78 |
+
def set_embeddings(model):
|
79 |
+
global embeddings
|
80 |
+
if model == "GPT-3.5" or model == "GPT-4":
|
81 |
+
print("Loading OpenAI embeddings")
|
82 |
+
embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
|
83 |
+
elif model == "Flan UL2" or model == "Flan T5":
|
84 |
+
print("Loading Hugging Face embeddings")
|
85 |
+
embeddings = HuggingFaceHubEmbeddings(repo_id="sentence-transformers/all-MiniLM-L6-v2")
|
86 |
+
|
87 |
+
|
88 |
+
def get_search_index(model):
|
89 |
+
global vectorstore_index
|
90 |
+
if os.path.isfile(get_file_path(model, pickle_file)) and os.path.isfile(
|
91 |
+
get_file_path(model, index_file)) and os.path.getsize(get_file_path(model, pickle_file)) > 0:
|
92 |
# Load index from pickle file
|
93 |
+
with open(get_file_path(model, pickle_file), "rb") as f:
|
94 |
search_index = pickle.load(f)
|
95 |
+
print("Loaded index")
|
96 |
else:
|
97 |
+
search_index = create_index(model)
|
98 |
+
print("Created index")
|
99 |
|
100 |
+
vectorstore_index = search_index
|
101 |
return search_index
|
102 |
|
103 |
|
104 |
+
def create_index(model):
|
105 |
source_chunks = create_chunk_documents()
|
106 |
search_index = search_index_from_docs(source_chunks)
|
107 |
+
faiss.write_index(search_index.index, get_file_path(model, index_file))
|
108 |
# Save index to pickle file
|
109 |
+
with open(get_file_path(model, pickle_file), "wb") as f:
|
110 |
pickle.dump(search_index, f)
|
111 |
return search_index
|
112 |
|
113 |
|
114 |
+
def get_file_path(model, file):
|
115 |
+
# If model is GPT3.5 or GPT4 return models_folder + openai + file else return models_folder + hf + file
|
116 |
+
if model == "GPT-3.5" or model == "GPT-4":
|
117 |
+
return models_folder + "openai" + file
|
118 |
+
else:
|
119 |
+
return models_folder + "hf" + file
|
120 |
+
|
121 |
+
|
122 |
def search_index_from_docs(source_chunks):
|
123 |
# print("source chunks: " + str(len(source_chunks)))
|
124 |
# print("embeddings: " + str(embeddings))
|
125 |
+
|
126 |
search_index = FAISS.from_documents(source_chunks, embeddings)
|
127 |
return search_index
|
128 |
|
|
|
136 |
def fetch_data_for_embeddings():
|
137 |
document_list = get_text_files()
|
138 |
document_list.extend(get_html_files())
|
139 |
+
print("document list: " + str(len(document_list)))
|
140 |
return document_list
|
141 |
|
142 |
|
|
|
153 |
|
154 |
source_chunks = splitter.split_documents(sources)
|
155 |
|
156 |
+
print("chunks: " + str(len(source_chunks)))
|
157 |
|
158 |
return source_chunks
|
159 |
|
160 |
|
161 |
+
def get_qa_chain(vectorstore_index):
|
162 |
+
global llm, model_name
|
163 |
+
print(llm)
|
164 |
+
|
165 |
# embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
|
166 |
# compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever())
|
167 |
+
retriever = vectorstore_index.as_retriever(search_type="similarity_score_threshold",
|
168 |
+
search_kwargs={"score_threshold": .5})
|
169 |
+
|
170 |
+
chain = ConversationalRetrievalChain.from_llm(llm, retriever, return_source_documents=True,
|
171 |
verbose=True, get_chat_history=get_chat_history,
|
172 |
combine_docs_chain_kwargs={"prompt": CHAT_PROMPT})
|
173 |
return chain
|
174 |
|
175 |
+
|
176 |
def get_chat_history(inputs) -> str:
|
177 |
res = []
|
178 |
for human, ai in inputs:
|
|
|
181 |
|
182 |
|
183 |
def generate_answer(question) -> str:
|
184 |
+
global chat_history, vectorstore_index
|
185 |
+
chain = get_qa_chain(vectorstore_index)
|
186 |
|
187 |
+
result = chain(
|
188 |
{"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}})
|
189 |
chat_history = [(question, result["answer"])]
|
190 |
sources = []
|
191 |
+
print(result)
|
192 |
|
193 |
for document in result['source_documents']:
|
194 |
source = document.metadata['source']
|
195 |
sources.append(source.split('/')[-1].split('.')[0])
|
196 |
+
print(sources)
|
197 |
|
198 |
source = ',\n'.join(set(sources))
|
199 |
+
return result['answer'] + '\nSOURCES: ' + source
|