awinml commited on
Commit
c5f41e6
1 Parent(s): b19bb41

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -3
app.py CHANGED
@@ -59,11 +59,16 @@ def save_key(api_key):
59
  return api_key
60
 
61
 
62
- def query_pinecone(query, top_k, model, index, threshold=0.5):
63
  # generate embeddings for the query
64
  xq = model.encode([query]).tolist()
65
  # search pinecone index for context passage with the answer
66
- xc = index.query(xq, top_k=top_k, include_metadata=True)
 
 
 
 
 
67
  # filter the context passages based on the score threshold
68
  filtered_matches = []
69
  for match in xc["matches"]:
@@ -137,6 +142,27 @@ st.write(
137
 
138
  query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
141
 
142
 
@@ -180,7 +206,14 @@ threshold = float(
180
  data = get_data()
181
 
182
  query_results = query_pinecone(
183
- query_text, num_results, retriever_model, pinecone_index, threshold
 
 
 
 
 
 
 
184
  )
185
 
186
  if threshold <= 0.60:
 
59
  return api_key
60
 
61
 
62
+ def query_pinecone(query, top_k, model, index, year, quarter, ticker, threshold=0.5):
63
  # generate embeddings for the query
64
  xq = model.encode([query]).tolist()
65
  # search pinecone index for context passage with the answer
66
+ xc = index.query(
67
+ xq,
68
+ top_k=top_k,
69
+ filter={"year": year, "quarter": quarter, "ticker": ticker},
70
+ include_metadata=True,
71
+ )
72
  # filter the context passages based on the score threshold
73
  filtered_matches = []
74
  for match in xc["matches"]:
 
142
 
143
  query_text = st.text_input("Input Query", value="Who is the CEO of Apple?")
144
 
145
+ years_choice = ["2016", "2017", "2018", "2019", "2020"]
146
+
147
+ year = st.selectbox("Year", years_choice)
148
+
149
+ quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
150
+
151
+ ticker_choice = [
152
+ "AAPL",
153
+ "CSCO",
154
+ "MSFT",
155
+ "ASML",
156
+ "NVDA",
157
+ "GOOGL",
158
+ "MU",
159
+ "INTC",
160
+ "AMZN",
161
+ "AMD",
162
+ ]
163
+
164
+ ticker = st.selectbox("Company", ticker_choice)
165
+
166
  num_results = int(st.number_input("Number of Results to query", 1, 5, value=3))
167
 
168
 
 
206
  data = get_data()
207
 
208
  query_results = query_pinecone(
209
+ query_text,
210
+ num_results,
211
+ retriever_model,
212
+ pinecone_index,
213
+ year,
214
+ quarter,
215
+ ticker,
216
+ threshold,
217
  )
218
 
219
  if threshold <= 0.60: