Spaces:
Build error
Build error
File size: 4,820 Bytes
8cd1f1e 9975133 f9da573 8cd1f1e f9da573 a64e1a1 f9da573 8cd1f1e f9da573 9975133 f9da573 8cd1f1e a7b0635 9975133 b19bb41 8cd1f1e 9975133 8cd1f1e 9975133 c5f41e6 9975133 c5f41e6 9975133 c5f41e6 9975133 c5f41e6 9975133 8cd1f1e 8d46199 9975133 8cd1f1e 9975133 8cd1f1e 9975133 8cd1f1e b19bb41 8cd1f1e b19bb41 8cd1f1e 9975133 fbd690d 9975133 e514fa8 fbd690d e514fa8 c5f41e6 e514fa8 872808d e514fa8 8cd1f1e 9975133 8cd1f1e 2a99161 9975133 8cd1f1e 40eb760 9975133 8cd1f1e 40eb760 9975133 8cd1f1e 9975133 f9da573 9975133 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import pinecone
import streamlit as st
st.set_page_config(layout="wide")
import streamlit_scrollable_textbox as stx
import openai
from utils import (
get_data,
get_mpnet_embedding_model,
get_sgpt_embedding_model,
get_flan_t5_model,
get_t5_model,
save_key,
)
from utils import (
retrieve_transcript,
query_pinecone,
format_query,
sentence_id_combine,
text_lookup,
generate_prompt,
gpt_model,
)
st.title("Abstractive Question Answering")
st.write(
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
)
col1, col2 = st.columns([3, 3], gap="medium")
with col1:
st.subheader("Question")
query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
with col1:
years_choice = ["2020", "2019", "2018", "2017", "2016"]
with col1:
year = st.selectbox("Year", years_choice)
with col1:
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
ticker_choice = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
]
with col1:
ticker = st.selectbox("Company", ticker_choice)
with st.sidebar:
st.subheader("Select Options:")
with st.sidebar:
num_results = int(st.number_input("Number of Results to query", 1, 5, value=5))
# Choose encoder model
encoder_models_choice = ["MPNET", "SGPT"]
with st.sidebar:
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
# Choose decoder model
decoder_models_choice = [
"GPT3 - (text-davinci-003)",
"T5",
"FLAN-T5",
]
with st.sidebar:
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
if encoder_model == "MPNET":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp")
pinecone_index_name = "week2-all-mpnet-base"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
elif encoder_model == "SGPT":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp")
pinecone_index_name = "week2-sgpt-125m"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_sgpt_embedding_model()
with st.sidebar:
window = int(st.number_input("Sentence Window Size", 0, 5, value=3))
with st.sidebar:
threshold = float(
st.number_input(
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.35
)
)
data = get_data()
query_results = query_pinecone(
query_text,
num_results,
retriever_model,
pinecone_index,
year,
quarter,
ticker,
threshold,
)
if threshold <= 0.90:
context_list = sentence_id_combine(data, query_results, lag=window)
else:
context_list = format_query(query_results)
prompt = generate_prompt(query_text, context_list)
if decoder_model == "GPT3 - (text-davinci-003)":
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(label="Model Prompt", value=prompt, height=270)
openai_key = st.text_input(
"Enter OpenAI key",
value="",
type="password",
)
submitted = st.form_submit_button("Submit")
if submitted:
api_key = save_key(openai_key)
openai.api_key = api_key
generated_text = gpt_model(edited_prompt)
with col2:
st.subheader("Answer:")
st.write(generated_text)
elif decoder_model == "T5":
t5_pipeline = get_t5_model()
output_text = []
for context_text in context_list:
output_text.append(t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
with col2:
st.subheader("Answer:")
st.write(t5_pipeline(generated_text)[0]["summary_text"])
elif decoder_model == "FLAN-T5":
flan_t5_pipeline = get_flan_t5_model()
output_text = []
for context_text in context_list:
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
with col2:
st.subheader("Answer:")
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
with col1:
with st.expander("See Retrieved Text"):
for context_text in context_list:
st.markdown(f"- {context_text}")
file_text = retrieve_transcript(data, year, quarter, ticker)
with col1:
with st.expander("See Transcript"):
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)
|