from collections import Counter import json import torch from tqdm import tqdm from relik.retriever.data.labels import Labels from relik.retriever.indexers.inmemory import InMemoryDocumentIndex if __name__ == "__main__": with open("frequency_blink.txt") as f_in: frequencies = [l.strip().split("\t")[0] for l in f_in.readlines()] frequencies = set(frequencies[:1_000_000]) with open( "/root/golden-retriever-v2/data/dpr-like/el/definitions_only_data.txt" ) as f_in: for line in f_in: title = line.strip().split(" ")[0].strip() frequencies.add(title) document_index = InMemoryDocumentIndex.from_pretrained( "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index", ) new_doc_index = {} new_embeddings = [] for i in range(document_index.documents.get_label_size()): doc = document_index.documents.get_label_from_index(i) title = doc.split(" ")[0].strip() if title in frequencies: new_doc_index[doc] = len(new_doc_index) new_embeddings.append(document_index.embeddings[i]) print(len(new_doc_index)) print(len(new_embeddings)) new_embeddings = torch.stack(new_embeddings, dim=0) new_embeddings = new_embeddings.to(torch.float16) print(new_embeddings.shape) new_label_index = Labels() new_label_index.add_labels(new_doc_index) new_document_index = InMemoryDocumentIndex( documents=new_label_index, embeddings=new_embeddings, ) new_document_index.save_pretrained( "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered" )