File size: 6,591 Bytes
075dfca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
from langdetect import detect
import torch.multiprocessing as mp

from colbert import Indexer, Searcher
from colbert.infra import ColBERTConfig, Run
from colbert.utils.utils import print_message
from colbert.data.collection import Collection
from colbert.modeling.checkpoint import Checkpoint
from colbert.indexing.index_saver import IndexSaver
from colbert.search.index_storage import IndexScorer
from colbert.infra.launcher import Launcher, print_memory_stats
from colbert.indexing.collection_encoder import CollectionEncoder
from colbert.indexing.collection_indexer import CollectionIndexer


MMARCO_LANGUAGES = {
    'ar': ('arabic', 'ar_AR'),
    'de': ('german', 'de_DE'),
    'en': ('english', 'en_XX'),
    'es': ('spanish', 'es_XX'),
    'fr': ('french', 'fr_XX'),
    'hi': ('hindi', 'hi_IN'),
    'id': ('indonesian', 'id_ID'),
    'it': ('italian', 'it_IT'),
    'ja': ('japanese', 'ja_XX'),
    'nl': ('dutch', 'nl_XX'),
    'pt': ('portuguese', 'pt_XX'),
    'ru': ('russian', 'ru_RU'),
    'vi': ('vietnamese', 'vi_VN'),
    'zh': ('chinese', 'zh_CN'),
}
MRTYDI_LANGUAGES = {
    'ar': ('arabic', 'ar_AR'),
    'bn': ('bengali', 'bn_IN'),
    'en': ('english', 'en_XX'),
    'fi': ('finnish', 'fi_FI'),
    'id': ('indonesian', 'id_ID'),
    'ja': ('japanese', 'ja_XX'),
    'ko': ('korean', 'ko_KR'),
    'ru': ('russian', 'ru_RU'),
    'sw': ('swahili', 'sw_KE'),
    'te': ('telugu', 'te_IN'),
    'th': ('thai', 'th_TH'),
}
MIRACL_LANGUAGES = {
    'ar': ('arabic', 'ar_AR'),
    'bn': ('bengali', 'bn_IN'),
    'en': ('english', 'en_XX'),
    'es': ('spanish', 'es_XX'),
    'fa': ('persian', 'fa_IR'),
    'fi': ('finnish', 'fi_FI'),
    'fr': ('french', 'fr_XX'),
    'hi': ('hindi', 'hi_IN'),
    'id': ('indonesian', 'id_ID'),
    'ja': ('japanese', 'ja_XX'),
    'ko': ('korean', 'ko_KR'),
    'ru': ('russian', 'ru_RU'),
    'sw': ('swahili', 'sw_KE'),
    'te': ('telugu', 'te_IN'),
    'th': ('thai', 'th_TH'),
    'zh': ('chinese', 'zh_CN'),
}
ALL_LANGUAGES = {**MMARCO_LANGUAGES, **MRTYDI_LANGUAGES, **MIRACL_LANGUAGES}


def set_xmod_language(model, lang:str):
    """
    Set the default language code for the model. This is used when the language is not specified in the input.
    Source: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/xmod/modeling_xmod.py#L687
    """
    lang = lang.split('-')[0]
    if (value := ALL_LANGUAGES.get(lang)) is not None:
        model.set_default_language(value[1])
    else:
        raise KeyError(f"Language {lang} not supported.")

#-----------------------------------------------------------------------------------------------------------------#
#                                               INDEXER
#-----------------------------------------------------------------------------------------------------------------#
class CustomIndexer(Indexer):
    def __launch(self, collection):
        manager = mp.Manager()
        shared_lists = [manager.list() for _ in range(self.config.nranks)]
        shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
        launcher = Launcher(custom_encode)
        launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose)

def custom_encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
    encoder = CustomCollectionIndexer(config=config, collection=collection, verbose=verbose)
    encoder.run(shared_lists)

class CustomCollectionIndexer(CollectionIndexer):
    def __init__(self, config: ColBERTConfig, collection, verbose=2):
        self.verbose = verbose
        self.config = config
        self.rank, self.nranks = self.config.rank, self.config.nranks
        self.use_gpu = self.config.total_visible_gpus > 0
        if self.config.rank == 0 and self.verbose > 1:
            self.config.help()
        self.collection = Collection.cast(collection)
        self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config)
        if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"):
            language = detect(self.collection.__getitem__(0))
            Run().print_main(f"#> Setting X-MOD language adapters to {language}.")
            set_xmod_language(self.checkpoint.bert, lang=language)
        if self.use_gpu:
            self.checkpoint = self.checkpoint.cuda()
        self.encoder = CollectionEncoder(config, self.checkpoint)
        self.saver = IndexSaver(config)
        print_memory_stats(f'RANK:{self.rank}')

#-----------------------------------------------------------------------------------------------------------------#
#                                               SEARCHER
#-----------------------------------------------------------------------------------------------------------------#
class CustomSearcher(Searcher):
    def __init__(self, index, checkpoint=None, collection=None, config=None, index_root=None, verbose:int = 3):
        self.verbose = verbose
        if self.verbose > 1:
            print_memory_stats()

        initial_config = ColBERTConfig.from_existing(config, Run().config)

        default_index_root = initial_config.index_root_
        index_root = index_root if index_root else default_index_root
        self.index = os.path.join(index_root, index)
        self.index_config = ColBERTConfig.load_from_index(self.index)

        self.checkpoint = checkpoint or self.index_config.checkpoint
        self.checkpoint_config = ColBERTConfig.load_from_checkpoint(self.checkpoint)
        self.config = ColBERTConfig.from_existing(self.checkpoint_config, self.index_config, initial_config)

        self.collection = Collection.cast(collection or self.config.collection)
        self.configure(checkpoint=self.checkpoint, collection=self.collection)

        self.checkpoint = Checkpoint(self.checkpoint, colbert_config=self.config, verbose=self.verbose)
        if self.checkpoint.bert.__class__.__name__.lower().startswith("xmod"):
            language = detect(self.collection.__getitem__(0))
            print_message(f"#> Setting X-MOD language adapters to {language}.")
            set_xmod_language(self.checkpoint.bert, lang=language)
        use_gpu = self.config.total_visible_gpus > 0
        if use_gpu:
            self.checkpoint = self.checkpoint.cuda()
        load_index_with_mmap = self.config.load_index_with_mmap
        if load_index_with_mmap and use_gpu:
            raise ValueError(f"Memory-mapped index can only be used with CPU!")
        self.ranker = IndexScorer(self.index, use_gpu, load_index_with_mmap)
        print_memory_stats()