awinml commited on
Commit
4b19adc
·
1 Parent(s): d93448a

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -85
app.py CHANGED
@@ -40,6 +40,7 @@ from utils.prompts import (
40
  generate_gpt_j_two_shot_prompt_2,
41
  generate_gpt_prompt_alpaca,
42
  generate_gpt_prompt_alpaca_multi_doc,
 
43
  generate_gpt_prompt_original,
44
  generate_multi_doc_context,
45
  get_context_list_prompt,
@@ -74,6 +75,11 @@ with st.sidebar:
74
  document_type = st.selectbox(
75
  "Select Query Type", ["Single-Document", "Multi-Document"]
76
  )
 
 
 
 
 
77
 
78
  if ner_choice == "Spacy":
79
  ner_model = get_spacy_model()
@@ -86,11 +92,16 @@ with col1:
86
  value="What was discussed regarding Wearables revenue performance?",
87
  )
88
  else:
89
- query_text = st.text_area(
90
- "Input Query",
91
- value="How has revenue from Wearables performed over the past 2 years?",
92
- )
93
-
 
 
 
 
 
94
 
95
  years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
96
  quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
@@ -145,32 +156,76 @@ if document_type == "Single-Document":
145
 
146
  else:
147
  # Multi-Document Case
148
-
149
  with col1:
150
- # Hardcoding the defaults for a question without metadata
151
- if (
152
- query_text
153
- == "How has revenue from Wearables performed over the past 2 years?"
154
- ):
155
- start_year = st.selectbox("Start Year", years_choice, index=2)
156
- start_quarter = st.selectbox(
157
- "Start Quarter", quarters_choice, index=0
158
- )
 
 
159
 
160
- end_year = st.selectbox("End Year", years_choice, index=0)
161
- end_quarter = st.selectbox("End Quarter", quarters_choice, index=0)
 
 
162
 
163
- ticker = st.selectbox("Company", ticker_choice, index=0)
164
- else:
165
- start_year = st.selectbox("Start Year", years_choice, index=2)
166
- start_quarter = st.selectbox(
167
- "Start Quarter", quarters_choice, index=0
168
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- end_year = st.selectbox("End Year", years_choice, index=0)
171
- end_quarter = st.selectbox("End Quarter", quarters_choice, index=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- ticker = st.selectbox("Company", ticker_choice, index=0)
 
 
 
 
 
 
 
 
 
 
174
 
175
  participant_type = st.selectbox(
176
  "Speaker", ["Company Speaker", "Analyst"]
@@ -186,7 +241,7 @@ with st.sidebar:
186
  )
187
  else:
188
  num_results = int(
189
- st.number_input("Number of Results to query", 1, 15, value=2)
190
  )
191
 
192
 
@@ -252,7 +307,7 @@ with st.sidebar:
252
  )
253
  )
254
  else:
255
- window = int(st.number_input("Sentence Window Size", 0, 10, value=0))
256
 
257
  threshold = float(
258
  st.number_input(
@@ -310,69 +365,191 @@ if document_type == "Single-Document":
310
 
311
  else:
312
  # Multi-Document Retreival
313
- if encoder_model == "Hybrid SGPT - SPLADE":
314
- dense_query_embedding = create_dense_embeddings(
315
- query_text, retriever_model
316
- )
317
- sparse_query_embedding = create_sparse_embeddings(
318
- query_text, sparse_retriever_model, sparse_retriever_tokenizer
319
- )
320
- dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
321
- dense_query_embedding, sparse_query_embedding, 0
322
- )
323
- year_quarter_list = year_quarter_range(
324
- start_quarter, start_year, end_quarter, end_year
325
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- context_group = []
328
- for year, quarter in year_quarter_list:
329
- query_results = query_pinecone_sparse(
330
- dense_query_embedding,
331
- sparse_query_embedding,
332
- num_results,
333
- pinecone_index,
334
- year,
335
- quarter,
336
- ticker,
337
- participant_type,
338
- threshold,
339
  )
340
- results_list = sentence_id_combine(data, query_results, lag=window)
341
- context_group.append((results_list, year, quarter))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
 
 
343
  else:
344
- dense_query_embedding = create_dense_embeddings(
345
- query_text, retriever_model
346
- )
347
- year_quarter_list = year_quarter_range(
348
- start_quarter, start_year, end_quarter, end_year
349
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
- context_group = []
352
- for year, quarter in year_quarter_list:
353
- query_results = query_pinecone(
354
- dense_query_embedding,
355
- num_results,
356
- pinecone_index,
357
- year,
358
- quarter,
359
- ticker,
360
- participant_type,
361
- threshold,
362
  )
363
- results_list = sentence_id_combine(data, query_results, lag=window)
364
- context_group.append((results_list, year, quarter))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
- multi_doc_context = generate_multi_doc_context(context_group)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
 
 
 
 
 
 
368
 
369
  if decoder_model == "GPT-3.5 Turbo":
370
  if document_type == "Single-Document":
371
  prompt = generate_gpt_prompt_alpaca(query_text, context_list)
372
  else:
373
- prompt = generate_gpt_prompt_alpaca_multi_doc(
374
- query_text, context_group
375
- )
 
 
 
 
 
376
 
377
  with col2:
378
  with st.form("my_form"):
@@ -527,6 +704,11 @@ with tab1:
527
  else:
528
  with st.expander("See Retrieved Text"):
529
  st.subheader("Retrieved Text:")
 
 
 
 
 
530
  sections = [
531
  s.strip()
532
  for s in multi_doc_context.split("Document: ")
@@ -554,10 +736,26 @@ with tab2:
554
  file_text, height=700, border=False, fontFamily="Helvetica"
555
  )
556
  else:
557
- for year, quarter in year_quarter_list:
558
- file_text = retrieve_transcript(data, year, quarter, ticker)
559
- with st.expander(f"See Transcript - {quarter} {year}"):
560
- st.subheader("Earnings Call Transcript - {quarter} {year}:")
561
- stx.scrollableTextbox(
562
- file_text, height=700, border=False, fontFamily="Helvetica"
563
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  generate_gpt_j_two_shot_prompt_2,
41
  generate_gpt_prompt_alpaca,
42
  generate_gpt_prompt_alpaca_multi_doc,
43
+ generate_gpt_prompt_alpaca_multi_doc_multi_company,
44
  generate_gpt_prompt_original,
45
  generate_multi_doc_context,
46
  get_context_list_prompt,
 
75
  document_type = st.selectbox(
76
  "Select Query Type", ["Single-Document", "Multi-Document"]
77
  )
78
+ if document_type == "Multi-Document":
79
+ multi_company_choice = st.selectbox(
80
+ "Select Company Query Type",
81
+ ["Single-Company", "Compare Companies"],
82
+ )
83
 
84
  if ner_choice == "Spacy":
85
  ner_model = get_spacy_model()
 
92
  value="What was discussed regarding Wearables revenue performance?",
93
  )
94
  else:
95
+ if multi_company_choice == "Single-Company":
96
+ query_text = st.text_area(
97
+ "Input Query",
98
+ value="What was the reported revenue for Wearables over the last 2 years?",
99
+ )
100
+ else:
101
+ query_text = st.text_area(
102
+ "Input Query",
103
+ value="How was AAPL's capex spend compared to GOOGL?",
104
+ )
105
 
106
  years_choice = ["2020", "2019", "2018", "2017", "2016", "All"]
107
  quarters_choice = ["Q1", "Q2", "Q3", "Q4", "All"]
 
156
 
157
  else:
158
  # Multi-Document Case
 
159
  with col1:
160
+ # Single Company Summary
161
+ if multi_company_choice == "Single-Company":
162
+ # Hardcoding the defaults for a question without metadata
163
+ if (
164
+ query_text
165
+ == "What was the reported revenue for Wearables over the last 2 years?"
166
+ ):
167
+ start_year = st.selectbox("Start Year", years_choice, index=2)
168
+ start_quarter = st.selectbox(
169
+ "Start Quarter", quarters_choice, index=0
170
+ )
171
 
172
+ end_year = st.selectbox("End Year", years_choice, index=0)
173
+ end_quarter = st.selectbox(
174
+ "End Quarter", quarters_choice, index=0
175
+ )
176
 
177
+ ticker = st.selectbox("Company", ticker_choice, index=0)
178
+ else:
179
+ start_year = st.selectbox("Start Year", years_choice, index=2)
180
+ start_quarter = st.selectbox(
181
+ "Start Quarter", quarters_choice, index=0
182
+ )
183
+
184
+ end_year = st.selectbox("End Year", years_choice, index=0)
185
+ end_quarter = st.selectbox(
186
+ "End Quarter", quarters_choice, index=0
187
+ )
188
+
189
+ ticker = st.selectbox("Company", ticker_choice, index=0)
190
+
191
+ # Single Company Summary
192
+ if multi_company_choice == "Compare Companies":
193
+ # Hardcoding the defaults for a question without metadata
194
+ if query_text == "How was AAPL's capex spend compared to GOOGL?":
195
+ start_year = st.selectbox("Start Year", years_choice, index=1)
196
+ start_quarter = st.selectbox(
197
+ "Start Quarter", quarters_choice, index=0
198
+ )
199
 
200
+ end_year = st.selectbox("End Year", years_choice, index=0)
201
+ end_quarter = st.selectbox(
202
+ "End Quarter", quarters_choice, index=0
203
+ )
204
+
205
+ ticker_first = st.selectbox(
206
+ "First Company", ticker_choice, index=0
207
+ )
208
+ ticker_second = st.selectbox(
209
+ "Second Company", ticker_choice, index=5
210
+ )
211
+
212
+ else:
213
+ start_year = st.selectbox("Start Year", years_choice, index=2)
214
+ start_quarter = st.selectbox(
215
+ "Start Quarter", quarters_choice, index=0
216
+ )
217
 
218
+ end_year = st.selectbox("End Year", years_choice, index=0)
219
+ end_quarter = st.selectbox(
220
+ "End Quarter", quarters_choice, index=0
221
+ )
222
+
223
+ ticker_first = st.selectbox(
224
+ "First Company", ticker_choice, index=0
225
+ )
226
+ ticker_second = st.selectbox(
227
+ "Second Company", ticker_choice, index=1
228
+ )
229
 
230
  participant_type = st.selectbox(
231
  "Speaker", ["Company Speaker", "Analyst"]
 
241
  )
242
  else:
243
  num_results = int(
244
+ st.number_input("Number of Results to query", 1, 15, value=4)
245
  )
246
 
247
 
 
307
  )
308
  )
309
  else:
310
+ window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
311
 
312
  threshold = float(
313
  st.number_input(
 
365
 
366
  else:
367
  # Multi-Document Retreival
368
+ # Single Company
369
+ if multi_company_choice == "Single-Company":
370
+ if encoder_model == "Hybrid SGPT - SPLADE":
371
+ dense_query_embedding = create_dense_embeddings(
372
+ query_text, retriever_model
373
+ )
374
+ sparse_query_embedding = create_sparse_embeddings(
375
+ query_text, sparse_retriever_model, sparse_retriever_tokenizer
376
+ )
377
+ dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
378
+ dense_query_embedding, sparse_query_embedding, 0
379
+ )
380
+ year_quarter_list = year_quarter_range(
381
+ start_quarter, start_year, end_quarter, end_year
382
+ )
383
+
384
+ context_group = []
385
+ for year, quarter in year_quarter_list:
386
+ query_results = query_pinecone_sparse(
387
+ dense_query_embedding,
388
+ sparse_query_embedding,
389
+ num_results,
390
+ pinecone_index,
391
+ year,
392
+ quarter,
393
+ ticker,
394
+ participant_type,
395
+ threshold,
396
+ )
397
+ results_list = sentence_id_combine(
398
+ data, query_results, lag=window
399
+ )
400
+ context_group.append((results_list, year, quarter, ticker))
401
 
402
+ else:
403
+ dense_query_embedding = create_dense_embeddings(
404
+ query_text, retriever_model
 
 
 
 
 
 
 
 
 
405
  )
406
+ year_quarter_list = year_quarter_range(
407
+ start_quarter, start_year, end_quarter, end_year
408
+ )
409
+
410
+ context_group = []
411
+ for year, quarter in year_quarter_list:
412
+ query_results = query_pinecone(
413
+ dense_query_embedding,
414
+ num_results,
415
+ pinecone_index,
416
+ year,
417
+ quarter,
418
+ ticker,
419
+ participant_type,
420
+ threshold,
421
+ )
422
+ results_list = sentence_id_combine(
423
+ data, query_results, lag=window
424
+ )
425
+ context_group.append((results_list, year, quarter, ticker))
426
 
427
+ multi_doc_context = generate_multi_doc_context(context_group)
428
+ # Companies Comparison
429
  else:
430
+ if encoder_model == "Hybrid SGPT - SPLADE":
431
+ dense_query_embedding = create_dense_embeddings(
432
+ query_text, retriever_model
433
+ )
434
+ sparse_query_embedding = create_sparse_embeddings(
435
+ query_text, sparse_retriever_model, sparse_retriever_tokenizer
436
+ )
437
+ dense_query_embedding, sparse_query_embedding = hybrid_score_norm(
438
+ dense_query_embedding, sparse_query_embedding, 0
439
+ )
440
+ year_quarter_list = year_quarter_range(
441
+ start_quarter, start_year, end_quarter, end_year
442
+ )
443
+
444
+ # First Company Context
445
+ context_group_first = []
446
+ for year, quarter in year_quarter_list:
447
+ query_results = query_pinecone_sparse(
448
+ dense_query_embedding,
449
+ sparse_query_embedding,
450
+ num_results,
451
+ pinecone_index,
452
+ year,
453
+ quarter,
454
+ ticker_first,
455
+ participant_type,
456
+ threshold,
457
+ )
458
+ results_list = sentence_id_combine(
459
+ data, query_results, lag=window
460
+ )
461
+ context_group_first.append(
462
+ (results_list, year, quarter, ticker_first)
463
+ )
464
+
465
+ # Second Company Context
466
+ context_group_second = []
467
+ for year, quarter in year_quarter_list:
468
+ query_results = query_pinecone_sparse(
469
+ dense_query_embedding,
470
+ sparse_query_embedding,
471
+ num_results,
472
+ pinecone_index,
473
+ year,
474
+ quarter,
475
+ ticker_second,
476
+ participant_type,
477
+ threshold,
478
+ )
479
+ results_list = sentence_id_combine(
480
+ data, query_results, lag=window
481
+ )
482
+ context_group_second.append(
483
+ (results_list, year, quarter, ticker_second)
484
+ )
485
 
486
+ else:
487
+ dense_query_embedding = create_dense_embeddings(
488
+ query_text, retriever_model
 
 
 
 
 
 
 
 
489
  )
490
+ year_quarter_list = year_quarter_range(
491
+ start_quarter, start_year, end_quarter, end_year
492
+ )
493
+
494
+ # First Company Context
495
+ context_group_first = []
496
+ for year, quarter in year_quarter_list:
497
+ query_results = query_pinecone(
498
+ dense_query_embedding,
499
+ num_results,
500
+ pinecone_index,
501
+ year,
502
+ quarter,
503
+ ticker_first,
504
+ participant_type,
505
+ threshold,
506
+ )
507
+ results_list = sentence_id_combine(
508
+ data, query_results, lag=window
509
+ )
510
+ context_group_first.append(
511
+ (results_list, year, quarter, ticker_first)
512
+ )
513
 
514
+ # Second Company Context
515
+ context_group_second = []
516
+ for year, quarter in year_quarter_list:
517
+ query_results = query_pinecone(
518
+ dense_query_embedding,
519
+ num_results,
520
+ pinecone_index,
521
+ year,
522
+ quarter,
523
+ ticker_second,
524
+ participant_type,
525
+ threshold,
526
+ )
527
+ results_list = sentence_id_combine(
528
+ data, query_results, lag=window
529
+ )
530
+ context_group_second.append(
531
+ (results_list, year, quarter, ticker_second)
532
+ )
533
 
534
+ multi_doc_context_first = generate_multi_doc_context(
535
+ context_group_first
536
+ )
537
+ multi_doc_context_second = generate_multi_doc_context(
538
+ context_group_second
539
+ )
540
 
541
  if decoder_model == "GPT-3.5 Turbo":
542
  if document_type == "Single-Document":
543
  prompt = generate_gpt_prompt_alpaca(query_text, context_list)
544
  else:
545
+ if multi_company_choice == "Single-Company":
546
+ prompt = generate_gpt_prompt_alpaca_multi_doc(
547
+ query_text, context_group
548
+ )
549
+ else:
550
+ prompt = generate_gpt_prompt_alpaca_multi_doc_multi_company(
551
+ query_text, context_group_first, context_group_second
552
+ )
553
 
554
  with col2:
555
  with st.form("my_form"):
 
704
  else:
705
  with st.expander("See Retrieved Text"):
706
  st.subheader("Retrieved Text:")
707
+ if multi_company_choice == "Compare Companies":
708
+ multi_doc_context = (
709
+ multi_doc_context_first + multi_doc_context_second
710
+ )
711
+
712
  sections = [
713
  s.strip()
714
  for s in multi_doc_context.split("Document: ")
 
736
  file_text, height=700, border=False, fontFamily="Helvetica"
737
  )
738
  else:
739
+ if multi_company_choice == "Single-Company":
740
+ for year, quarter in year_quarter_list:
741
+ file_text = retrieve_transcript(data, year, quarter, ticker)
742
+ with st.expander(f"See Transcript - {quarter} {year}"):
743
+ st.subheader("Earnings Call Transcript - {quarter} {year}:")
744
+ stx.scrollableTextbox(
745
+ file_text, height=700, border=False, fontFamily="Helvetica"
746
+ )
747
+ else:
748
+ for year, quarter in year_quarter_list:
749
+ file_text = retrieve_transcript(data, year, quarter, ticker_first)
750
+ with st.expander(f"See Transcript - {quarter} {year}"):
751
+ st.subheader("Earnings Call Transcript - {quarter} {year}:")
752
+ stx.scrollableTextbox(
753
+ file_text, height=700, border=False, fontFamily="Helvetica"
754
+ )
755
+ for year, quarter in year_quarter_list:
756
+ file_text = retrieve_transcript(data, year, quarter, ticker_second)
757
+ with st.expander(f"See Transcript - {quarter} {year}"):
758
+ st.subheader("Earnings Call Transcript - {quarter} {year}:")
759
+ stx.scrollableTextbox(
760
+ file_text, height=700, border=False, fontFamily="Helvetica"
761
+ )