File size: 4,629 Bytes
411678e 31b6e92 411678e 3b52176 d98d50d 446f9c9 d98d50d 3b52176 e741287 3b52176 446f9c9 e694dea d50bad2 3b52176 e741287 3b52176 a957eeb 741aa8b a957eeb 741aa8b a957eeb 16975a3 d98d50d a957eeb 446f9c9 a957eeb 446f9c9 d98d50d 741aa8b a957eeb 446f9c9 a957eeb 446f9c9 a957eeb 446f9c9 a957eeb 446f9c9 a957eeb 446f9c9 a957eeb 741aa8b a957eeb 741aa8b 82bf281 a957eeb 953c510 8eb51fc ee6d004 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import streamlit as st
from functions import *
st.set_page_config(page_title="Earnings Semantic Search", page_icon="π")
st.sidebar.header("Semantic Search")
st.markdown("## Earnings Semantic Search with SBert")
def gen_sentiment(text):
'''Generate sentiment of given text'''
return sent_pipe(text)[0]['label']
bi_enc_dict = {'mpnet-base-v2':"all-mpnet-base-v2",
'e5-base':'intfloat/e5-base',
'instructor-base': 'hkunlp/instructor-base',
'mpnet-base-dot-v1':'multi-qa-mpnet-base-dot-v1',
'setfit-finance': 'nickmuchi/setfit-finetuned-financial-text-classification'}
search_input = st.text_input(
label='Enter Your Search Query',value= "What key challenges did the business face?", key='search')
sbert_model_name = st.sidebar.selectbox("Embedding Model", options=list(bi_enc_dict.keys()), key='sbox')
top_k = 2
window_size = st.sidebar.slider("Number of Sentences Generated in Search Response",min_value=1,max_value=7,value=3)
try:
if search_input:
if "sen_df" in st.session_state and "earnings_passages" in st.session_state:
## Save to a dataframe for ease of visualization
sen_df = st.session_state['sen_df']
passages = chunk_long_text(st.session_state['earnings_passages'],150,window_size=window_size)
with st.spinner(
text=f"Loading {bi_enc_dict[sbert_model_name]} encoder model..."
):
sbert = load_sbert(bi_enc_dict[sbert_model_name])
##### Sematic Search #####
# Encode the query using the bi-encoder and find potentially relevant passages
# corpus_embeddings = sbert.encode(passages, convert_to_tensor=True, show_progress_bar=True)
# question_embedding = sbert.encode(search_input, convert_to_tensor=True)
# question_embedding = question_embedding.cpu()
# hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
# hits = hits[0] # Get the hits for the first query
# ##### Re-Ranking #####
# # Now, score all retrieved passages with the cross_encoder
# cross_inp = [[search_input, passages[hit['corpus_id']]] for hit in hits]
# cross_scores = cross_encoder.predict(cross_inp)
# # Sort results by the cross-encoder scores
# for idx in range(len(cross_scores)):
# hits[idx]['cross-score'] = cross_scores[idx]
# # Output of top-3 hits from re-ranker
# hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
embedding_model = bi_enc_dict[sbert_model_name]
hits = embed_text(search_input,passages,embedding_model)
score='cross-score'
df = pd.DataFrame([(hit[score],passages[hit['corpus_id']]) for hit in hits[0:int(top_k)]],columns=['Score','Text'])
df['Score'] = round(df['Score'],2)
df['Sentiment'] = df.Text.apply(gen_sentiment)
def gen_annotated_text(df):
'''Generate annotated text'''
tag_list=[]
for row in df.itertuples():
label = row[3]
text = row[2]
if label == 'Positive':
tag_list.append((text,label,'#8fce00'))
elif label == 'Negative':
tag_list.append((text,label,'#f44336'))
else:
tag_list.append((text,label,'#000000'))
return tag_list
text_annotations = gen_annotated_text(df)
first, second = text_annotations[0], text_annotations[1]
with st.expander(label='Best Search Query Result', expanded=True):
annotated_text(first)
with st.expander(label='Alternative Search Query Result'):
annotated_text(second)
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
else:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
except RuntimeError:
st.write('Please ensure you have entered the YouTube URL or uploaded the Earnings Call file')
|