Spaces:
Runtime error
Runtime error
# | |
# Pyserini: Reproducible IR research with sparse and dense representations | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import datetime | |
import glob | |
import hashlib | |
import multiprocessing | |
import pickle | |
import os | |
import random | |
import subprocess | |
import uuid | |
import json | |
import time | |
import sys | |
sys.path.append('..') | |
import numpy as np | |
import pandas as pd | |
import lightgbm as lgb | |
from collections import defaultdict | |
from tqdm import tqdm | |
from pyserini.search.lucene.ltr import * | |
import argparse | |
""" | |
train a LTR model with lambdaRank library and save to pickle for future inference | |
run from python root dir | |
""" | |
def train_data_loader(task='triple', neg_sample=20, random_seed=12345): | |
print(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle') | |
if os.path.exists(f'./collections/msmarco-ltr-passage/train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle'): | |
sampled_train = pd.read_pickle(f'./collections/msmarco-ltr-passage/train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle') | |
print(sampled_train.shape) | |
print(sampled_train.index.get_level_values('qid').drop_duplicates().shape) | |
print(sampled_train.groupby('qid').count().mean()) | |
print(sampled_train.head(10)) | |
print(sampled_train.info()) | |
return sampled_train | |
else: | |
if task == 'triple': | |
train = pd.read_csv('./collections/msmarco-passage/qidpidtriples.train.full.2.tsv', sep="\t", | |
names=['qid', 'pos_pid', 'neg_pid'], dtype=np.int32) | |
pos_half = train[['qid', 'pos_pid']].rename(columns={"pos_pid": "pid"}).drop_duplicates() | |
pos_half['rel'] = np.int32(1) | |
neg_half = train[['qid', 'neg_pid']].rename(columns={"neg_pid": "pid"}).drop_duplicates() | |
neg_half['rel'] = np.int32(0) | |
del train | |
sampled_neg_half = [] | |
for qid, group in tqdm(neg_half.groupby('qid')): | |
sampled_neg_half.append(group.sample(n=min(neg_sample, len(group)), random_state=random_seed)) | |
sampled_train = pd.concat([pos_half] + sampled_neg_half, axis=0, ignore_index=True) | |
sampled_train = sampled_train.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) | |
print(sampled_train.shape) | |
print(sampled_train.index.get_level_values('qid').drop_duplicates().shape) | |
print(sampled_train.groupby('qid').count().mean()) | |
print(sampled_train.head(10)) | |
print(sampled_train.info()) | |
sampled_train.to_pickle(f'./collections/msmarco-ltr-passage/train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle') | |
elif task == 'rank': | |
qrel = defaultdict(list) | |
with open("./collections/msmarco-passage/qrels.train.tsv") as f: | |
for line in f: | |
topicid, _, docid, rel = line.strip().split('\t') | |
assert rel == "1", line.split(' ') | |
qrel[topicid].append(docid) | |
qid2pos = defaultdict(list) | |
qid2neg = defaultdict(list) | |
with open("./runs/msmarco-passage/run.train.small.tsv") as f: | |
for line in tqdm(f): | |
topicid, docid, rank = line.split() | |
assert topicid in qrel | |
if docid in qrel[topicid]: | |
qid2pos[topicid].append(docid) | |
else: | |
qid2neg[topicid].append(docid) | |
sampled_train = [] | |
for topicid, pos_list in tqdm(qid2pos.items()): | |
neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample)) | |
for positive_docid in pos_list: | |
sampled_train.append((int(topicid), int(positive_docid), 1)) | |
for negative_docid in neg_list: | |
sampled_train.append((int(topicid), int(negative_docid), 0)) | |
sampled_train = pd.DataFrame(sampled_train, columns=['qid', 'pid', 'rel'], dtype=np.int32) | |
sampled_train = sampled_train.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) | |
print(sampled_train.shape) | |
print(sampled_train.index.get_level_values('qid').drop_duplicates().shape) | |
print(sampled_train.groupby('qid').count().mean()) | |
print(sampled_train.head(10)) | |
print(sampled_train.info()) | |
sampled_train.to_pickle(f'./collections/msmarco-ltr-passage/train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle') | |
else: | |
raise Exception('unknown parameters') | |
return sampled_train | |
def dev_data_loader(task='anserini'): | |
if os.path.exists(f'./collections/msmarco-ltr-passage/dev_{task}.pickle'): | |
dev = pd.read_pickle(f'./collections/msmarco-ltr-passage/dev_{task}.pickle') | |
print(dev.shape) | |
print(dev.index.get_level_values('qid').drop_duplicates().shape) | |
print(dev.groupby('qid').count().mean()) | |
print(dev.head(10)) | |
print(dev.info()) | |
dev_qrel = pd.read_pickle(f'./collections/msmarco-ltr-passage/dev_qrel.pickle') | |
return dev, dev_qrel | |
else: | |
if task == 'rerank': | |
dev = pd.read_csv('./collections/msmarco-passage/top1000.dev', sep="\t", | |
names=['qid', 'pid', 'query', 'doc'], usecols=['qid', 'pid'], dtype=np.int32) | |
elif task == 'anserini': | |
dev = pd.read_csv('./runs/run.msmarco-passage.bm25tuned.txt', sep="\t", | |
names=['qid', 'pid', 'rank'], dtype=np.int32) | |
elif task == 'pygaggle': | |
#pygaggle bm25 top 1000 input | |
dev = pd.read_csv('./collections/msmarco-passage/run.dev.small.tsv', sep="\t", | |
names=['qid', 'pid', 'rank'], dtype=np.int32) | |
else: | |
raise Exception('unknown parameters') | |
dev_qrel = pd.read_csv('./collections/msmarco-passage/qrels.dev.small.tsv', sep="\t", | |
names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], dtype=np.int32) | |
dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left') | |
dev['rel'] = dev['rel'].fillna(0).astype(np.int32) | |
dev = dev.sort_values(['qid', 'pid']).set_index(['qid', 'pid']) | |
print(dev.shape) | |
print(dev.index.get_level_values('qid').drop_duplicates().shape) | |
print(dev.groupby('qid').count().mean()) | |
print(dev.head(10)) | |
print(dev.info()) | |
dev.to_pickle(f'./collections/msmarco-ltr-passage/dev_{task}.pickle') | |
dev_qrel.to_pickle(f'./collections/msmarco-ltr-passage/dev_qrel.pickle') | |
return dev, dev_qrel | |
def query_loader(): | |
queries = {} | |
with open('./collections/msmarco-ltr-passage/queries.eval.small.json') as f: | |
for line in f: | |
query = json.loads(line) | |
qid = query.pop('id') | |
query['analyzed'] = query['analyzed'].split(" ") | |
query['text'] = query['text_unlemm'].split(" ") | |
query['text_unlemm'] = query['text_unlemm'].split(" ") | |
query['text_bert_tok'] = query['text_bert_tok'].split(" ") | |
queries[qid] = query | |
with open('./collections/msmarco-ltr-passage/queries.dev.small.json') as f: | |
for line in f: | |
query = json.loads(line) | |
qid = query.pop('id') | |
query['analyzed'] = query['analyzed'].split(" ") | |
query['text'] = query['text_unlemm'].split(" ") | |
query['text_unlemm'] = query['text_unlemm'].split(" ") | |
query['text_bert_tok'] = query['text_bert_tok'].split(" ") | |
queries[qid] = query | |
with open('./collections/msmarco-ltr-passage/queries.train.json') as f: | |
for line in f: | |
query = json.loads(line) | |
qid = query.pop('id') | |
query['analyzed'] = query['analyzed'].split(" ") | |
query['text'] = query['text_unlemm'].split(" ") | |
query['text_unlemm'] = query['text_unlemm'].split(" ") | |
query['text_bert_tok'] = query['text_bert_tok'].split(" ") | |
queries[qid] = query | |
return queries | |
def batch_extract(df, queries, fe): | |
tasks = [] | |
task_infos = [] | |
group_lst = [] | |
info_dfs = [] | |
feature_dfs = [] | |
group_dfs = [] | |
for qid, group in tqdm(df.groupby('qid')): | |
task = { | |
"qid": str(qid), | |
"docIds": [], | |
"rels": [], | |
"query_dict": queries[str(qid)] | |
} | |
for t in group.reset_index().itertuples(): | |
task["docIds"].append(str(t.pid)) | |
task_infos.append((qid, t.pid, t.rel)) | |
tasks.append(task) | |
group_lst.append((qid, len(task['docIds']))) | |
if len(tasks) == 10000: | |
features = fe.batch_extract(tasks) | |
task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel']) | |
group = pd.DataFrame(group_lst, columns=['qid', 'count']) | |
print(features.shape) | |
print(task_infos.qid.drop_duplicates().shape) | |
print(group.mean()) | |
print(features.head(10)) | |
print(features.info()) | |
info_dfs.append(task_infos) | |
feature_dfs.append(features) | |
group_dfs.append(group) | |
tasks = [] | |
task_infos = [] | |
group_lst = [] | |
# deal with rest | |
if len(tasks) > 0: | |
features = fe.batch_extract(tasks) | |
task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel']) | |
group = pd.DataFrame(group_lst, columns=['qid', 'count']) | |
print(features.shape) | |
print(task_infos.qid.drop_duplicates().shape) | |
print(group.mean()) | |
print(features.head(10)) | |
print(features.info()) | |
info_dfs.append(task_infos) | |
feature_dfs.append(features) | |
group_dfs.append(group) | |
info_dfs = pd.concat(info_dfs, axis=0, ignore_index=True) | |
feature_dfs = pd.concat(feature_dfs, axis=0, ignore_index=True, copy=False) | |
group_dfs = pd.concat(group_dfs, axis=0, ignore_index=True) | |
return info_dfs, feature_dfs, group_dfs | |
def hash_df(df): | |
h = pd.util.hash_pandas_object(df) | |
return hex(h.sum().astype(np.uint64)) | |
def hash_anserini_jar(): | |
find = glob.glob(os.environ['ANSERINI_CLASSPATH'] + "/*fatjar.jar") | |
assert len(find) == 1 | |
md5Hash = hashlib.md5(open(find[0], 'rb').read()) | |
return md5Hash.hexdigest() | |
def hash_fe(fe): | |
return hashlib.md5(','.join(sorted(fe.feature_names())).encode()).hexdigest() | |
def data_loader(task, df, queries, fe): | |
df_hash = hash_df(df) | |
jar_hash = hash_anserini_jar() | |
fe_hash = hash_fe(fe) | |
if task == 'train' or task == 'dev': | |
info, data, group = batch_extract(df, queries, fe) | |
obj = {'info': info, 'data': data, 'group': group, | |
'df_hash': df_hash, 'jar_hash': jar_hash, 'fe_hash': fe_hash} | |
print(info.shape) | |
print(info.qid.drop_duplicates().shape) | |
print(group.mean()) | |
return obj | |
else: | |
raise Exception('unknown parameters') | |
def gen_dev_group_rel_num(dev_qrel, dev_extracted): | |
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] | |
prev_qid = None | |
dev_rel_num_list = [] | |
for t in dev_extracted['info'].itertuples(): | |
if prev_qid is None or t.qid != prev_qid: | |
prev_qid = t.qid | |
dev_rel_num_list.append(dev_rel_num.loc[t.qid]) | |
else: | |
continue | |
assert len(dev_rel_num_list) == dev_qrel.qid.drop_duplicates().shape[0] | |
def recall_at_200(preds, dataset): | |
labels = dataset.get_label() | |
groups = dataset.get_group() | |
idx = 0 | |
recall = 0 | |
assert len(dev_rel_num_list) == len(groups) | |
for g, gnum in zip(groups, dev_rel_num_list): | |
top_preds = labels[idx:idx + g][np.argsort(preds[idx:idx + g])] | |
recall += np.sum(top_preds[-200:]) / gnum | |
idx += g | |
assert idx == len(preds) | |
return 'recall@200', recall / len(groups), True | |
return recall_at_200 | |
def mrr_at_10(preds, dataset): | |
labels = dataset.get_label() | |
groups = dataset.get_group() | |
idx = 0 | |
recall = 0 | |
MRR = [] | |
for g in groups: | |
top_preds = labels[idx:idx + g][np.argsort(preds[idx:idx + g])][-10:][::-1] | |
rank = 0 | |
while(rank < len(top_preds)): | |
if(top_preds[rank] > 0): | |
MRR.append(1.0/(rank+1)) | |
break | |
rank += 1 | |
if (rank == len(top_preds)): | |
MRR.append(0.) | |
idx += g | |
assert idx == len(preds) | |
return 'mrr@10', np.mean(MRR).item(), True | |
def train(train_extracted, dev_extracted, feature_name, eval_fn): | |
lgb_train = lgb.Dataset(train_extracted['data'].loc[:, feature_name], | |
label=train_extracted['info']['rel'], | |
group=train_extracted['group']['count']) | |
lgb_valid = lgb.Dataset(dev_extracted['data'].loc[:, feature_name], | |
label=dev_extracted['info']['rel'], | |
group=dev_extracted['group']['count'], | |
free_raw_data=False) | |
# max_leaves = -1 seems to work better for many settings, although 10 is also good | |
params = { | |
'boosting_type': 'goss', | |
'objective': 'lambdarank', | |
'max_bin': 255, | |
'num_leaves': 200, | |
'max_depth': -1, | |
'min_data_in_leaf': 50, | |
'min_sum_hessian_in_leaf': 0, | |
'feature_fraction': 1, | |
'learning_rate': 0.1, | |
'num_boost_round': 1000, | |
'early_stopping_round': 200, | |
'metric': 'custom', | |
'label_gain': [0, 1], | |
'seed': 12345, | |
'num_threads': max(multiprocessing.cpu_count() // 2, 1) | |
} | |
num_boost_round = params.pop('num_boost_round') | |
early_stopping_round = params.pop('early_stopping_round') | |
gbm = lgb.train(params, lgb_train, | |
valid_sets=lgb_valid, | |
num_boost_round=num_boost_round, | |
early_stopping_rounds=early_stopping_round, | |
feval=eval_fn, | |
feature_name=feature_name, | |
verbose_eval=True) | |
del lgb_train | |
dev_extracted['info']['score'] = gbm.predict(lgb_valid.get_data()) | |
best_score = gbm.best_score['valid_0']['mrr@10'] | |
print(best_score) | |
best_iteration = gbm.best_iteration | |
print(best_iteration) | |
feature_importances = sorted(list(zip(feature_name, gbm.feature_importance().tolist())), | |
key=lambda x: x[1], reverse=True) | |
print(feature_importances) | |
params['num_boost_round'] = num_boost_round | |
params['early_stopping_round'] = early_stopping_round | |
return {'model': [gbm], 'params': params, | |
'feature_names': feature_name, | |
'feature_importances': feature_importances} | |
def eval_mrr(dev_data): | |
score_tie_counter = 0 | |
score_tie_query = set() | |
MRR = [] | |
for qid, group in tqdm(dev_data.groupby('qid')): | |
group = group.reset_index() | |
rank = 0 | |
prev_score = None | |
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | |
# stable sort is also used in LightGBM | |
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | |
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | |
score_tie_counter += 1 | |
score_tie_query.add(qid) | |
prev_score = t.score | |
rank += 1 | |
if t.rel > 0: | |
MRR.append(1.0 / rank) | |
break | |
elif rank == 10 or rank == len(group): | |
MRR.append(0.) | |
break | |
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | |
print(score_tie) | |
mrr_10 = np.mean(MRR).item() | |
print(f'MRR@10:{mrr_10} with {len(MRR)} queries') | |
return {'score_tie': score_tie, 'mrr_10': mrr_10} | |
def eval_recall(dev_qrel, dev_data): | |
dev_rel_num = dev_qrel[dev_qrel['rel'] > 0].groupby('qid').count()['rel'] | |
score_tie_counter = 0 | |
score_tie_query = set() | |
recall_point = [10, 20, 50, 100, 200, 500, 1000] | |
recall_curve = {k: [] for k in recall_point} | |
for qid, group in tqdm(dev_data.groupby('qid')): | |
group = group.reset_index() | |
rank = 0 | |
prev_score = None | |
assert len(group['pid'].tolist()) == len(set(group['pid'].tolist())) | |
# stable sort is also used in LightGBM | |
total_rel = dev_rel_num.loc[qid] | |
query_recall = [0 for k in recall_point] | |
for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples(): | |
if prev_score is not None and abs(t.score - prev_score) < 1e-8: | |
score_tie_counter += 1 | |
score_tie_query.add(qid) | |
prev_score = t.score | |
rank += 1 | |
if t.rel > 0: | |
for i, p in enumerate(recall_point): | |
if rank <= p: | |
query_recall[i] += 1 | |
for i, p in enumerate(recall_point): | |
if total_rel > 0: | |
recall_curve[p].append(query_recall[i] / total_rel) | |
else: | |
recall_curve[p].append(0.) | |
score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries' | |
print(score_tie) | |
res = {'score_tie': score_tie} | |
for k, v in recall_curve.items(): | |
avg = np.mean(v) | |
print(f'recall@{k}:{avg}') | |
res[f'recall@{k}'] = avg | |
return res | |
def gen_exp_dir(): | |
dirname = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + '_' + str(uuid.uuid1()) | |
dirname = './runs/'+dirname | |
assert not os.path.exists(dirname) | |
os.mkdir(dirname) | |
return dirname | |
def save_exp(dirname, | |
train_extracted, dev_extracted, | |
train_res, eval_res): | |
dev_extracted['info'][['qid', 'pid', 'score']].to_json(f'{dirname}/output.json') | |
subprocess.check_output(['gzip', f'{dirname}/output.json']) | |
with open(f'{dirname}/model.pkl', 'wb') as f: | |
pickle.dump(train_res['model'], f) | |
metadata = { | |
'train_df_hash': train_extracted['df_hash'], | |
'train_jar_hash': train_extracted['jar_hash'], | |
'train_fe_hash': train_extracted['fe_hash'], | |
'dev_df_hash': dev_extracted['df_hash'], | |
'dev_jar_hash': dev_extracted['jar_hash'], | |
'dev_fe_hash': dev_extracted['fe_hash'], | |
'feature_names': train_res['feature_names'], | |
'feature_importances': train_res['feature_importances'], | |
'params': train_res['params'], | |
'score_tie': eval_res['score_tie'], | |
'mrr_10': eval_res['mrr_10'] | |
} | |
json.dump(metadata, open(f'{dirname}/metadata.json', 'w')) | |
if __name__ == '__main__': | |
os.environ["ANSERINI_CLASSPATH"] = "pyserini/resources/jars" | |
parser = argparse.ArgumentParser(description='Learning to rank training') | |
parser.add_argument('--index', required=True) | |
parser.add_argument('--neg-sample', default=10) | |
parser.add_argument('--opt', default='mrr_at_10') | |
args = parser.parse_args() | |
total_start_time = time.time() | |
sampled_train = train_data_loader(task='triple', neg_sample = args.neg_sample) | |
dev, dev_qrel = dev_data_loader(task='anserini') | |
queries = query_loader() | |
fe = FeatureExtractor(args.index, | |
max(multiprocessing.cpu_count() // 2, 1)) | |
#fe.add(RunList('./collections/msmarco-ltr-passage/run.monot5.run_list.whole.trec','t5')) | |
#fe.add(RunList('./collections/msmarco-ltr-passage/run.monobert.run_list.whole.trec','bert')) | |
for qfield, ifield in [('analyzed', 'contents'), | |
('text_unlemm', 'text_unlemm'), | |
('text_bert_tok', 'text_bert_tok')]: | |
print(qfield, ifield) | |
fe.add(BM25Stat(SumPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) | |
fe.add(BM25Stat(AvgPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) | |
fe.add(BM25Stat(MedianPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) | |
fe.add(BM25Stat(MaxPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) | |
fe.add(BM25Stat(MinPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) | |
fe.add(BM25Stat(MaxMinRatioPooler(), k1=2.0, b=0.75, field=ifield, qfield=qfield)) | |
fe.add(LmDirStat(SumPooler(), mu=1000, field=ifield, qfield=qfield)) | |
fe.add(LmDirStat(AvgPooler(), mu=1000, field=ifield, qfield=qfield)) | |
fe.add(LmDirStat(MedianPooler(), mu=1000, field=ifield, qfield=qfield)) | |
fe.add(LmDirStat(MaxPooler(), mu=1000, field=ifield, qfield=qfield)) | |
fe.add(LmDirStat(MinPooler(), mu=1000, field=ifield, qfield=qfield)) | |
fe.add(LmDirStat(MaxMinRatioPooler(), mu=1000, field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfIdf(field=ifield, qfield=qfield)) | |
fe.add(ProbalitySum(field=ifield, qfield=qfield)) | |
fe.add(DfrGl2Stat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrGl2Stat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrGl2Stat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrGl2Stat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrGl2Stat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrGl2Stat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrInExpB2Stat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrInExpB2Stat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrInExpB2Stat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrInExpB2Stat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrInExpB2Stat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(DfrInExpB2Stat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(DphStat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(DphStat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(DphStat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(DphStat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(DphStat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(DphStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(Proximity(field=ifield, qfield=qfield)) | |
fe.add(TpScore(field=ifield, qfield=qfield)) | |
fe.add(TpDist(field=ifield, qfield=qfield)) | |
fe.add(DocSize(field=ifield)) | |
fe.add(QueryLength(qfield=qfield)) | |
fe.add(QueryCoverageRatio(qfield=qfield)) | |
fe.add(UniqueTermCount(qfield=qfield)) | |
fe.add(MatchingTermCount(field=ifield, qfield=qfield)) | |
fe.add(SCS(field=ifield, qfield=qfield)) | |
fe.add(TfStat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfStat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfStat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfStat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfStat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfIdfStat(True, AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfIdfStat(True, MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfIdfStat(True, SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfIdfStat(True, MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfIdfStat(True, MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(TfIdfStat(True, MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfStat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfStat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfStat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfStat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfStat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(NormalizedTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(IdfStat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(IdfStat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(IdfStat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(IdfStat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(IdfStat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(IdfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(IcTfStat(AvgPooler(), field=ifield, qfield=qfield)) | |
fe.add(IcTfStat(MedianPooler(), field=ifield, qfield=qfield)) | |
fe.add(IcTfStat(SumPooler(), field=ifield, qfield=qfield)) | |
fe.add(IcTfStat(MinPooler(), field=ifield, qfield=qfield)) | |
fe.add(IcTfStat(MaxPooler(), field=ifield, qfield=qfield)) | |
fe.add(IcTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield)) | |
fe.add(UnorderedSequentialPairs(3, field=ifield, qfield=qfield)) | |
fe.add(UnorderedSequentialPairs(8, field=ifield, qfield=qfield)) | |
fe.add(UnorderedSequentialPairs(15, field=ifield, qfield=qfield)) | |
fe.add(OrderedSequentialPairs(3, field=ifield, qfield=qfield)) | |
fe.add(OrderedSequentialPairs(8, field=ifield, qfield=qfield)) | |
fe.add(OrderedSequentialPairs(15, field=ifield, qfield=qfield)) | |
fe.add(UnorderedQueryPairs(3, field=ifield, qfield=qfield)) | |
fe.add(UnorderedQueryPairs(8, field=ifield, qfield=qfield)) | |
fe.add(UnorderedQueryPairs(15, field=ifield, qfield=qfield)) | |
fe.add(OrderedQueryPairs(3, field=ifield, qfield=qfield)) | |
fe.add(OrderedQueryPairs(8, field=ifield, qfield=qfield)) | |
fe.add(OrderedQueryPairs(15, field=ifield, qfield=qfield)) | |
start = time.time() | |
fe.add(IbmModel1("collections/msmarco-ltr-passage/ibm_model/title_unlemm","text_unlemm","title_unlemm","text_unlemm")) | |
end = time.time() | |
print('IBM model Load takes %.2f seconds'%(end-start)) | |
start = end | |
fe.add(IbmModel1("collections/msmarco-ltr-passage/ibm_model/url_unlemm","text_unlemm","url_unlemm","text_unlemm")) | |
end = time.time() | |
print('IBM model Load takes %.2f seconds'%(end-start)) | |
start = end | |
fe.add(IbmModel1("collections/msmarco-ltr-passage/ibm_model/body","text_unlemm","body","text_unlemm")) | |
end = time.time() | |
print('IBM model Load takes %.2f seconds'%(end-start)) | |
start = end | |
fe.add(IbmModel1("collections/msmarco-ltr-passage/ibm_model/text_bert_tok","text_bert_tok","text_bert_tok","text_bert_tok")) | |
end = time.time() | |
print('IBM model Load takes %.2f seconds'%(end-start)) | |
start = end | |
train_extracted = data_loader('train', sampled_train, queries, fe) | |
print("train_extracted") | |
dev_extracted = data_loader('dev', dev, queries, fe) | |
print("dev extracted") | |
feature_name = fe.feature_names() | |
del sampled_train, dev, queries, fe | |
recall_at_20 = gen_dev_group_rel_num(dev_qrel, dev_extracted) | |
print("start train") | |
train_res = train(train_extracted, dev_extracted, feature_name, mrr_at_10) | |
print("end train") | |
eval_res = eval_mrr(dev_extracted['info']) | |
eval_res.update(eval_recall(dev_qrel, dev_extracted['info'])) | |
dirname = gen_exp_dir() | |
save_exp(dirname, train_extracted, dev_extracted, train_res, eval_res) | |
total_time = (time.time() - total_start_time) | |
print(f'Total training time: {total_time:0.3f} s') | |
print('Done!') |