vtiyyal1 commited on
Commit
f8e7b59
1 Parent(s): 4c7c1f7

Upload rerank.py

Browse files

fixed rerank to use dates

Files changed (1) hide show
  1. rerank.py +6 -4
rerank.py CHANGED
@@ -100,6 +100,7 @@ def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
100
  contents = articles['content'].tolist()
101
  uuids = articles['uuid'].tolist()
102
  titles = articles['title'].tolist()
 
103
 
104
  # biencoder retrieval does not have domain
105
  if 'domain' not in articles:
@@ -109,7 +110,7 @@ def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
109
 
110
  cross_inp = [[question, content] for content in contents]
111
  cross_scores = cross_encoder.predict(cross_inp)
112
- scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
113
  scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
114
 
115
  out_values = scores_sentences[:top_n]
@@ -132,7 +133,7 @@ def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) ->
132
  contents = articles['content'].tolist()
133
  uuids = articles['uuid'].tolist()
134
  titles = articles['title'].tolist()
135
-
136
  if 'domain' not in articles:
137
  domain = [""] * len(contents)
138
  else:
@@ -142,16 +143,17 @@ def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) ->
142
  new_uuids = []
143
  new_titles = []
144
  new_domains = []
 
145
  for idx in range(len(contents)):
146
  sents = sent_tokenize(contents[idx])
147
  sentences.extend(sents)
148
  new_uuids.extend([uuids[idx]] * len(sents))
149
  new_titles.extend([titles[idx]] * len(sents))
150
  new_domains.extend([domain[idx]] * len(sents))
151
-
152
  cross_inp = [[question, sent] for sent in sentences]
153
  cross_scores = cross_encoder.predict(cross_inp)
154
- scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
155
  scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
156
 
157
  out_values = scores_sentences[:top_n]
 
100
  contents = articles['content'].tolist()
101
  uuids = articles['uuid'].tolist()
102
  titles = articles['title'].tolist()
103
+ published_dates = articles['published_date'].tolist()
104
 
105
  # biencoder retrieval does not have domain
106
  if 'domain' not in articles:
 
110
 
111
  cross_inp = [[question, content] for content in contents]
112
  cross_scores = cross_encoder.predict(cross_inp)
113
+ scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain, published_dates))
114
  scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
115
 
116
  out_values = scores_sentences[:top_n]
 
133
  contents = articles['content'].tolist()
134
  uuids = articles['uuid'].tolist()
135
  titles = articles['title'].tolist()
136
+ published_dates = articles['published_date'].tolist()
137
  if 'domain' not in articles:
138
  domain = [""] * len(contents)
139
  else:
 
143
  new_uuids = []
144
  new_titles = []
145
  new_domains = []
146
+ new_published_dates = []
147
  for idx in range(len(contents)):
148
  sents = sent_tokenize(contents[idx])
149
  sentences.extend(sents)
150
  new_uuids.extend([uuids[idx]] * len(sents))
151
  new_titles.extend([titles[idx]] * len(sents))
152
  new_domains.extend([domain[idx]] * len(sents))
153
+ new_published_dates.extend([published_dates[idx]] * len(sents))
154
  cross_inp = [[question, sent] for sent in sentences]
155
  cross_scores = cross_encoder.predict(cross_inp)
156
+ scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains, new_published_dates))
157
  scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
158
 
159
  out_values = scores_sentences[:top_n]