visheratin commited on
Commit
1924a68
1 Parent(s): 27ab534

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +32 -20
nllb_mrl.py CHANGED
@@ -1,21 +1,26 @@
1
- from dataclasses import dataclass
2
  from typing import List, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- from huggingface_hub import PyTorchModelHubMixin
8
- from open_clip import create_model_and_transforms, get_tokenizer
9
  from PIL import Image
10
- from transformers import PretrainedConfig
11
 
12
 
13
- @dataclass
14
  class MatryoshkaNllbClipConfig(PretrainedConfig):
15
- clip_model_name: str
16
- clip_model_version: str
17
- target_resolution: int
18
- mrl_resolutions: List[int]
 
 
 
 
 
 
 
19
 
20
 
21
  class MatryoshkaLayer(nn.Module):
@@ -37,16 +42,23 @@ class MatryoshkaLayer(nn.Module):
37
  return outputs
38
 
39
 
40
- class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
 
 
41
  def __init__(self, config: MatryoshkaNllbClipConfig, device):
42
- super().__init__()
43
  if isinstance(device, str):
44
  device = torch.device(device)
45
  self.config = config
46
- self.model, _, self.transform = create_model_and_transforms(
47
- config.clip_model_name, config.clip_model_version, output_dict=True
 
 
 
 
 
48
  )
49
- self.device = device
50
  self.model.to(device)
51
  self.matryoshka_layer = MatryoshkaLayer(
52
  config.mrl_resolutions, config.target_resolution
@@ -55,8 +67,8 @@ class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
55
  self.tokenizer = get_tokenizer(config.clip_model_name)
56
 
57
  def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
58
- image_inputs = image_inputs.to(self.device)
59
- input_ids = input_ids.to(self.device)
60
  outputs = self.model(
61
  image=image_inputs,
62
  text=input_ids,
@@ -118,7 +130,7 @@ class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
118
  resolution: Union[int, None] = None,
119
  ):
120
  image_inputs = [self.transform(image) for image in images]
121
- image_inputs = torch.stack(image_inputs, dim=0).to(self.device)
122
  with torch.inference_mode():
123
  features = self.model.visual(image_inputs)
124
  if resolution is not None:
@@ -138,10 +150,10 @@ class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
138
  ):
139
  if langs is None:
140
  langs = ["eng_Latn"] * len(texts)
141
- texts = [f"{lang} {text}" for lang, text in zip(langs, texts)]
142
  input_ids = self.tokenizer.tokenizer.batch_encode_plus(
143
  texts, return_tensors="pt", padding="longest", add_special_tokens=False
144
- )["input_ids"].to(self.device)
145
  with torch.inference_mode():
146
  features = self.model.text(input_ids)
147
  if resolution is not None:
@@ -172,4 +184,4 @@ class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
172
  if self.model.logit_bias is not None:
173
  image_logits += self.model.logit_bias
174
  text_logits = image_logits.T
175
- return image_logits, text_logits
 
 
1
  from typing import List, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from open_clip import create_model, get_tokenizer
7
+ from open_clip.transform import PreprocessCfg, image_transform_v2
8
  from PIL import Image
9
+ from transformers import PretrainedConfig, PreTrainedModel
10
 
11
 
 
12
  class MatryoshkaNllbClipConfig(PretrainedConfig):
13
+ def __init__(
14
+ self,
15
+ clip_model_name: str = "",
16
+ target_resolution: int = -1,
17
+ mrl_resolutions: List[int] = [],
18
+ **kwargs,
19
+ ):
20
+ super().__init__(**kwargs)
21
+ self.clip_model_name = clip_model_name
22
+ self.target_resolution = target_resolution
23
+ self.mrl_resolutions = mrl_resolutions
24
 
25
 
26
  class MatryoshkaLayer(nn.Module):
 
42
  return outputs
43
 
44
 
45
+ class MatryoshkaNllbClip(PreTrainedModel):
46
+ config_class = MatryoshkaNllbClipConfig
47
+
48
  def __init__(self, config: MatryoshkaNllbClipConfig, device):
49
+ super().__init__(config)
50
  if isinstance(device, str):
51
  device = torch.device(device)
52
  self.config = config
53
+ self.model = create_model(
54
+ config.clip_model_name, output_dict=True
55
+ )
56
+ pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
57
+ self.transform = image_transform_v2(
58
+ pp_cfg,
59
+ is_train=False,
60
  )
61
+ self._device = device
62
  self.model.to(device)
63
  self.matryoshka_layer = MatryoshkaLayer(
64
  config.mrl_resolutions, config.target_resolution
 
67
  self.tokenizer = get_tokenizer(config.clip_model_name)
68
 
69
  def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
70
+ image_inputs = image_inputs.to(self._device)
71
+ input_ids = input_ids.to(self._device)
72
  outputs = self.model(
73
  image=image_inputs,
74
  text=input_ids,
 
130
  resolution: Union[int, None] = None,
131
  ):
132
  image_inputs = [self.transform(image) for image in images]
133
+ image_inputs = torch.stack(image_inputs, dim=0).to(self._device)
134
  with torch.inference_mode():
135
  features = self.model.visual(image_inputs)
136
  if resolution is not None:
 
150
  ):
151
  if langs is None:
152
  langs = ["eng_Latn"] * len(texts)
153
+ texts = [f"{lang}{text}" for lang, text in zip(langs, texts)]
154
  input_ids = self.tokenizer.tokenizer.batch_encode_plus(
155
  texts, return_tensors="pt", padding="longest", add_special_tokens=False
156
+ )["input_ids"].to(self._device)
157
  with torch.inference_mode():
158
  features = self.model.text(input_ids)
159
  if resolution is not None:
 
184
  if self.model.logit_bias is not None:
185
  image_logits += self.model.logit_bias
186
  text_logits = image_logits.T
187
+ return image_logits, text_logits