awinml commited on
Commit
654c927
1 Parent(s): 2a6af36

Upload 16 files

Browse files
Files changed (2) hide show
  1. app.py +90 -24
  2. utils/models.py +1 -1
app.py CHANGED
@@ -22,9 +22,9 @@ from utils.models import (
22
  get_data,
23
  get_flan_alpaca_xl_model,
24
  get_flan_t5_model,
 
25
  get_mpnet_embedding_model,
26
  get_sgpt_embedding_model,
27
- get_instructor_embedding_model,
28
  get_spacy_model,
29
  get_splade_sparse_embedding_model,
30
  get_t5_model,
@@ -248,7 +248,13 @@ with st.sidebar:
248
 
249
  # Choose encoder model
250
 
251
- encoder_models_choice = ["MPNET", "Instructor", "SGPT", "Hybrid MPNET - SPLADE"]
 
 
 
 
 
 
252
  with st.sidebar:
253
  encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
254
 
@@ -285,12 +291,32 @@ elif encoder_model == "SGPT":
285
  elif encoder_model == "Instructor":
286
  # Connect to pinecone environment
287
  pinecone.init(
288
- api_key=st.secrets["pinecone_instructor"], environment="us-west4-gcp-free"
 
289
  )
290
  pinecone_index_name = "week13-instructor-xl"
291
  pinecone_index = pinecone.Index(pinecone_index_name)
292
  retriever_model = get_instructor_embedding_model()
293
- instruction = "Represent the financial question for retrieving supporting documents:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
 
295
  elif encoder_model == "Hybrid MPNET - SPLADE":
296
  pinecone.init(
@@ -332,10 +358,15 @@ with st.sidebar:
332
  data = get_data()
333
 
334
  if document_type == "Single-Document":
335
- if encoder_model == "Hybrid SGPT - SPLADE":
336
- dense_query_embedding = create_dense_embeddings(
337
- query_text, retriever_model
338
- )
 
 
 
 
 
339
  sparse_query_embedding = create_sparse_embeddings(
340
  query_text, sparse_retriever_model, sparse_retriever_tokenizer
341
  )
@@ -383,10 +414,18 @@ else:
383
  # Multi-Document Retreival
384
  # Single Company
385
  if multi_company_choice == "Single-Company":
386
- if encoder_model == "Hybrid SGPT - SPLADE":
387
- dense_query_embedding = create_dense_embeddings(
388
- query_text, retriever_model
389
- )
 
 
 
 
 
 
 
 
390
  sparse_query_embedding = create_sparse_embeddings(
391
  query_text, sparse_retriever_model, sparse_retriever_tokenizer
392
  )
@@ -448,10 +487,18 @@ else:
448
  multi_doc_context = generate_multi_doc_context(context_group)
449
  # Companies Comparison
450
  else:
451
- if encoder_model == "Hybrid SGPT - SPLADE":
452
- dense_query_embedding = create_dense_embeddings(
453
- query_text, retriever_model
454
- )
 
 
 
 
 
 
 
 
455
  sparse_query_embedding = create_sparse_embeddings(
456
  query_text, sparse_retriever_model, sparse_retriever_tokenizer
457
  )
@@ -766,22 +813,41 @@ with tab2:
766
  for year, quarter in year_quarter_list:
767
  file_text = retrieve_transcript(data, year, quarter, ticker)
768
  with st.expander(f"See Transcript - {quarter} {year}"):
769
- st.subheader("Earnings Call Transcript - {quarter} {year}:")
 
 
770
  stx.scrollableTextbox(
771
- file_text, height=700, border=False, fontFamily="Helvetica"
 
 
 
772
  )
773
  else:
774
  for year, quarter in year_quarter_list:
775
- file_text = retrieve_transcript(data, year, quarter, ticker_first)
 
 
776
  with st.expander(f"See Transcript - {quarter} {year}"):
777
- st.subheader("Earnings Call Transcript - {quarter} {year}:")
 
 
778
  stx.scrollableTextbox(
779
- file_text, height=700, border=False, fontFamily="Helvetica"
 
 
 
780
  )
781
  for year, quarter in year_quarter_list:
782
- file_text = retrieve_transcript(data, year, quarter, ticker_second)
 
 
783
  with st.expander(f"See Transcript - {quarter} {year}"):
784
- st.subheader("Earnings Call Transcript - {quarter} {year}:")
 
 
785
  stx.scrollableTextbox(
786
- file_text, height=700, border=False, fontFamily="Helvetica"
 
 
 
787
  )
 
22
  get_data,
23
  get_flan_alpaca_xl_model,
24
  get_flan_t5_model,
25
+ get_instructor_embedding_model,
26
  get_mpnet_embedding_model,
27
  get_sgpt_embedding_model,
 
28
  get_spacy_model,
29
  get_splade_sparse_embedding_model,
30
  get_t5_model,
 
248
 
249
  # Choose encoder model
250
 
251
+ encoder_models_choice = [
252
+ "MPNET",
253
+ "Instructor",
254
+ "Hybrid Instructor - SPLADE",
255
+ "SGPT",
256
+ "Hybrid MPNET - SPLADE",
257
+ ]
258
  with st.sidebar:
259
  encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
260
 
 
291
  elif encoder_model == "Instructor":
292
  # Connect to pinecone environment
293
  pinecone.init(
294
+ api_key=st.secrets["pinecone_instructor"],
295
+ environment="us-west4-gcp-free",
296
  )
297
  pinecone_index_name = "week13-instructor-xl"
298
  pinecone_index = pinecone.Index(pinecone_index_name)
299
  retriever_model = get_instructor_embedding_model()
300
+ instruction = (
301
+ "Represent the financial question for retrieving supporting documents:"
302
+ )
303
+
304
+ elif encoder_model == "Hybrid Instructor - SPLADE":
305
+ # Connect to pinecone environment
306
+ pinecone.init(
307
+ api_key=st.secrets["pinecone_instructor_splade"],
308
+ environment="us-west4-gcp-free",
309
+ )
310
+ pinecone_index_name = "week13-splade-instructor-xl"
311
+ pinecone_index = pinecone.Index(pinecone_index_name)
312
+ retriever_model = get_instructor_embedding_model()
313
+ (
314
+ sparse_retriever_model,
315
+ sparse_retriever_tokenizer,
316
+ ) = get_splade_sparse_embedding_model()
317
+ instruction = (
318
+ "Represent the financial question for retrieving supporting documents:"
319
+ )
320
 
321
  elif encoder_model == "Hybrid MPNET - SPLADE":
322
  pinecone.init(
 
358
  data = get_data()
359
 
360
  if document_type == "Single-Document":
361
+ if encoder_model in ["Hybrid SGPT - SPLADE", "Hybrid Instructor - SPLADE"]:
362
+ if encoder_model == "Hybrid Instructor - SPLADE":
363
+ dense_query_embedding = create_dense_embeddings(
364
+ query_text, retriever_model, instruction
365
+ )
366
+ else:
367
+ dense_query_embedding = create_dense_embeddings(
368
+ query_text, retriever_model
369
+ )
370
  sparse_query_embedding = create_sparse_embeddings(
371
  query_text, sparse_retriever_model, sparse_retriever_tokenizer
372
  )
 
414
  # Multi-Document Retreival
415
  # Single Company
416
  if multi_company_choice == "Single-Company":
417
+ if encoder_model in [
418
+ "Hybrid SGPT - SPLADE",
419
+ "Hybrid Instructor - SPLADE",
420
+ ]:
421
+ if encoder_model == "Hybrid Instructor - SPLADE":
422
+ dense_query_embedding = create_dense_embeddings(
423
+ query_text, retriever_model, instruction
424
+ )
425
+ else:
426
+ dense_query_embedding = create_dense_embeddings(
427
+ query_text, retriever_model
428
+ )
429
  sparse_query_embedding = create_sparse_embeddings(
430
  query_text, sparse_retriever_model, sparse_retriever_tokenizer
431
  )
 
487
  multi_doc_context = generate_multi_doc_context(context_group)
488
  # Companies Comparison
489
  else:
490
+ if encoder_model in [
491
+ "Hybrid SGPT - SPLADE",
492
+ "Hybrid Instructor - SPLADE",
493
+ ]:
494
+ if encoder_model == "Hybrid Instructor - SPLADE":
495
+ dense_query_embedding = create_dense_embeddings(
496
+ query_text, retriever_model, instruction
497
+ )
498
+ else:
499
+ dense_query_embedding = create_dense_embeddings(
500
+ query_text, retriever_model
501
+ )
502
  sparse_query_embedding = create_sparse_embeddings(
503
  query_text, sparse_retriever_model, sparse_retriever_tokenizer
504
  )
 
813
  for year, quarter in year_quarter_list:
814
  file_text = retrieve_transcript(data, year, quarter, ticker)
815
  with st.expander(f"See Transcript - {quarter} {year}"):
816
+ st.subheader(
817
+ "Earnings Call Transcript - {quarter} {year}:"
818
+ )
819
  stx.scrollableTextbox(
820
+ file_text,
821
+ height=700,
822
+ border=False,
823
+ fontFamily="Helvetica",
824
  )
825
  else:
826
  for year, quarter in year_quarter_list:
827
+ file_text = retrieve_transcript(
828
+ data, year, quarter, ticker_first
829
+ )
830
  with st.expander(f"See Transcript - {quarter} {year}"):
831
+ st.subheader(
832
+ "Earnings Call Transcript - {quarter} {year}:"
833
+ )
834
  stx.scrollableTextbox(
835
+ file_text,
836
+ height=700,
837
+ border=False,
838
+ fontFamily="Helvetica",
839
  )
840
  for year, quarter in year_quarter_list:
841
+ file_text = retrieve_transcript(
842
+ data, year, quarter, ticker_second
843
+ )
844
  with st.expander(f"See Transcript - {quarter} {year}"):
845
+ st.subheader(
846
+ "Earnings Call Transcript - {quarter} {year}:"
847
+ )
848
  stx.scrollableTextbox(
849
+ file_text,
850
+ height=700,
851
+ border=False,
852
+ fontFamily="Helvetica",
853
  )
utils/models.py CHANGED
@@ -8,8 +8,8 @@ import spacy
8
  import spacy_transformers
9
  import streamlit_scrollable_textbox as stx
10
  import torch
11
- from sentence_transformers import SentenceTransformer
12
  from InstructorEmbedding import INSTRUCTOR
 
13
  from tqdm import tqdm
14
  from transformers import (
15
  AutoModelForMaskedLM,
 
8
  import spacy_transformers
9
  import streamlit_scrollable_textbox as stx
10
  import torch
 
11
  from InstructorEmbedding import INSTRUCTOR
12
+ from sentence_transformers import SentenceTransformer
13
  from tqdm import tqdm
14
  from transformers import (
15
  AutoModelForMaskedLM,