visheratin commited on
Commit
27ab534
·
verified ·
1 Parent(s): f4f9035

Update nllb_mrl.py

Browse files
Files changed (1) hide show
  1. nllb_mrl.py +56 -36
nllb_mrl.py CHANGED
@@ -1,26 +1,21 @@
 
1
  from typing import List, Union
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
- from open_clip import create_model, get_tokenizer
7
- from open_clip.transform import PreprocessCfg, image_transform_v2
8
  from PIL import Image
9
- from transformers import PretrainedConfig, PreTrainedModel
10
 
11
 
 
12
  class MatryoshkaNllbClipConfig(PretrainedConfig):
13
- def __init__(
14
- self,
15
- clip_model_name: str = "",
16
- target_resolution: int = -1,
17
- mrl_resolutions: List[int] = [],
18
- **kwargs,
19
- ):
20
- super().__init__(**kwargs)
21
- self.clip_model_name = clip_model_name
22
- self.target_resolution = target_resolution
23
- self.mrl_resolutions = mrl_resolutions
24
 
25
 
26
  class MatryoshkaLayer(nn.Module):
@@ -42,23 +37,16 @@ class MatryoshkaLayer(nn.Module):
42
  return outputs
43
 
44
 
45
- class MatryoshkaNllbClip(PreTrainedModel):
46
- config_class = MatryoshkaNllbClipConfig
47
-
48
  def __init__(self, config: MatryoshkaNllbClipConfig, device):
49
- super().__init__(config)
50
  if isinstance(device, str):
51
  device = torch.device(device)
52
  self.config = config
53
- self.model = create_model(
54
- config.clip_model_name, output_dict=True
55
  )
56
- pp_cfg = PreprocessCfg(**self.model.visual.preprocess_cfg)
57
- self.transform = image_transform_v2(
58
- pp_cfg,
59
- is_train=False,
60
- )
61
- self._device = device
62
  self.model.to(device)
63
  self.matryoshka_layer = MatryoshkaLayer(
64
  config.mrl_resolutions, config.target_resolution
@@ -67,8 +55,8 @@ class MatryoshkaNllbClip(PreTrainedModel):
67
  self.tokenizer = get_tokenizer(config.clip_model_name)
68
 
69
  def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
70
- image_inputs = image_inputs.to(self._device)
71
- input_ids = input_ids.to(self._device)
72
  outputs = self.model(
73
  image=image_inputs,
74
  text=input_ids,
@@ -91,14 +79,46 @@ class MatryoshkaNllbClip(PreTrainedModel):
91
  "logit_bias": outputs["logit_bias"],
92
  }
93
 
94
- def encode_images(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  self,
96
  images: List[Image.Image],
97
  normalize=False,
98
  resolution: Union[int, None] = None,
99
  ):
100
  image_inputs = [self.transform(image) for image in images]
101
- image_inputs = torch.stack(image_inputs, dim=0).to(self._device)
102
  with torch.inference_mode():
103
  features = self.model.visual(image_inputs)
104
  if resolution is not None:
@@ -109,7 +129,7 @@ class MatryoshkaNllbClip(PreTrainedModel):
109
  features = self.matryoshka_layer.layers[str(resolution)](features)
110
  return F.normalize(features, dim=-1) if normalize else features
111
 
112
- def encode_texts(
113
  self,
114
  texts: List[str],
115
  langs: Union[List[str], None] = None,
@@ -118,10 +138,10 @@ class MatryoshkaNllbClip(PreTrainedModel):
118
  ):
119
  if langs is None:
120
  langs = ["eng_Latn"] * len(texts)
121
- texts = [f"{lang}{text}" for lang, text in zip(langs, texts)]
122
  input_ids = self.tokenizer.tokenizer.batch_encode_plus(
123
  texts, return_tensors="pt", padding="longest", add_special_tokens=False
124
- )["input_ids"].to(self._device)
125
  with torch.inference_mode():
126
  features = self.model.text(input_ids)
127
  if resolution is not None:
@@ -139,10 +159,10 @@ class MatryoshkaNllbClip(PreTrainedModel):
139
  langs: Union[List[str], None] = None,
140
  resolution: Union[int, None] = None,
141
  ):
142
- image_features = self.encode_images(
143
  images, normalize=True, resolution=resolution
144
  )
145
- text_features = self.encode_texts(
146
  texts, langs, normalize=True, resolution=resolution
147
  )
148
  with torch.inference_mode():
@@ -152,4 +172,4 @@ class MatryoshkaNllbClip(PreTrainedModel):
152
  if self.model.logit_bias is not None:
153
  image_logits += self.model.logit_bias
154
  text_logits = image_logits.T
155
- return image_logits, text_logits
 
1
+ from dataclasses import dataclass
2
  from typing import List, Union
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+ from open_clip import create_model_and_transforms, get_tokenizer
9
  from PIL import Image
10
+ from transformers import PretrainedConfig
11
 
12
 
13
+ @dataclass
14
  class MatryoshkaNllbClipConfig(PretrainedConfig):
15
+ clip_model_name: str
16
+ clip_model_version: str
17
+ target_resolution: int
18
+ mrl_resolutions: List[int]
 
 
 
 
 
 
 
19
 
20
 
21
  class MatryoshkaLayer(nn.Module):
 
37
  return outputs
38
 
39
 
40
+ class MatryoshkaNllbClip(nn.Module, PyTorchModelHubMixin):
 
 
41
  def __init__(self, config: MatryoshkaNllbClipConfig, device):
42
+ super().__init__()
43
  if isinstance(device, str):
44
  device = torch.device(device)
45
  self.config = config
46
+ self.model, _, self.transform = create_model_and_transforms(
47
+ config.clip_model_name, config.clip_model_version, output_dict=True
48
  )
49
+ self.device = device
 
 
 
 
 
50
  self.model.to(device)
51
  self.matryoshka_layer = MatryoshkaLayer(
52
  config.mrl_resolutions, config.target_resolution
 
55
  self.tokenizer = get_tokenizer(config.clip_model_name)
56
 
57
  def forward(self, image_inputs, input_ids, resolution: Union[int, None] = None):
58
+ image_inputs = image_inputs.to(self.device)
59
+ input_ids = input_ids.to(self.device)
60
  outputs = self.model(
61
  image=image_inputs,
62
  text=input_ids,
 
79
  "logit_bias": outputs["logit_bias"],
80
  }
81
 
82
+ def encode_image(
83
+ self,
84
+ image,
85
+ normalize=False,
86
+ resolution: Union[int, None] = None,
87
+ ):
88
+ with torch.inference_mode():
89
+ features = self.model.visual(image)
90
+ if resolution is not None:
91
+ if resolution not in self.matryoshka_layer.resolutions:
92
+ raise ValueError(
93
+ f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
94
+ )
95
+ features = self.matryoshka_layer.layers[str(resolution)](features)
96
+ return F.normalize(features, dim=-1) if normalize else features
97
+
98
+ def encode_text(
99
+ self,
100
+ text,
101
+ normalize=False,
102
+ resolution: Union[int, None] = None,
103
+ ):
104
+ with torch.inference_mode():
105
+ features = self.model.text(text)
106
+ if resolution is not None:
107
+ if resolution not in self.matryoshka_layer.resolutions:
108
+ raise ValueError(
109
+ f"Resolution {resolution} not in {self.matryoshka_layer.resolutions}"
110
+ )
111
+ features = self.matryoshka_layer.layers[str(resolution)](features)
112
+ return F.normalize(features, dim=-1) if normalize else features
113
+
114
+ def image_features(
115
  self,
116
  images: List[Image.Image],
117
  normalize=False,
118
  resolution: Union[int, None] = None,
119
  ):
120
  image_inputs = [self.transform(image) for image in images]
121
+ image_inputs = torch.stack(image_inputs, dim=0).to(self.device)
122
  with torch.inference_mode():
123
  features = self.model.visual(image_inputs)
124
  if resolution is not None:
 
129
  features = self.matryoshka_layer.layers[str(resolution)](features)
130
  return F.normalize(features, dim=-1) if normalize else features
131
 
132
+ def text_features(
133
  self,
134
  texts: List[str],
135
  langs: Union[List[str], None] = None,
 
138
  ):
139
  if langs is None:
140
  langs = ["eng_Latn"] * len(texts)
141
+ texts = [f"{lang} {text}" for lang, text in zip(langs, texts)]
142
  input_ids = self.tokenizer.tokenizer.batch_encode_plus(
143
  texts, return_tensors="pt", padding="longest", add_special_tokens=False
144
+ )["input_ids"].to(self.device)
145
  with torch.inference_mode():
146
  features = self.model.text(input_ids)
147
  if resolution is not None:
 
159
  langs: Union[List[str], None] = None,
160
  resolution: Union[int, None] = None,
161
  ):
162
+ image_features = self.image_features(
163
  images, normalize=True, resolution=resolution
164
  )
165
+ text_features = self.text_features(
166
  texts, langs, normalize=True, resolution=resolution
167
  )
168
  with torch.inference_mode():
 
172
  if self.model.logit_bias is not None:
173
  image_logits += self.model.logit_bias
174
  text_logits = image_logits.T
175
+ return image_logits, text_logits