Runtime error
Runtime error
File size: 5,061 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from colbert.utils.utils import load_checkpoint
from colbert.utils.amp import MixedPrecisionManager
from colbert.utils.utils import flatten
from baleen.utils.loaders import *
from baleen.condenser.model import ElectraReader
from baleen.condenser.tokenization import AnswerAwareTokenizer
class Condenser:
def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'):
self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1)
self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2)
assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers."
self.amp, self.tokenizer = self._setup_inference(self.maxlenL2)
self.CollectionX, self.CollectionY = self._load_collection(collectionX_path)
def condense(self, query, backs, ranking):
stage1_preds = self._stage1(query, backs, ranking)
stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds)
return stage1_preds, stage2_preds, stage2_preds_L3x
def _load_model(self, path, device):
model = torch.load(path, map_location='cpu')
ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator']
assert model['arguments']['model'] in ElectraModels, model['arguments']
model = ElectraReader.from_pretrained(model['arguments']['model'])
checkpoint = load_checkpoint(path, model)
model =
maxlen = checkpoint['arguments']['maxlen']
return model, maxlen
def _setup_inference(self, maxlen):
amp = MixedPrecisionManager(activated=True)
tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen)
return amp, tokenizer
def _load_collection(self, collectionX_path):
CollectionX = {}
CollectionY = {}
with open(collectionX_path) as f:
for line_idx, line in enumerate(f):
line = ujson.loads(line)
assert type(line['text']) is list
assert line['pid'] == line_idx, (line_idx, line)
passage = [line['title']] + line['text']
CollectionX[line_idx] = passage
passage = [line['title'] + ' | ' + sentence for sentence in line['text']]
for idx, sentence in enumerate(passage):
CollectionY[(line_idx, idx)] = sentence
return CollectionX, CollectionY
def _stage1(self, query, BACKS, ranking, TOPK=9):
model = self.modelL1
with torch.inference_mode():
backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY]
backs = [query] + backs
query = ' # '.join(backs)
# print(query)
# print(backs)
passages = []
actual_ranking = []
for pid in ranking:
psg = self.CollectionX[pid]
psg = ' [MASK] '.join(psg)
obj = self.tokenizer.process([query], passages, None)
with self.amp.context():
scores = model(
pids = [[pid] * scores.size(1) for pid in actual_ranking]
pids = flatten(pids)
sids = [list(range(scores.size(1))) for pid in actual_ranking]
sids = flatten(sids)
scores = scores.view(-1)
topk = scores.topk(min(TOPK, len(scores))).indices.tolist()
topk_pids = [pids[idx] for idx in topk]
topk_sids = [sids[idx] for idx in topk]
preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)]
pred_plus = BACKS + preds
pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK]
return pred_plus
def _stage2(self, query, preds):
model = self.modelL2
psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY]
psg = ' [MASK] '.join([''] + psgX)
passages = [psg]
# print(passages)
obj = self.tokenizer.process([query], passages, None)
with self.amp.context():
scores = model(
scores = scores.view(-1).tolist()
preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)]
preds = sorted(preds, reverse=True)[:5]
preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2!
preds = [x for score, x in preds if score > 0]
earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs.
preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids]
assert len(preds_L3x) >= 2
assert len(f7([pid for pid, _ in preds_L3x])) <= 4
return preds, preds_L3x