Spaces:
Build error
Build error
Upload 16 files
Browse files- app.py +90 -24
- utils/models.py +1 -1
app.py
CHANGED
@@ -22,9 +22,9 @@ from utils.models import (
|
|
22 |
get_data,
|
23 |
get_flan_alpaca_xl_model,
|
24 |
get_flan_t5_model,
|
|
|
25 |
get_mpnet_embedding_model,
|
26 |
get_sgpt_embedding_model,
|
27 |
-
get_instructor_embedding_model,
|
28 |
get_spacy_model,
|
29 |
get_splade_sparse_embedding_model,
|
30 |
get_t5_model,
|
@@ -248,7 +248,13 @@ with st.sidebar:
|
|
248 |
|
249 |
# Choose encoder model
|
250 |
|
251 |
-
encoder_models_choice = [
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
with st.sidebar:
|
253 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
254 |
|
@@ -285,12 +291,32 @@ elif encoder_model == "SGPT":
|
|
285 |
elif encoder_model == "Instructor":
|
286 |
# Connect to pinecone environment
|
287 |
pinecone.init(
|
288 |
-
api_key=st.secrets["pinecone_instructor"],
|
|
|
289 |
)
|
290 |
pinecone_index_name = "week13-instructor-xl"
|
291 |
pinecone_index = pinecone.Index(pinecone_index_name)
|
292 |
retriever_model = get_instructor_embedding_model()
|
293 |
-
instruction =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
elif encoder_model == "Hybrid MPNET - SPLADE":
|
296 |
pinecone.init(
|
@@ -332,10 +358,15 @@ with st.sidebar:
|
|
332 |
data = get_data()
|
333 |
|
334 |
if document_type == "Single-Document":
|
335 |
-
if encoder_model
|
336 |
-
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
|
|
339 |
sparse_query_embedding = create_sparse_embeddings(
|
340 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
341 |
)
|
@@ -383,10 +414,18 @@ else:
|
|
383 |
# Multi-Document Retreival
|
384 |
# Single Company
|
385 |
if multi_company_choice == "Single-Company":
|
386 |
-
if encoder_model
|
387 |
-
|
388 |
-
|
389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
390 |
sparse_query_embedding = create_sparse_embeddings(
|
391 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
392 |
)
|
@@ -448,10 +487,18 @@ else:
|
|
448 |
multi_doc_context = generate_multi_doc_context(context_group)
|
449 |
# Companies Comparison
|
450 |
else:
|
451 |
-
if encoder_model
|
452 |
-
|
453 |
-
|
454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
455 |
sparse_query_embedding = create_sparse_embeddings(
|
456 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
457 |
)
|
@@ -766,22 +813,41 @@ with tab2:
|
|
766 |
for year, quarter in year_quarter_list:
|
767 |
file_text = retrieve_transcript(data, year, quarter, ticker)
|
768 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
769 |
-
st.subheader(
|
|
|
|
|
770 |
stx.scrollableTextbox(
|
771 |
-
file_text,
|
|
|
|
|
|
|
772 |
)
|
773 |
else:
|
774 |
for year, quarter in year_quarter_list:
|
775 |
-
file_text = retrieve_transcript(
|
|
|
|
|
776 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
777 |
-
st.subheader(
|
|
|
|
|
778 |
stx.scrollableTextbox(
|
779 |
-
file_text,
|
|
|
|
|
|
|
780 |
)
|
781 |
for year, quarter in year_quarter_list:
|
782 |
-
file_text = retrieve_transcript(
|
|
|
|
|
783 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
784 |
-
st.subheader(
|
|
|
|
|
785 |
stx.scrollableTextbox(
|
786 |
-
file_text,
|
|
|
|
|
|
|
787 |
)
|
|
|
22 |
get_data,
|
23 |
get_flan_alpaca_xl_model,
|
24 |
get_flan_t5_model,
|
25 |
+
get_instructor_embedding_model,
|
26 |
get_mpnet_embedding_model,
|
27 |
get_sgpt_embedding_model,
|
|
|
28 |
get_spacy_model,
|
29 |
get_splade_sparse_embedding_model,
|
30 |
get_t5_model,
|
|
|
248 |
|
249 |
# Choose encoder model
|
250 |
|
251 |
+
encoder_models_choice = [
|
252 |
+
"MPNET",
|
253 |
+
"Instructor",
|
254 |
+
"Hybrid Instructor - SPLADE",
|
255 |
+
"SGPT",
|
256 |
+
"Hybrid MPNET - SPLADE",
|
257 |
+
]
|
258 |
with st.sidebar:
|
259 |
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
|
260 |
|
|
|
291 |
elif encoder_model == "Instructor":
|
292 |
# Connect to pinecone environment
|
293 |
pinecone.init(
|
294 |
+
api_key=st.secrets["pinecone_instructor"],
|
295 |
+
environment="us-west4-gcp-free",
|
296 |
)
|
297 |
pinecone_index_name = "week13-instructor-xl"
|
298 |
pinecone_index = pinecone.Index(pinecone_index_name)
|
299 |
retriever_model = get_instructor_embedding_model()
|
300 |
+
instruction = (
|
301 |
+
"Represent the financial question for retrieving supporting documents:"
|
302 |
+
)
|
303 |
+
|
304 |
+
elif encoder_model == "Hybrid Instructor - SPLADE":
|
305 |
+
# Connect to pinecone environment
|
306 |
+
pinecone.init(
|
307 |
+
api_key=st.secrets["pinecone_instructor_splade"],
|
308 |
+
environment="us-west4-gcp-free",
|
309 |
+
)
|
310 |
+
pinecone_index_name = "week13-splade-instructor-xl"
|
311 |
+
pinecone_index = pinecone.Index(pinecone_index_name)
|
312 |
+
retriever_model = get_instructor_embedding_model()
|
313 |
+
(
|
314 |
+
sparse_retriever_model,
|
315 |
+
sparse_retriever_tokenizer,
|
316 |
+
) = get_splade_sparse_embedding_model()
|
317 |
+
instruction = (
|
318 |
+
"Represent the financial question for retrieving supporting documents:"
|
319 |
+
)
|
320 |
|
321 |
elif encoder_model == "Hybrid MPNET - SPLADE":
|
322 |
pinecone.init(
|
|
|
358 |
data = get_data()
|
359 |
|
360 |
if document_type == "Single-Document":
|
361 |
+
if encoder_model in ["Hybrid SGPT - SPLADE", "Hybrid Instructor - SPLADE"]:
|
362 |
+
if encoder_model == "Hybrid Instructor - SPLADE":
|
363 |
+
dense_query_embedding = create_dense_embeddings(
|
364 |
+
query_text, retriever_model, instruction
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
dense_query_embedding = create_dense_embeddings(
|
368 |
+
query_text, retriever_model
|
369 |
+
)
|
370 |
sparse_query_embedding = create_sparse_embeddings(
|
371 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
372 |
)
|
|
|
414 |
# Multi-Document Retreival
|
415 |
# Single Company
|
416 |
if multi_company_choice == "Single-Company":
|
417 |
+
if encoder_model in [
|
418 |
+
"Hybrid SGPT - SPLADE",
|
419 |
+
"Hybrid Instructor - SPLADE",
|
420 |
+
]:
|
421 |
+
if encoder_model == "Hybrid Instructor - SPLADE":
|
422 |
+
dense_query_embedding = create_dense_embeddings(
|
423 |
+
query_text, retriever_model, instruction
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
dense_query_embedding = create_dense_embeddings(
|
427 |
+
query_text, retriever_model
|
428 |
+
)
|
429 |
sparse_query_embedding = create_sparse_embeddings(
|
430 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
431 |
)
|
|
|
487 |
multi_doc_context = generate_multi_doc_context(context_group)
|
488 |
# Companies Comparison
|
489 |
else:
|
490 |
+
if encoder_model in [
|
491 |
+
"Hybrid SGPT - SPLADE",
|
492 |
+
"Hybrid Instructor - SPLADE",
|
493 |
+
]:
|
494 |
+
if encoder_model == "Hybrid Instructor - SPLADE":
|
495 |
+
dense_query_embedding = create_dense_embeddings(
|
496 |
+
query_text, retriever_model, instruction
|
497 |
+
)
|
498 |
+
else:
|
499 |
+
dense_query_embedding = create_dense_embeddings(
|
500 |
+
query_text, retriever_model
|
501 |
+
)
|
502 |
sparse_query_embedding = create_sparse_embeddings(
|
503 |
query_text, sparse_retriever_model, sparse_retriever_tokenizer
|
504 |
)
|
|
|
813 |
for year, quarter in year_quarter_list:
|
814 |
file_text = retrieve_transcript(data, year, quarter, ticker)
|
815 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
816 |
+
st.subheader(
|
817 |
+
"Earnings Call Transcript - {quarter} {year}:"
|
818 |
+
)
|
819 |
stx.scrollableTextbox(
|
820 |
+
file_text,
|
821 |
+
height=700,
|
822 |
+
border=False,
|
823 |
+
fontFamily="Helvetica",
|
824 |
)
|
825 |
else:
|
826 |
for year, quarter in year_quarter_list:
|
827 |
+
file_text = retrieve_transcript(
|
828 |
+
data, year, quarter, ticker_first
|
829 |
+
)
|
830 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
831 |
+
st.subheader(
|
832 |
+
"Earnings Call Transcript - {quarter} {year}:"
|
833 |
+
)
|
834 |
stx.scrollableTextbox(
|
835 |
+
file_text,
|
836 |
+
height=700,
|
837 |
+
border=False,
|
838 |
+
fontFamily="Helvetica",
|
839 |
)
|
840 |
for year, quarter in year_quarter_list:
|
841 |
+
file_text = retrieve_transcript(
|
842 |
+
data, year, quarter, ticker_second
|
843 |
+
)
|
844 |
with st.expander(f"See Transcript - {quarter} {year}"):
|
845 |
+
st.subheader(
|
846 |
+
"Earnings Call Transcript - {quarter} {year}:"
|
847 |
+
)
|
848 |
stx.scrollableTextbox(
|
849 |
+
file_text,
|
850 |
+
height=700,
|
851 |
+
border=False,
|
852 |
+
fontFamily="Helvetica",
|
853 |
)
|
utils/models.py
CHANGED
@@ -8,8 +8,8 @@ import spacy
|
|
8 |
import spacy_transformers
|
9 |
import streamlit_scrollable_textbox as stx
|
10 |
import torch
|
11 |
-
from sentence_transformers import SentenceTransformer
|
12 |
from InstructorEmbedding import INSTRUCTOR
|
|
|
13 |
from tqdm import tqdm
|
14 |
from transformers import (
|
15 |
AutoModelForMaskedLM,
|
|
|
8 |
import spacy_transformers
|
9 |
import streamlit_scrollable_textbox as stx
|
10 |
import torch
|
|
|
11 |
from InstructorEmbedding import INSTRUCTOR
|
12 |
+
from sentence_transformers import SentenceTransformer
|
13 |
from tqdm import tqdm
|
14 |
from transformers import (
|
15 |
AutoModelForMaskedLM,
|