Upload rerank.py
Browse filesfixed rerank to use dates
rerank.py
CHANGED
@@ -100,6 +100,7 @@ def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
|
|
100 |
contents = articles['content'].tolist()
|
101 |
uuids = articles['uuid'].tolist()
|
102 |
titles = articles['title'].tolist()
|
|
|
103 |
|
104 |
# biencoder retrieval does not have domain
|
105 |
if 'domain' not in articles:
|
@@ -109,7 +110,7 @@ def crossencoder_rerank_answer(csv_path: str, question: str, top_n=4) -> list:
|
|
109 |
|
110 |
cross_inp = [[question, content] for content in contents]
|
111 |
cross_scores = cross_encoder.predict(cross_inp)
|
112 |
-
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain))
|
113 |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
114 |
|
115 |
out_values = scores_sentences[:top_n]
|
@@ -132,7 +133,7 @@ def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) ->
|
|
132 |
contents = articles['content'].tolist()
|
133 |
uuids = articles['uuid'].tolist()
|
134 |
titles = articles['title'].tolist()
|
135 |
-
|
136 |
if 'domain' not in articles:
|
137 |
domain = [""] * len(contents)
|
138 |
else:
|
@@ -142,16 +143,17 @@ def crossencoder_rerank_sentencewise(csv_path: str, question: str, top_n=10) ->
|
|
142 |
new_uuids = []
|
143 |
new_titles = []
|
144 |
new_domains = []
|
|
|
145 |
for idx in range(len(contents)):
|
146 |
sents = sent_tokenize(contents[idx])
|
147 |
sentences.extend(sents)
|
148 |
new_uuids.extend([uuids[idx]] * len(sents))
|
149 |
new_titles.extend([titles[idx]] * len(sents))
|
150 |
new_domains.extend([domain[idx]] * len(sents))
|
151 |
-
|
152 |
cross_inp = [[question, sent] for sent in sentences]
|
153 |
cross_scores = cross_encoder.predict(cross_inp)
|
154 |
-
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains))
|
155 |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
156 |
|
157 |
out_values = scores_sentences[:top_n]
|
|
|
100 |
contents = articles['content'].tolist()
|
101 |
uuids = articles['uuid'].tolist()
|
102 |
titles = articles['title'].tolist()
|
103 |
+
published_dates = articles['published_date'].tolist()
|
104 |
|
105 |
# biencoder retrieval does not have domain
|
106 |
if 'domain' not in articles:
|
|
|
110 |
|
111 |
cross_inp = [[question, content] for content in contents]
|
112 |
cross_scores = cross_encoder.predict(cross_inp)
|
113 |
+
scores_sentences = list(zip(cross_scores, contents, uuids, titles, domain, published_dates))
|
114 |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
115 |
|
116 |
out_values = scores_sentences[:top_n]
|
|
|
133 |
contents = articles['content'].tolist()
|
134 |
uuids = articles['uuid'].tolist()
|
135 |
titles = articles['title'].tolist()
|
136 |
+
published_dates = articles['published_date'].tolist()
|
137 |
if 'domain' not in articles:
|
138 |
domain = [""] * len(contents)
|
139 |
else:
|
|
|
143 |
new_uuids = []
|
144 |
new_titles = []
|
145 |
new_domains = []
|
146 |
+
new_published_dates = []
|
147 |
for idx in range(len(contents)):
|
148 |
sents = sent_tokenize(contents[idx])
|
149 |
sentences.extend(sents)
|
150 |
new_uuids.extend([uuids[idx]] * len(sents))
|
151 |
new_titles.extend([titles[idx]] * len(sents))
|
152 |
new_domains.extend([domain[idx]] * len(sents))
|
153 |
+
new_published_dates.extend([published_dates[idx]] * len(sents))
|
154 |
cross_inp = [[question, sent] for sent in sentences]
|
155 |
cross_scores = cross_encoder.predict(cross_inp)
|
156 |
+
scores_sentences = list(zip(cross_scores, sentences, new_uuids, new_titles, new_domains, new_published_dates))
|
157 |
scores_sentences = sorted(scores_sentences, key=lambda x: x[0], reverse=True)
|
158 |
|
159 |
out_values = scores_sentences[:top_n]
|