File size: 2,365 Bytes
3bd5293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pickle

import torch
import transformers


class MultilingualClip(torch.nn.Module):
    def __init__(self, model_name, tokenizer_name, head_name, weights_dir='data/weights/', cache_dir=None):
        super().__init__()
        self.model_name = model_name
        self.tokenizer_name = tokenizer_name
        self.head_path = weights_dir + head_name

        self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir)
        self.transformer = transformers.AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
        self.clip_head = torch.nn.Linear(in_features=768, out_features=640)
        self._load_head()

    def forward(self, txt):
        txt_tok = self.tokenizer(txt, padding=True, return_tensors='pt')
        embs = self.transformer(**txt_tok)[0]
        att = txt_tok['attention_mask']
        embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]
        return self.clip_head(embs)

    def _load_head(self):
        with open(self.head_path, 'rb') as f:
            lin_weights = pickle.loads(f.read())
        self.clip_head.weight = torch.nn.Parameter(torch.tensor(lin_weights[0]).float().t())
        self.clip_head.bias = torch.nn.Parameter(torch.tensor(lin_weights[1]).float())


AVAILABLE_MODELS = {
    'M-BERT-Distil-40': {
        'model_name': 'M-CLIP/M-BERT-Distil-40',
        'tokenizer_name': 'M-CLIP/M-BERT-Distil-40',
        'head_name': 'M-BERT Distil 40 Linear Weights.pkl'
    },

    'M-BERT-Base-69': {
        'model_name': 'M-CLIP/M-BERT-Base-69',
        'tokenizer_name': 'M-CLIP/M-BERT-Base-69',
        'head_name': 'M-BERT-Base-69 Linear Weights.pkl'
    },

    'Swe-CLIP-500k': {
        'model_name': 'M-CLIP/Swedish-500k',
        'tokenizer_name': 'M-CLIP/Swedish-500k',
        'head_name': 'Swedish-500k Linear Weights.pkl'
    },

    'Swe-CLIP-2M': {
        'model_name': 'M-CLIP/Swedish-2M',
        'tokenizer_name': 'M-CLIP/Swedish-2M',
        'head_name': 'Swedish-2M Linear Weights.pkl'
    },
    
    'M-BERT-Base-ViT-B': {
        'model_name': 'M-CLIP/M-BERT-Base-ViT-B',
        'tokenizer_name': 'M-CLIP/M-BERT-Base-ViT-B',
        'head_name': 'M-BERT-Base-69-ViT Linear Weights.pkl'
    },
}


def load_model(name, cache_dir=None):
    config = AVAILABLE_MODELS[name]
    return MultilingualClip(**config, cache_dir=cache_dir)