rohan13 commited on
Commit
129d2e2
β€’
1 Parent(s): 5a7a629

multiple models

Browse files
Files changed (4) hide show
  1. app.py +39 -16
  2. main.py +6 -5
  3. requirements.txt +3 -2
  4. utils.py +87 -27
app.py CHANGED
@@ -7,19 +7,22 @@ from transformers import pipeline
7
 
8
  p = pipeline("automatic-speech-recognition")
9
 
10
- """Use text to call chat method from main_old.py"""
11
- def add_text(history, text):
 
 
 
12
  print("Question asked: " + text)
13
- response = run_model(text)
14
  history = history + [(text, response)]
15
  print(history)
16
  return history, ""
17
 
18
 
19
- def run_model(text):
20
  start_time = time.time()
21
  print("start time:" + str(start_time))
22
- response = run(question=text)
23
  end_time = time.time()
24
  # If response contains string `SOURCES:`, then add a \n before `SOURCES`
25
  if "SOURCES:" in response:
@@ -31,12 +34,12 @@ def run_model(text):
31
 
32
 
33
 
34
- def get_output(history, audio):
35
 
36
  txt = p(audio)["text"]
37
  # history.append(( (audio, ) , txt))
38
  audio_path = 'response.wav'
39
- response = run_model(txt)
40
  # Remove all text from SOURCES: to the end of the string
41
  trimmed_response = response.split("SOURCES:")[0]
42
  myobj = gTTS(text=trimmed_response, lang='en', slow=False)
@@ -48,34 +51,54 @@ def get_output(history, audio):
48
  print(history)
49
  return history
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  def bot(history):
52
  return history
53
 
54
  with gr.Blocks() as demo:
55
- index()
56
- chatbot = gr.Chatbot([(None,'Learn about <a href="https://www.coursera.org/learn/3d-printing-revolution/home">3D printing Revolution</a> course with referred sources. Try out the new voice to voice Q&A on the course! ')], elem_id="chatbot").style(height=750)
57
 
 
 
 
 
 
58
  with gr.Row():
59
- with gr.Column(scale=0.85):
60
  txt = gr.Textbox(
61
  label="Coursera Voice Q&A Bot",
62
  placeholder="Enter text and press enter, or upload an image", lines=1
63
  ).style(container=False)
64
 
65
- with gr.Column(scale=0.15):
66
- audio = gr.Audio(source="microphone", type="filepath")
67
 
68
- txt.submit(add_text, [chatbot, txt], [chatbot, txt], postprocess=False).then(
69
  bot, chatbot, chatbot
70
  )
71
 
72
-
73
- audio.change(fn=get_output, inputs=[chatbot, audio], outputs=[chatbot]).then(
74
  bot, chatbot, chatbot
75
  )
76
 
77
- print(audio)
 
78
  audio.change(lambda:None, None, audio)
79
 
 
 
80
  if __name__ == "__main__":
 
 
81
  demo.launch(debug=True)
 
7
 
8
  p = pipeline("automatic-speech-recognition")
9
 
10
+ """Use text to call chat method from main.py"""
11
+
12
+ models = ["GPT-3.5", "Flan UL2", "GPT-4", "Flan T5"]
13
+
14
+ def add_text(history, text, model):
15
  print("Question asked: " + text)
16
+ response = run_model(text, model)
17
  history = history + [(text, response)]
18
  print(history)
19
  return history, ""
20
 
21
 
22
+ def run_model(text, model):
23
  start_time = time.time()
24
  print("start time:" + str(start_time))
25
+ response = run(text, model)
26
  end_time = time.time()
27
  # If response contains string `SOURCES:`, then add a \n before `SOURCES`
28
  if "SOURCES:" in response:
 
34
 
35
 
36
 
37
+ def get_output(history, audio, model):
38
 
39
  txt = p(audio)["text"]
40
  # history.append(( (audio, ) , txt))
41
  audio_path = 'response.wav'
42
+ response = run_model(txt, model)
43
  # Remove all text from SOURCES: to the end of the string
44
  trimmed_response = response.split("SOURCES:")[0]
45
  myobj = gTTS(text=trimmed_response, lang='en', slow=False)
 
51
  print(history)
52
  return history
53
 
54
+ def set_model(history, model):
55
+ print("Model selected: " + model)
56
+ history = get_first_message(history)
57
+ index(model)
58
+ return history
59
+
60
+
61
+ def get_first_message(history):
62
+ history = [(None,
63
+ 'Learn about <a href="https://www.coursera.org/learn/3d-printing-revolution/home">3D printing Revolution</a> course with referred sources. Try out the new voice to voice Q&A on the course! ')]
64
+ return history
65
+
66
+
67
  def bot(history):
68
  return history
69
 
70
  with gr.Blocks() as demo:
 
 
71
 
72
+ chatbot = gr.Chatbot(get_first_message([]), elem_id="chatbot").style(height=600)
73
+
74
+ with gr.Row():
75
+ # Create radio button to select model
76
+ radio = gr.Radio(models, label="Choose a model", value="GPT-3.5", type="value")
77
  with gr.Row():
78
+ with gr.Column(scale=0.75):
79
  txt = gr.Textbox(
80
  label="Coursera Voice Q&A Bot",
81
  placeholder="Enter text and press enter, or upload an image", lines=1
82
  ).style(container=False)
83
 
84
+ with gr.Column(scale=0.25):
85
+ audio = gr.Audio(source="microphone", type="filepath").style(container=False)
86
 
87
+ txt.submit(add_text, [chatbot, txt, radio], [chatbot, txt], postprocess=False).then(
88
  bot, chatbot, chatbot
89
  )
90
 
91
+ audio.change(fn=get_output, inputs=[chatbot, audio, radio], outputs=[chatbot]).then(
 
92
  bot, chatbot, chatbot
93
  )
94
 
95
+ radio.change(fn=set_model, inputs=[chatbot, radio], outputs=[chatbot]).then(bot, chatbot, chatbot)
96
+
97
  audio.change(lambda:None, None, audio)
98
 
99
+ set_model(chatbot, radio.value)
100
+
101
  if __name__ == "__main__":
102
+ demo.queue()
103
+ demo.queue(concurrency_count=5)
104
  demo.launch(debug=True)
main.py CHANGED
@@ -1,9 +1,10 @@
1
- from utils import get_search_index, get_qa_chain, generate_answer
2
 
3
- def index():
4
- get_search_index()
 
5
  return True
6
 
7
- def run(question):
8
- index()
9
  return generate_answer(question)
 
1
+ from utils import get_search_index, generate_answer, set_model_and_embeddings
2
 
3
+ def index(model):
4
+ set_model_and_embeddings(model)
5
+ get_search_index(model)
6
  return True
7
 
8
+ def run(question, model):
9
+ index(model)
10
  return generate_answer(question)
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- langchain==0.0.166
2
  openai
3
  faiss-cpu==1.7.3
4
  unstructured==0.5.8
@@ -6,4 +6,5 @@ ffmpeg-python
6
  transformers
7
  gtts
8
  torch
9
- tiktoken
 
 
1
+ langchain
2
  openai
3
  faiss-cpu==1.7.3
4
  unstructured==0.5.8
 
6
  transformers
7
  gtts
8
  torch
9
+ tiktoken
10
+ huggingface-hub
utils.py CHANGED
@@ -2,10 +2,11 @@ import os
2
  import pickle
3
 
4
  import faiss
 
5
  from langchain.chains import ConversationalRetrievalChain
6
  from langchain.chat_models import ChatOpenAI
7
  from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
8
- from langchain.embeddings import OpenAIEmbeddings
9
  from langchain.memory import ConversationBufferWindowMemory
10
  from langchain.prompts.chat import (
11
  ChatPromptTemplate,
@@ -16,24 +17,26 @@ from langchain.text_splitter import CharacterTextSplitter
16
  from langchain.vectorstores.faiss import FAISS
17
 
18
 
19
- pickle_file = "open_ai.pkl"
20
- index_file = "open_ai.index"
21
 
 
22
 
 
 
 
23
 
24
- gpt_3_5 = ChatOpenAI(model_name='gpt-4',temperature=0.1)
25
 
26
  embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
27
 
28
-
29
  chat_history = []
30
 
31
- memory = ConversationBufferWindowMemory(memory_key="chat_history")
32
 
33
- gpt_3_5_index = None
34
 
35
  system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can.
36
- You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore.
37
  Use the following pieces of context to answer the users question.
38
  ----------------
39
  {context}"""
@@ -44,32 +47,82 @@ messages = [
44
  ]
45
  CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
46
 
47
- def get_search_index():
48
- global gpt_3_5_index
49
- if os.path.isfile(pickle_file) and os.path.isfile(index_file) and os.path.getsize(pickle_file) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Load index from pickle file
51
- with open(pickle_file, "rb") as f:
52
  search_index = pickle.load(f)
 
53
  else:
54
- search_index = create_index()
 
55
 
56
- gpt_3_5_index = search_index
57
  return search_index
58
 
59
 
60
- def create_index():
61
  source_chunks = create_chunk_documents()
62
  search_index = search_index_from_docs(source_chunks)
63
- faiss.write_index(search_index.index, index_file)
64
  # Save index to pickle file
65
- with open(pickle_file, "wb") as f:
66
  pickle.dump(search_index, f)
67
  return search_index
68
 
69
 
 
 
 
 
 
 
 
 
70
  def search_index_from_docs(source_chunks):
71
  # print("source chunks: " + str(len(source_chunks)))
72
  # print("embeddings: " + str(embeddings))
 
73
  search_index = FAISS.from_documents(source_chunks, embeddings)
74
  return search_index
75
 
@@ -83,7 +136,7 @@ def get_html_files():
83
  def fetch_data_for_embeddings():
84
  document_list = get_text_files()
85
  document_list.extend(get_html_files())
86
- print("document list" + str(len(document_list)))
87
  return document_list
88
 
89
 
@@ -100,20 +153,26 @@ def create_chunk_documents():
100
 
101
  source_chunks = splitter.split_documents(sources)
102
 
103
- print("sources" + str(len(source_chunks)))
104
 
105
  return source_chunks
106
 
107
 
108
- def get_qa_chain(gpt_3_5_index):
109
- global gpt_3_5
 
 
110
  # embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
111
  # compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever())
112
- chain = ConversationalRetrievalChain.from_llm(gpt_3_5, gpt_3_5_index.as_retriever(), return_source_documents=True,
 
 
 
113
  verbose=True, get_chat_history=get_chat_history,
114
  combine_docs_chain_kwargs={"prompt": CHAT_PROMPT})
115
  return chain
116
 
 
117
  def get_chat_history(inputs) -> str:
118
  res = []
119
  for human, ai in inputs:
@@ -122,18 +181,19 @@ def get_chat_history(inputs) -> str:
122
 
123
 
124
  def generate_answer(question) -> str:
125
- global chat_history, gpt_3_5_index
126
- gpt_3_5_chain = get_qa_chain(gpt_3_5_index)
127
 
128
- result = gpt_3_5_chain(
129
  {"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}})
130
  chat_history = [(question, result["answer"])]
131
  sources = []
132
- print(result['answer'])
133
 
134
  for document in result['source_documents']:
135
  source = document.metadata['source']
136
  sources.append(source.split('/')[-1].split('.')[0])
 
137
 
138
  source = ',\n'.join(set(sources))
139
- return result['answer'] + '\nSOURCES: ' + source
 
2
  import pickle
3
 
4
  import faiss
5
+ from langchain import HuggingFaceHub
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.document_loaders import DirectoryLoader, TextLoader, UnstructuredHTMLLoader
9
+ from langchain.embeddings import OpenAIEmbeddings, HuggingFaceHubEmbeddings
10
  from langchain.memory import ConversationBufferWindowMemory
11
  from langchain.prompts.chat import (
12
  ChatPromptTemplate,
 
17
  from langchain.vectorstores.faiss import FAISS
18
 
19
 
20
+ global model_name
 
21
 
22
+ models = ["GPT-3.5", "Flan UL2", "GPT-4", "Flan T5"]
23
 
24
+ pickle_file = "_vs.pkl"
25
+ index_file = "_vs.index"
26
+ models_folder = "models/"
27
 
28
+ llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
29
 
30
  embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
31
 
 
32
  chat_history = []
33
 
34
+ memory = ConversationBufferWindowMemory(memory_key="chat_history", k=10)
35
 
36
+ vectorstore_index = None
37
 
38
  system_template = """You are Coursera QA Bot. Have a conversation with a human, answering the following questions as best you can.
39
+ You are a teaching assistant for a Coursera Course: The 3D Printing Evolution and can answer any question about that using vectorstore or context.
40
  Use the following pieces of context to answer the users question.
41
  ----------------
42
  {context}"""
 
47
  ]
48
  CHAT_PROMPT = ChatPromptTemplate.from_messages(messages)
49
 
50
+
51
+ def set_model_and_embeddings(model):
52
+ global chat_history
53
+ set_model(model)
54
+ # set_embeddings(model)
55
+ chat_history = []
56
+
57
+
58
+ def set_model(model):
59
+ global llm
60
+ print("Setting model to " + str(model))
61
+ if model == "GPT-3.5":
62
+ print("Loading GPT-3.5")
63
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0.1)
64
+ elif model == "GPT-4":
65
+ print("Loading GPT-4")
66
+ llm = ChatOpenAI(model_name="gpt-4", temperature=0.1)
67
+ elif model == "Flan UL2":
68
+ print("Loading Flan-UL2")
69
+ llm = HuggingFaceHub(repo_id="google/flan-ul2", model_kwargs={"temperature": 0.1, "max_new_tokens":500})
70
+ elif model == "Flan T5":
71
+ print("Loading Flan T5")
72
+ llm = HuggingFaceHub(repo_id="google/flan-t5-base", model_kwargs={"temperature": 0.1})
73
+ else:
74
+ print("Loading GPT-3.5 from else")
75
+ llm = ChatOpenAI(model_name="text-davinci-002", temperature=0.1)
76
+
77
+
78
+ def set_embeddings(model):
79
+ global embeddings
80
+ if model == "GPT-3.5" or model == "GPT-4":
81
+ print("Loading OpenAI embeddings")
82
+ embeddings = OpenAIEmbeddings(model='text-embedding-ada-002')
83
+ elif model == "Flan UL2" or model == "Flan T5":
84
+ print("Loading Hugging Face embeddings")
85
+ embeddings = HuggingFaceHubEmbeddings(repo_id="sentence-transformers/all-MiniLM-L6-v2")
86
+
87
+
88
+ def get_search_index(model):
89
+ global vectorstore_index
90
+ if os.path.isfile(get_file_path(model, pickle_file)) and os.path.isfile(
91
+ get_file_path(model, index_file)) and os.path.getsize(get_file_path(model, pickle_file)) > 0:
92
  # Load index from pickle file
93
+ with open(get_file_path(model, pickle_file), "rb") as f:
94
  search_index = pickle.load(f)
95
+ print("Loaded index")
96
  else:
97
+ search_index = create_index(model)
98
+ print("Created index")
99
 
100
+ vectorstore_index = search_index
101
  return search_index
102
 
103
 
104
+ def create_index(model):
105
  source_chunks = create_chunk_documents()
106
  search_index = search_index_from_docs(source_chunks)
107
+ faiss.write_index(search_index.index, get_file_path(model, index_file))
108
  # Save index to pickle file
109
+ with open(get_file_path(model, pickle_file), "wb") as f:
110
  pickle.dump(search_index, f)
111
  return search_index
112
 
113
 
114
+ def get_file_path(model, file):
115
+ # If model is GPT3.5 or GPT4 return models_folder + openai + file else return models_folder + hf + file
116
+ if model == "GPT-3.5" or model == "GPT-4":
117
+ return models_folder + "openai" + file
118
+ else:
119
+ return models_folder + "hf" + file
120
+
121
+
122
  def search_index_from_docs(source_chunks):
123
  # print("source chunks: " + str(len(source_chunks)))
124
  # print("embeddings: " + str(embeddings))
125
+
126
  search_index = FAISS.from_documents(source_chunks, embeddings)
127
  return search_index
128
 
 
136
  def fetch_data_for_embeddings():
137
  document_list = get_text_files()
138
  document_list.extend(get_html_files())
139
+ print("document list: " + str(len(document_list)))
140
  return document_list
141
 
142
 
 
153
 
154
  source_chunks = splitter.split_documents(sources)
155
 
156
+ print("chunks: " + str(len(source_chunks)))
157
 
158
  return source_chunks
159
 
160
 
161
+ def get_qa_chain(vectorstore_index):
162
+ global llm, model_name
163
+ print(llm)
164
+
165
  # embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.76)
166
  # compression_retriever = ContextualCompressionRetriever(base_compressor=embeddings_filter, base_retriever=gpt_3_5_index.as_retriever())
167
+ retriever = vectorstore_index.as_retriever(search_type="similarity_score_threshold",
168
+ search_kwargs={"score_threshold": .5})
169
+
170
+ chain = ConversationalRetrievalChain.from_llm(llm, retriever, return_source_documents=True,
171
  verbose=True, get_chat_history=get_chat_history,
172
  combine_docs_chain_kwargs={"prompt": CHAT_PROMPT})
173
  return chain
174
 
175
+
176
  def get_chat_history(inputs) -> str:
177
  res = []
178
  for human, ai in inputs:
 
181
 
182
 
183
  def generate_answer(question) -> str:
184
+ global chat_history, vectorstore_index
185
+ chain = get_qa_chain(vectorstore_index)
186
 
187
+ result = chain(
188
  {"question": question, "chat_history": chat_history, "vectordbkwargs": {"search_distance": 0.6}})
189
  chat_history = [(question, result["answer"])]
190
  sources = []
191
+ print(result)
192
 
193
  for document in result['source_documents']:
194
  source = document.metadata['source']
195
  sources.append(source.split('/')[-1].split('.')[0])
196
+ print(sources)
197
 
198
  source = ',\n'.join(set(sources))
199
+ return result['answer'] + '\nSOURCES: ' + source