awinml commited on
Commit
5ef1f60
·
1 Parent(s): bd9fae2

Upload 16 files (#23)

Browse files

- Upload 16 files (26a27349113d5597933a40955d7fd0212a032e35)

Files changed (2) hide show
  1. app.py +2 -2
  2. utils/retriever.py +184 -71
app.py CHANGED
@@ -91,7 +91,7 @@ with st.sidebar:
91
  ["Single-Company", "Compare Companies"],
92
  )
93
 
94
-
95
  corpus, bm25 = get_bm25_model(data)
96
 
97
  tokenized_query = preprocess_text(query_text).split()
@@ -382,7 +382,7 @@ with st.sidebar:
382
  )
383
  )
384
 
385
- data = get_data()
386
 
387
  if document_type == "Single-Document":
388
  if encoder_model in ["Hybrid SGPT - SPLADE", "Hybrid Instructor - SPLADE"]:
 
91
  ["Single-Company", "Compare Companies"],
92
  )
93
 
94
+ data = get_data()
95
  corpus, bm25 = get_bm25_model(data)
96
 
97
  tokenized_query = preprocess_text(query_text).split()
 
382
  )
383
  )
384
 
385
+
386
 
387
  if document_type == "Single-Document":
388
  if encoder_model in ["Hybrid SGPT - SPLADE", "Hybrid Instructor - SPLADE"]:
utils/retriever.py CHANGED
@@ -32,20 +32,188 @@ def query_pinecone(
32
  if year == "All":
33
  if quarter == "All":
34
  if indices != None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  xc = index.query(
36
  vector=dense_vec,
37
  top_k=top_k,
38
  filter={
39
- "Year": {
40
- "$in": [
41
- int("2020"),
42
- int("2019"),
43
- int("2018"),
44
- int("2017"),
45
- int("2016"),
46
- ]
47
- },
48
- "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
49
  "Ticker": {"$eq": ticker},
50
  "QA_Flag": {"$eq": participant},
51
  "Keywords": {"$in": keywords},
@@ -58,42 +226,25 @@ def query_pinecone(
58
  vector=dense_vec,
59
  top_k=top_k,
60
  filter={
61
- "Year": {
62
- "$in": [
63
- int("2020"),
64
- int("2019"),
65
- int("2018"),
66
- int("2017"),
67
- int("2016"),
68
- ]
69
- },
70
- "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
71
  "Ticker": {"$eq": ticker},
72
  "QA_Flag": {"$eq": participant},
73
- "Keywords": {"$in": keywords},
74
  },
75
  include_metadata=True,
76
  )
77
  else:
78
- if indices != None:
79
  xc = index.query(
80
  vector=dense_vec,
81
  top_k=top_k,
82
  filter={
83
- "Year": {
84
- "$in": [
85
- int("2020"),
86
- int("2019"),
87
- int("2018"),
88
- int("2017"),
89
- int("2016"),
90
- ]
91
- },
92
  "Quarter": {"$eq": quarter},
93
  "Ticker": {"$eq": ticker},
94
  "QA_Flag": {"$eq": participant},
95
  "Keywords": {"$in": keywords},
96
- "index": {"$in": indices},
97
  },
98
  include_metadata=True,
99
  )
@@ -102,51 +253,13 @@ def query_pinecone(
102
  vector=dense_vec,
103
  top_k=top_k,
104
  filter={
105
- "Year": {
106
- "$in": [
107
- int("2020"),
108
- int("2019"),
109
- int("2018"),
110
- int("2017"),
111
- int("2016"),
112
- ]
113
- },
114
  "Quarter": {"$eq": quarter},
115
  "Ticker": {"$eq": ticker},
116
  "QA_Flag": {"$eq": participant},
117
- "Keywords": {"$in": keywords},
118
  },
119
  include_metadata=True,
120
  )
121
- else:
122
- # search pinecone index for context passage with the answer
123
- if indices != None:
124
- xc = index.query(
125
- vector=dense_vec,
126
- top_k=top_k,
127
- filter={
128
- "Year": int(year),
129
- "Quarter": {"$eq": quarter},
130
- "Ticker": {"$eq": ticker},
131
- "QA_Flag": {"$eq": participant},
132
- "Keywords": {"$in": keywords},
133
- "index": {"$in": indices},
134
- },
135
- include_metadata=True,
136
- )
137
- else:
138
- xc = index.query(
139
- vector=dense_vec,
140
- top_k=top_k,
141
- filter={
142
- "Year": int(year),
143
- "Quarter": {"$eq": quarter},
144
- "Ticker": {"$eq": ticker},
145
- "QA_Flag": {"$eq": participant},
146
- "Keywords": {"$in": keywords},
147
- },
148
- include_metadata=True,
149
- )
150
  # filter the context passages based on the score threshold
151
  filtered_matches = []
152
  for match in xc["matches"]:
 
32
  if year == "All":
33
  if quarter == "All":
34
  if indices != None:
35
+ if keywords != None:
36
+ xc = index.query(
37
+ vector=dense_vec,
38
+ top_k=top_k,
39
+ filter={
40
+ "Year": {
41
+ "$in": [
42
+ int("2020"),
43
+ int("2019"),
44
+ int("2018"),
45
+ int("2017"),
46
+ int("2016"),
47
+ ]
48
+ },
49
+ "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
50
+ "Ticker": {"$eq": ticker},
51
+ "QA_Flag": {"$eq": participant},
52
+ "Keywords": {"$in": keywords},
53
+ "index": {"$in": indices},
54
+ },
55
+ include_metadata=True,
56
+ )
57
+ else:
58
+ xc = index.query(
59
+ vector=dense_vec,
60
+ top_k=top_k,
61
+ filter={
62
+ "Year": {
63
+ "$in": [
64
+ int("2020"),
65
+ int("2019"),
66
+ int("2018"),
67
+ int("2017"),
68
+ int("2016"),
69
+ ]
70
+ },
71
+ "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
72
+ "Ticker": {"$eq": ticker},
73
+ "QA_Flag": {"$eq": participant},
74
+ "index": {"$in": indices},
75
+ },
76
+ include_metadata=True,
77
+ )
78
+ else:
79
+ if keywords != None:
80
+ xc = index.query(
81
+ vector=dense_vec,
82
+ top_k=top_k,
83
+ filter={
84
+ "Year": {
85
+ "$in": [
86
+ int("2020"),
87
+ int("2019"),
88
+ int("2018"),
89
+ int("2017"),
90
+ int("2016"),
91
+ ]
92
+ },
93
+ "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
94
+ "Ticker": {"$eq": ticker},
95
+ "QA_Flag": {"$eq": participant},
96
+ "Keywords": {"$in": keywords},
97
+ },
98
+ include_metadata=True,
99
+ )
100
+ else:
101
+ xc = index.query(
102
+ vector=dense_vec,
103
+ top_k=top_k,
104
+ filter={
105
+ "Year": {
106
+ "$in": [
107
+ int("2020"),
108
+ int("2019"),
109
+ int("2018"),
110
+ int("2017"),
111
+ int("2016"),
112
+ ]
113
+ },
114
+ "Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
115
+ "Ticker": {"$eq": ticker},
116
+ "QA_Flag": {"$eq": participant},
117
+ },
118
+ include_metadata=True,
119
+ )
120
+ else:
121
+ if indices != None:
122
+ if keywords != None:
123
+ xc = index.query(
124
+ vector=dense_vec,
125
+ top_k=top_k,
126
+ filter={
127
+ "Year": {
128
+ "$in": [
129
+ int("2020"),
130
+ int("2019"),
131
+ int("2018"),
132
+ int("2017"),
133
+ int("2016"),
134
+ ]
135
+ },
136
+ "Quarter": {"$eq": quarter},
137
+ "Ticker": {"$eq": ticker},
138
+ "QA_Flag": {"$eq": participant},
139
+ "Keywords": {"$in": keywords},
140
+ "index": {"$in": indices},
141
+ },
142
+ include_metadata=True,
143
+ )
144
+ else:
145
+ xc = index.query(
146
+ vector=dense_vec,
147
+ top_k=top_k,
148
+ filter={
149
+ "Year": {
150
+ "$in": [
151
+ int("2020"),
152
+ int("2019"),
153
+ int("2018"),
154
+ int("2017"),
155
+ int("2016"),
156
+ ]
157
+ },
158
+ "Quarter": {"$eq": quarter},
159
+ "Ticker": {"$eq": ticker},
160
+ "QA_Flag": {"$eq": participant},
161
+ "index": {"$in": indices},
162
+ },
163
+ include_metadata=True,
164
+ )
165
+ else:
166
+ if keywords != None:
167
+ xc = index.query(
168
+ vector=dense_vec,
169
+ top_k=top_k,
170
+ filter={
171
+ "Year": {
172
+ "$in": [
173
+ int("2020"),
174
+ int("2019"),
175
+ int("2018"),
176
+ int("2017"),
177
+ int("2016"),
178
+ ]
179
+ },
180
+ "Quarter": {"$eq": quarter},
181
+ "Ticker": {"$eq": ticker},
182
+ "QA_Flag": {"$eq": participant},
183
+ "Keywords": {"$in": keywords},
184
+ },
185
+ include_metadata=True,
186
+ )
187
+ else:
188
+ xc = index.query(
189
+ vector=dense_vec,
190
+ top_k=top_k,
191
+ filter={
192
+ "Year": {
193
+ "$in": [
194
+ int("2020"),
195
+ int("2019"),
196
+ int("2018"),
197
+ int("2017"),
198
+ int("2016"),
199
+ ]
200
+ },
201
+ "Quarter": {"$eq": quarter},
202
+ "Ticker": {"$eq": ticker},
203
+ "QA_Flag": {"$eq": participant},
204
+ },
205
+ include_metadata=True,
206
+ )
207
+ else:
208
+ # search pinecone index for context passage with the answer
209
+ if indices != None:
210
+ if keywords != None:
211
  xc = index.query(
212
  vector=dense_vec,
213
  top_k=top_k,
214
  filter={
215
+ "Year": int(year),
216
+ "Quarter": {"$eq": quarter},
 
 
 
 
 
 
 
 
217
  "Ticker": {"$eq": ticker},
218
  "QA_Flag": {"$eq": participant},
219
  "Keywords": {"$in": keywords},
 
226
  vector=dense_vec,
227
  top_k=top_k,
228
  filter={
229
+ "Year": int(year),
230
+ "Quarter": {"$eq": quarter},
 
 
 
 
 
 
 
 
231
  "Ticker": {"$eq": ticker},
232
  "QA_Flag": {"$eq": participant},
233
+ "index": {"$in": indices},
234
  },
235
  include_metadata=True,
236
  )
237
  else:
238
+ if keywords != None:
239
  xc = index.query(
240
  vector=dense_vec,
241
  top_k=top_k,
242
  filter={
243
+ "Year": int(year),
 
 
 
 
 
 
 
 
244
  "Quarter": {"$eq": quarter},
245
  "Ticker": {"$eq": ticker},
246
  "QA_Flag": {"$eq": participant},
247
  "Keywords": {"$in": keywords},
 
248
  },
249
  include_metadata=True,
250
  )
 
253
  vector=dense_vec,
254
  top_k=top_k,
255
  filter={
256
+ "Year": int(year),
 
 
 
 
 
 
 
 
257
  "Quarter": {"$eq": quarter},
258
  "Ticker": {"$eq": ticker},
259
  "QA_Flag": {"$eq": participant},
 
260
  },
261
  include_metadata=True,
262
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  # filter the context passages based on the score threshold
264
  filtered_matches = []
265
  for match in xc["matches"]: