awinml commited on
Commit
9c49e99
·
1 Parent(s): e375940

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +64 -22
  2. utils.py +126 -5
app.py CHANGED
@@ -1,17 +1,22 @@
1
  import openai
 
2
  import streamlit_scrollable_textbox as stx
3
 
4
- import pinecone
5
  import streamlit as st
6
  from utils import (
 
7
  create_dense_embeddings,
8
  create_sparse_embeddings,
 
9
  format_query,
10
- generate_prompt,
 
 
11
  get_data,
12
  get_flan_t5_model,
13
  get_mpnet_embedding_model,
14
  get_sgpt_embedding_model,
 
15
  get_splade_sparse_embedding_model,
16
  get_t5_model,
17
  gpt_model,
@@ -24,7 +29,7 @@ from utils import (
24
  text_lookup,
25
  )
26
 
27
- st.set_page_config(layout="wide")
28
 
29
 
30
  st.title("Abstractive Question Answering")
@@ -36,21 +41,31 @@ st.write(
36
 
37
  col1, col2 = st.columns([3, 3], gap="medium")
38
 
 
 
 
39
  with col1:
40
  st.subheader("Question")
41
  query_text = st.text_input(
42
  "Input Query",
43
- value="What was discussed regarding Wearables revenue performance?",
44
  )
45
 
 
 
 
 
 
46
  with col1:
47
  years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
48
 
49
  with col1:
50
- year = st.selectbox("Year", years_choice)
51
 
52
  with col1:
53
- quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"])
 
 
54
 
55
  with col1:
56
  participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
@@ -69,7 +84,7 @@ ticker_choice = [
69
  ]
70
 
71
  with col1:
72
- ticker = st.selectbox("Company", ticker_choice)
73
 
74
  with st.sidebar:
75
  st.subheader("Select Options:")
@@ -189,9 +204,8 @@ else:
189
  context_list = format_query(query_results)
190
 
191
 
192
- prompt = generate_prompt(query_text, context_list)
193
-
194
  if decoder_model == "GPT3 - (text-davinci-003)":
 
195
  with col2:
196
  with st.form("my_form"):
197
  edited_prompt = st.text_area(
@@ -208,29 +222,57 @@ if decoder_model == "GPT3 - (text-davinci-003)":
208
  api_key = save_key(openai_key)
209
  openai.api_key = api_key
210
  generated_text = gpt_model(edited_prompt)
211
- with col2:
212
- st.subheader("Answer:")
213
- st.write(generated_text)
214
 
215
  elif decoder_model == "T5":
 
216
  t5_pipeline = get_t5_model()
217
  output_text = []
218
- for context_text in context_list:
219
- output_text.append(t5_pipeline(context_text)[0]["summary_text"])
220
  with col2:
221
- st.subheader("Answer:")
222
- for text in output_text:
223
- st.markdown(f"- {text}")
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  elif decoder_model == "FLAN-T5":
 
226
  flan_t5_pipeline = get_flan_t5_model()
227
  output_text = []
228
- for context_text in context_list:
229
- output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
230
  with col2:
231
- st.subheader("Answer:")
232
- for text in output_text:
233
- st.markdown(f"- {text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  with col1:
236
  with st.expander("See Retrieved Text"):
 
1
  import openai
2
+ import pinecone
3
  import streamlit_scrollable_textbox as stx
4
 
 
5
  import streamlit as st
6
  from utils import (
7
+ clean_entities,
8
  create_dense_embeddings,
9
  create_sparse_embeddings,
10
+ extract_entities,
11
  format_query,
12
+ generate_flant5_prompt,
13
+ generate_gpt_prompt,
14
+ get_context_list_prompt,
15
  get_data,
16
  get_flan_t5_model,
17
  get_mpnet_embedding_model,
18
  get_sgpt_embedding_model,
19
+ get_spacy_model,
20
  get_splade_sparse_embedding_model,
21
  get_t5_model,
22
  gpt_model,
 
29
  text_lookup,
30
  )
31
 
32
+ st.set_page_config(layout="wide") # isort: skip
33
 
34
 
35
  st.title("Abstractive Question Answering")
 
41
 
42
  col1, col2 = st.columns([3, 3], gap="medium")
43
 
44
+
45
+ spacy_model = get_spacy_model()
46
+
47
  with col1:
48
  st.subheader("Question")
49
  query_text = st.text_input(
50
  "Input Query",
51
+ value="What was discussed regarding Wearables revenue performance in Q1 2020?",
52
  )
53
 
54
+ company_ent, quarter_ent, year_ent = extract_entities(query_text, spacy_model)
55
+ ticker_index, quarter_index, year_index = clean_entities(
56
+ company_ent, quarter_ent, year_ent
57
+ )
58
+
59
  with col1:
60
  years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
61
 
62
  with col1:
63
+ year = st.selectbox("Year", years_choice, index=year_index)
64
 
65
  with col1:
66
+ quarter = st.selectbox(
67
+ "Quarter", ["Q1", "Q2", "Q3", "Q4", "All"], index=quarter_index
68
+ )
69
 
70
  with col1:
71
  participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"])
 
84
  ]
85
 
86
  with col1:
87
+ ticker = st.selectbox("Company", ticker_choice, ticker_index)
88
 
89
  with st.sidebar:
90
  st.subheader("Select Options:")
 
204
  context_list = format_query(query_results)
205
 
206
 
 
 
207
  if decoder_model == "GPT3 - (text-davinci-003)":
208
+ prompt = generate_gpt_prompt(query_text, context_list)
209
  with col2:
210
  with st.form("my_form"):
211
  edited_prompt = st.text_area(
 
222
  api_key = save_key(openai_key)
223
  openai.api_key = api_key
224
  generated_text = gpt_model(edited_prompt)
225
+ st.subheader("Answer:")
226
+ st.write(generated_text)
227
+
228
 
229
  elif decoder_model == "T5":
230
+ prompt = generate_flant5_prompt(query_text, context_list)
231
  t5_pipeline = get_t5_model()
232
  output_text = []
 
 
233
  with col2:
234
+ with st.form("my_form"):
235
+ edited_prompt = st.text_area(
236
+ label="Model Prompt", value=prompt, height=270
237
+ )
238
+ context_list = get_context_list_prompt(edited_prompt)
239
+ submitted = st.form_submit_button("Submit")
240
+ if submitted:
241
+ for context_text in context_list:
242
+ output_text.append(
243
+ t5_pipeline(context_text)[0]["summary_text"]
244
+ )
245
+ st.subheader("Answer:")
246
+ for text in output_text:
247
+ st.markdown(f"- {text}")
248
 
249
  elif decoder_model == "FLAN-T5":
250
+ prompt = generate_flant5_prompt(query_text, context_list)
251
  flan_t5_pipeline = get_flan_t5_model()
252
  output_text = []
 
 
253
  with col2:
254
+ with st.form("my_form"):
255
+ edited_prompt = st.text_area(
256
+ label="Model Prompt", value=prompt, height=270
257
+ )
258
+ context_list = get_context_list_prompt(edited_prompt)
259
+ submitted = st.form_submit_button("Submit")
260
+ if submitted:
261
+ for context_text in context_list:
262
+ output_text.append(
263
+ flan_t5_pipeline(
264
+ "Question:"
265
+ + query_text
266
+ + "\nContext:"
267
+ + context_text
268
+ + "\nAnswer?"
269
+ )[0]["summary_text"]
270
+ )
271
+ st.subheader("Answer:")
272
+ for text in output_text:
273
+ if "(iii)" not in text:
274
+ st.markdown(f"- {text}")
275
+
276
 
277
  with col1:
278
  with st.expander("See Retrieved Text"):
utils.py CHANGED
@@ -1,5 +1,9 @@
 
 
1
  import openai
2
  import pandas as pd
 
 
3
  import streamlit_scrollable_textbox as stx
4
  import torch
5
  from sentence_transformers import SentenceTransformer
@@ -11,7 +15,6 @@ from transformers import (
11
  pipeline,
12
  )
13
 
14
- import pinecone
15
  import streamlit as st
16
 
17
 
@@ -21,6 +24,14 @@ def get_data():
21
  return data
22
 
23
 
 
 
 
 
 
 
 
 
24
  # Initialize models from HuggingFace
25
 
26
 
@@ -33,8 +44,8 @@ def get_t5_model():
33
  def get_flan_t5_model():
34
  return pipeline(
35
  "summarization",
36
- model="google/flan-t5-small",
37
- tokenizer="google/flan-t5-small",
38
  max_length=512,
39
  # length_penalty = 0
40
  )
@@ -320,7 +331,7 @@ def text_lookup(data, sentence_ids):
320
  return context
321
 
322
 
323
- def generate_prompt(query_text, context_list):
324
  context = " ".join(context_list)
325
  prompt = f"""Answer the question in 6 long detailed points as accurately as possible using the provided context. Include as many key details as possible.
326
  Context: {context}
@@ -329,7 +340,7 @@ Answer:"""
329
  return prompt
330
 
331
 
332
- def generate_prompt_2(query_text, context_list):
333
  context = " ".join(context_list)
334
  prompt = f"""
335
  Context information is below:
@@ -342,6 +353,24 @@ def generate_prompt_2(query_text, context_list):
342
  return prompt
343
 
344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  def gpt_model(prompt):
346
  response = openai.Completion.create(
347
  model="text-davinci-003",
@@ -355,6 +384,98 @@ def gpt_model(prompt):
355
  return response.choices[0].text
356
 
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  # Transcript Retrieval
359
 
360
 
 
1
+ import re
2
+
3
  import openai
4
  import pandas as pd
5
+ import pinecone
6
+ import spacy
7
  import streamlit_scrollable_textbox as stx
8
  import torch
9
  from sentence_transformers import SentenceTransformer
 
15
  pipeline,
16
  )
17
 
 
18
  import streamlit as st
19
 
20
 
 
24
  return data
25
 
26
 
27
+ # Initialize Spacy Model
28
+
29
+
30
+ @st.experimental_singleton
31
+ def get_spacy_model():
32
+ return spacy.load("en_core_web_sm")
33
+
34
+
35
  # Initialize models from HuggingFace
36
 
37
 
 
44
  def get_flan_t5_model():
45
  return pipeline(
46
  "summarization",
47
+ model="google/flan-t5-xl",
48
+ tokenizer="google/flan-t5-xl",
49
  max_length=512,
50
  # length_penalty = 0
51
  )
 
331
  return context
332
 
333
 
334
+ def generate_gpt_prompt(query_text, context_list):
335
  context = " ".join(context_list)
336
  prompt = f"""Answer the question in 6 long detailed points as accurately as possible using the provided context. Include as many key details as possible.
337
  Context: {context}
 
340
  return prompt
341
 
342
 
343
+ def generate_gpt_prompt_2(query_text, context_list):
344
  context = " ".join(context_list)
345
  prompt = f"""
346
  Context information is below:
 
353
  return prompt
354
 
355
 
356
+ def generate_flant5_prompt(query_text, context_list):
357
+ context = " \n".join(context_list)
358
+ prompt = f"""Given the context information and prior knowledge, answer this question:
359
+ {query_text}
360
+ Context information is below:
361
+ ---------------------
362
+ {context}
363
+ ---------------------"""
364
+ return prompt
365
+
366
+
367
+ def get_context_list_prompt(prompt):
368
+ prompt_list = prompt.split("---------------------")
369
+ context = prompt_list[-2].strip()
370
+ context_list = context.split(" \n")
371
+ return context_list
372
+
373
+
374
  def gpt_model(prompt):
375
  response = openai.Completion.create(
376
  model="text-davinci-003",
 
384
  return response.choices[0].text
385
 
386
 
387
+ # Entity Extraction
388
+
389
+
390
+ def extract_quarter_year(string):
391
+ # Extract year from string
392
+ year_match = re.search(r"\d{4}", string)
393
+ if year_match:
394
+ year = year_match.group()
395
+ else:
396
+ return None, None
397
+
398
+ # Extract quarter from string
399
+ quarter_match = re.search(r"Q\d", string)
400
+ if quarter_match:
401
+ quarter = "Q" + quarter_match.group()[1]
402
+ else:
403
+ return None, None
404
+
405
+ return quarter, year
406
+
407
+
408
+ def extract_entities(query, model):
409
+ doc = model(query)
410
+ entities = {ent.label_: ent.text for ent in doc.ents}
411
+ if "ORG" in entities.keys():
412
+ company = entities["ORG"].lower()
413
+ if "DATE" in entities.keys():
414
+ quarter, year = extract_quarter_year(entities["DATE"])
415
+ return company, quarter, year
416
+ else:
417
+ return company, None, None
418
+ else:
419
+ if "DATE" in entities.keys():
420
+ quarter, year = extract_quarter_year(entities["DATE"])
421
+ return None, quarter, year
422
+ else:
423
+ return None, None, None
424
+
425
+
426
+ def clean_entities(company, quarter, year):
427
+ company_ticker_map = {
428
+ "apple": "AAPL",
429
+ "amd": "AMD",
430
+ "amazon": "AMZN",
431
+ "cisco": "CSCO",
432
+ "google": "GOOGL",
433
+ "microsoft": "MSFT",
434
+ "nvidia": "NVDA",
435
+ "asml": "ASML",
436
+ "intel": "INTC",
437
+ "micron": "MU",
438
+ }
439
+
440
+ ticker_choice = [
441
+ "AAPL",
442
+ "CSCO",
443
+ "MSFT",
444
+ "ASML",
445
+ "NVDA",
446
+ "GOOGL",
447
+ "MU",
448
+ "INTC",
449
+ "AMZN",
450
+ "AMD",
451
+ ]
452
+ year_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
453
+ quarter_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
454
+ if company is not None:
455
+ if company in company_ticker_map.keys():
456
+ ticker = company_ticker_map[company]
457
+ ticker_index = ticker_choice.index(ticker)
458
+ else:
459
+ ticker_index = 0
460
+ else:
461
+ ticker_index = 0
462
+ if quarter is not None:
463
+ if quarter in quarter_choice:
464
+ quarter_index = quarter_choice.index(quarter)
465
+ else:
466
+ quarter_index = len(quarter_choice) - 1
467
+ else:
468
+ quarter_index = len(quarter_choice) - 1
469
+ if year is not None:
470
+ if year in year_choice:
471
+ year_index = year_choice.index(year)
472
+ else:
473
+ year_index = len(year_choice) - 1
474
+ else:
475
+ year_index = len(year_choice) - 1
476
+ return ticker_index, quarter_index, year_index
477
+
478
+
479
  # Transcript Retrieval
480
 
481