File size: 1,737 Bytes
8197b11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from collections import Counter
import json
import torch

from tqdm import tqdm
from relik.retriever.data.labels import Labels

from relik.retriever.indexers.inmemory import InMemoryDocumentIndex

if __name__ == "__main__":
    with open("frequency_blink.txt") as f_in:
        frequencies = [l.strip().split("\t")[0] for l in f_in.readlines()]

    frequencies = set(frequencies[:1_000_000])

    with open(
        "/root/golden-retriever-v2/data/dpr-like/el/definitions_only_data.txt"
    ) as f_in:
        for line in f_in:
            title = line.strip().split(" <def>")[0].strip()
            frequencies.add(title)

    document_index = InMemoryDocumentIndex.from_pretrained(
        "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index",
    )

    new_doc_index = {}
    new_embeddings = []

    for i in range(document_index.documents.get_label_size()):
        doc = document_index.documents.get_label_from_index(i)
        title = doc.split(" <def>")[0].strip()
        if title in frequencies:
            new_doc_index[doc] = len(new_doc_index)
            new_embeddings.append(document_index.embeddings[i])

    print(len(new_doc_index))
    print(len(new_embeddings))

    new_embeddings = torch.stack(new_embeddings, dim=0)
    new_embeddings = new_embeddings.to(torch.float16)

    print(new_embeddings.shape)

    new_label_index = Labels()
    new_label_index.add_labels(new_doc_index)
    new_document_index = InMemoryDocumentIndex(
        documents=new_label_index,
        embeddings=new_embeddings,
    )

    new_document_index.save_pretrained(
        "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered"
    )