import gzip import random from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW import sys import torch import transformers from torch.utils.data import Dataset, DataLoader from torch.cuda.amp import autocast import tqdm from datetime import datetime from shutil import copyfile import os #################################### import gzip from collections import defaultdict import logging import tqdm import numpy as np import sys import pytrec_eval from sentence_transformers import SentenceTransformer, util, CrossEncoder import torch model_name = sys.argv[1] max_length = 350 ######### Evaluation queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz' queries_eval = {} with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn: for line in fIn: qid, query = line.strip().split("\t")[0:2] queries_eval[qid] = query rel = defaultdict(lambda: defaultdict(int)) with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn: for line in fIn: qid, _, pid, score = line.strip().split() score = int(score) if score > 0: rel[qid][pid] = score relevant_qid = [] for qid in queries_eval: if len(rel[qid]) > 0: relevant_qid.append(qid) # Read top 1k passage_cand = {} with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn: for line in fIn: qid, pid, query, passage = line.strip().split("\t") if qid not in passage_cand: passage_cand[qid] = [] passage_cand[qid].append([pid, passage]) def eval_modal(model_path): run = {} model = CrossEncoder(model_path, max_length=512) for qid in relevant_qid: query = queries_eval[qid] cand = passage_cand[qid] pids = [c[0] for c in cand] corpus_sentences = [c[1] for c in cand] ## CrossEncoder cross_inp = [[query, sent] for sent in corpus_sentences] if model.config.num_labels > 1: cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist() else: cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist() cross_scores_sparse = {} for idx, pid in enumerate(pids): cross_scores_sparse[pid] = cross_scores[idx] sparse_scores = cross_scores_sparse run[qid] = {} for pid in sparse_scores: run[qid][pid] = float(sparse_scores[pid]) evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'}) scores = evaluator.evaluate(run) scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()]) print("NDCG@10: {:.2f}".format(scores_mean * 100)) return scores_mean ################################ device = 'cuda' if torch.cuda.is_available() else 'cpu' config = AutoConfig.from_pretrained(model_name) config.num_labels = 1 model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) tokenizer = AutoTokenizer.from_pretrained(model_name) ############# Remove layers if len(sys.argv) > 2: num_layers = int(sys.argv[2]) if num_layers == 6: layers_to_keep = [0, 2, 4, 6, 8, 10] #6 Layers elif num_layers == 4: layers_to_keep = [1, 4, 7, 10] #4 Layers elif num_layers == 2: layers_to_keep = [3, 7] #2 Layers else: print("Unknown number of layers to keep:", num_layers) exit() print("Reduce model to {} layers".format(len(layers_to_keep))) new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(model.bert.encoder.layer) if i in layers_to_keep]) model.bert.encoder.layer = new_layers model.bert.config.num_hidden_layers = len(layers_to_keep) model_name += "_L-{}".format(len(layers_to_keep)) ####################### queries = {} corpus = {} output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) output_save_path_latest = output_save_path+"-latest" tokenizer.save_pretrained(output_save_path) tokenizer.save_pretrained(output_save_path_latest) # Write self to path train_script_path = os.path.join(output_save_path, 'train_script.py') copyfile(__file__, train_script_path) with open(train_script_path, 'a') as fOut: fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) #### train_script_path = os.path.join(output_save_path_latest, 'train_script.py') copyfile(__file__, train_script_path) with open(train_script_path, 'a') as fOut: fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) #### Read train files class MultilingualDataset(Dataset): def __init__(self): self.examples = defaultdict(lambda: defaultdict(list)) #[id][lang] => [samples...] def add(self, lang, filepath): open_method = gzip.open if filepath.endswith('.gz') else open with open_method(filepath, 'rt') as fIn: for line in fIn: pid, passage = line.strip().split("\t") self.examples[pid][lang].append(passage) def __len__(self): return len(self.examples) def __getitem__(self, item): all_examples = self.examples[item] #All examples in all languages lang_examples = random.choice(list(all_examples.values())) #Examples in on specific language return random.choice(lang_examples) #One random example train_corpus = MultilingualDataset() train_corpus.add('en', 'msmarco-data/collection.tsv') train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz') train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz') train_queries = MultilingualDataset() train_queries.add('en', 'msmarco-data/queries.train.tsv') train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz') train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz') ############## MSE Dataset class MSEDataset(Dataset): def __init__(self, filepath): super().__init__() self.examples = [] with open(filepath) as fIn: for line in fIn: pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t") self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)]) def __len__(self): return len(self.examples) def __getitem__(self, item): return self.examples[item] train_batch_size = 16 train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv') train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size) ############## Optimizer weight_decay = 0.01 max_grad_norm = 1 param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5) scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader)) scaler = torch.cuda.amp.GradScaler() loss_fct = torch.nn.MSELoss() ### Start training model.to(device) auto_save = 10000 best_ndcg_score = 0 for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)): batch_queries = [train_queries[qid] for qid in batch[0]] batch_pos = [train_corpus[cid] for cid in batch[1]] batch_neg = [train_corpus[cid] for cid in batch[2]] scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device) with autocast(): inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device) pred_pos = model(**inp_pos).logits.squeeze() inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device) pred_neg = model(**inp_neg).logits.squeeze() pred_diff = pred_pos - pred_neg loss_value = loss_fct(pred_diff, scores) scaler.scale(loss_value).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() if (step_idx+1) % auto_save == 0: print("Step:", step_idx+1) model.save_pretrained(output_save_path_latest) ndcg_score = eval_modal(output_save_path_latest) if ndcg_score >= best_ndcg_score: best_ndcg_score = ndcg_score print("Save to:", output_save_path) model.save_pretrained(output_save_path) model.save_pretrained(output_save_path) # Script was called via: #python train_cross-encoder_mse_multilingual.py microsoft/Multilingual-MiniLM-L12-H384 6