|
from collections import Counter |
|
import json |
|
import torch |
|
|
|
from tqdm import tqdm |
|
from relik.retriever.data.labels import Labels |
|
|
|
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex |
|
|
|
if __name__ == "__main__": |
|
with open("frequency_blink.txt") as f_in: |
|
frequencies = [l.strip().split("\t")[0] for l in f_in.readlines()] |
|
|
|
frequencies = set(frequencies[:1_000_000]) |
|
|
|
with open( |
|
"/root/golden-retriever-v2/data/dpr-like/el/definitions_only_data.txt" |
|
) as f_in: |
|
for line in f_in: |
|
title = line.strip().split(" <def>")[0].strip() |
|
frequencies.add(title) |
|
|
|
document_index = InMemoryDocumentIndex.from_pretrained( |
|
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index", |
|
) |
|
|
|
new_doc_index = {} |
|
new_embeddings = [] |
|
|
|
for i in range(document_index.documents.get_label_size()): |
|
doc = document_index.documents.get_label_from_index(i) |
|
title = doc.split(" <def>")[0].strip() |
|
if title in frequencies: |
|
new_doc_index[doc] = len(new_doc_index) |
|
new_embeddings.append(document_index.embeddings[i]) |
|
|
|
print(len(new_doc_index)) |
|
print(len(new_embeddings)) |
|
|
|
new_embeddings = torch.stack(new_embeddings, dim=0) |
|
new_embeddings = new_embeddings.to(torch.float16) |
|
|
|
print(new_embeddings.shape) |
|
|
|
new_label_index = Labels() |
|
new_label_index.add_labels(new_doc_index) |
|
new_document_index = InMemoryDocumentIndex( |
|
documents=new_label_index, |
|
embeddings=new_embeddings, |
|
) |
|
|
|
new_document_index.save_pretrained( |
|
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered" |
|
) |
|
|