gmastrapas
commited on
Commit
•
cd1adcb
1
Parent(s):
4ed2f34
style: apply ruff & isort
Browse files- modeling_clip.py +62 -46
modeling_clip.py
CHANGED
@@ -5,20 +5,28 @@
|
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
from functools import partial
|
8 |
-
from typing import Optional, Tuple, Union
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
import torch.nn.functional as f
|
13 |
import torch.utils.checkpoint
|
14 |
from torch import nn
|
15 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
from transformers.models.clip.modeling_clip import (
|
17 |
CLIPOutput,
|
18 |
CLIPTextModelOutput,
|
19 |
CLIPVisionModelOutput,
|
20 |
clip_loss,
|
21 |
)
|
|
|
22 |
try:
|
23 |
from tqdm.autonotebook import trange
|
24 |
|
@@ -226,6 +234,20 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
226 |
self.preprocess = None
|
227 |
self.post_init()
|
228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
def get_text_features(
|
230 |
self,
|
231 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
@@ -248,11 +270,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
248 |
)
|
249 |
return self.visual_projection(self.vision_model(x=x))
|
250 |
|
251 |
-
def get_tokenizer(self):
|
252 |
-
if not self.tokenizer:
|
253 |
-
self.tokenizer = AutoTokenizer.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
254 |
-
return self.tokenizer
|
255 |
-
|
256 |
@torch.inference_mode()
|
257 |
def encode_text(
|
258 |
self,
|
@@ -266,38 +283,41 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
266 |
**tokenizer_kwargs,
|
267 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
268 |
"""
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
294 |
"""
|
295 |
is_training = self.training
|
296 |
self.eval()
|
297 |
all_embeddings = []
|
298 |
|
299 |
self.tokenizer = self.get_tokenizer()
|
300 |
-
|
301 |
if show_progress_bar is None:
|
302 |
show_progress_bar = (
|
303 |
logger.getEffectiveLevel() == logging.INFO
|
@@ -328,7 +348,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
328 |
0,
|
329 |
len(sentences),
|
330 |
batch_size,
|
331 |
-
desc=
|
332 |
disable=not show_progress_bar,
|
333 |
)
|
334 |
else:
|
@@ -361,13 +381,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
361 |
self.train(is_training)
|
362 |
return all_embeddings
|
363 |
|
364 |
-
|
365 |
-
def get_preprocess(self):
|
366 |
-
if not self.preprocess:
|
367 |
-
self.preprocess = AutoImageProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
|
368 |
-
return self.preprocess
|
369 |
-
|
370 |
-
|
371 |
@torch.inference_mode()
|
372 |
def encode_image(
|
373 |
self,
|
@@ -389,7 +402,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
389 |
Batch size for the computation
|
390 |
show_progress_bar(`bool`, *optional*, defaults to None):
|
391 |
Show a progress bar when encoding images.
|
392 |
-
If set to None, progress bar is only shown when
|
|
|
393 |
convert_to_numpy(`bool`, *optional*, defaults to True):
|
394 |
If true, the output is a list of numpy vectors.
|
395 |
Else, it is a list of pytorch tensors.
|
@@ -399,14 +413,16 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
399 |
device(`torch.device`, *optional*, defaults to None):
|
400 |
Which torch.device to use for the computation
|
401 |
normalize_embeddings(`bool`, *optional*, defaults to False):
|
402 |
-
If set to true, returned vectors will have length 1. In that case,
|
|
|
|
|
403 |
Returns:
|
404 |
By default, a list of tensors is returned.
|
405 |
If convert_to_tensor, a stacked tensor is returned.
|
406 |
If convert_to_numpy, a numpy matrix is returned.
|
407 |
"""
|
408 |
from PIL import Image
|
409 |
-
|
410 |
is_training = self.training
|
411 |
self.eval()
|
412 |
|
@@ -439,13 +455,13 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
439 |
0,
|
440 |
len(images),
|
441 |
batch_size,
|
442 |
-
desc=
|
443 |
disable=not show_progress_bar,
|
444 |
)
|
445 |
else:
|
446 |
range_iter = range(0, len(images), batch_size)
|
447 |
|
448 |
-
for
|
449 |
processed_inputs = self.preprocess([Image.open(image) for image in images])
|
450 |
embeddings = self.get_image_features(processed_inputs)
|
451 |
if normalize_embeddings:
|
|
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
from functools import partial
|
8 |
+
from typing import List, Optional, Tuple, Union
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
12 |
import torch.nn.functional as f
|
13 |
import torch.utils.checkpoint
|
14 |
from torch import nn
|
15 |
+
from transformers import (
|
16 |
+
AutoImageProcessor,
|
17 |
+
AutoTokenizer,
|
18 |
+
BatchEncoding,
|
19 |
+
BatchFeature,
|
20 |
+
PreTrainedModel,
|
21 |
+
logging,
|
22 |
+
)
|
23 |
from transformers.models.clip.modeling_clip import (
|
24 |
CLIPOutput,
|
25 |
CLIPTextModelOutput,
|
26 |
CLIPVisionModelOutput,
|
27 |
clip_loss,
|
28 |
)
|
29 |
+
|
30 |
try:
|
31 |
from tqdm.autonotebook import trange
|
32 |
|
|
|
234 |
self.preprocess = None
|
235 |
self.post_init()
|
236 |
|
237 |
+
def get_tokenizer(self):
|
238 |
+
if not self.tokenizer:
|
239 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
240 |
+
self.config._name_or_path, trust_remote_code=True
|
241 |
+
)
|
242 |
+
return self.tokenizer
|
243 |
+
|
244 |
+
def get_preprocess(self):
|
245 |
+
if not self.preprocess:
|
246 |
+
self.preprocess = AutoImageProcessor.from_pretrained(
|
247 |
+
self.config._name_or_path, trust_remote_code=True
|
248 |
+
)
|
249 |
+
return self.preprocess
|
250 |
+
|
251 |
def get_text_features(
|
252 |
self,
|
253 |
input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
|
|
|
270 |
)
|
271 |
return self.visual_projection(self.vision_model(x=x))
|
272 |
|
|
|
|
|
|
|
|
|
|
|
273 |
@torch.inference_mode()
|
274 |
def encode_text(
|
275 |
self,
|
|
|
283 |
**tokenizer_kwargs,
|
284 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
285 |
"""
|
286 |
+
Computes sentence embeddings
|
287 |
+
Args:
|
288 |
+
sentences(`str` or `List[str]`):
|
289 |
+
Sentence or sentences to be encoded
|
290 |
+
batch_size(`int`, *optional*, defaults to 32):
|
291 |
+
Batch size for the computation
|
292 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
293 |
+
Show a progress bar when encoding sentences.
|
294 |
+
If set to None, progress bar is only shown when
|
295 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
296 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
297 |
+
If true, the output is a list of numpy vectors.
|
298 |
+
Else, it is a list of pytorch tensors.
|
299 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
300 |
+
If true, you get one large tensor as return.
|
301 |
+
Overwrites any setting from convert_to_numpy
|
302 |
+
device(`torch.device`, *optional*, defaults to None):
|
303 |
+
Which torch.device to use for the computation
|
304 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
305 |
+
If set to true, returned vectors will have length 1. In that case,
|
306 |
+
the faster dot-product (util.dot_score) instead of cosine similarity
|
307 |
+
can be used.
|
308 |
+
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
309 |
+
Keyword arguments for the tokenizer
|
310 |
+
Returns:
|
311 |
+
By default, a list of tensors is returned.
|
312 |
+
If convert_to_tensor, a stacked tensor is returned.
|
313 |
+
If convert_to_numpy, a numpy matrix is returned.
|
314 |
"""
|
315 |
is_training = self.training
|
316 |
self.eval()
|
317 |
all_embeddings = []
|
318 |
|
319 |
self.tokenizer = self.get_tokenizer()
|
320 |
+
|
321 |
if show_progress_bar is None:
|
322 |
show_progress_bar = (
|
323 |
logger.getEffectiveLevel() == logging.INFO
|
|
|
348 |
0,
|
349 |
len(sentences),
|
350 |
batch_size,
|
351 |
+
desc='Encoding',
|
352 |
disable=not show_progress_bar,
|
353 |
)
|
354 |
else:
|
|
|
381 |
self.train(is_training)
|
382 |
return all_embeddings
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
@torch.inference_mode()
|
385 |
def encode_image(
|
386 |
self,
|
|
|
402 |
Batch size for the computation
|
403 |
show_progress_bar(`bool`, *optional*, defaults to None):
|
404 |
Show a progress bar when encoding images.
|
405 |
+
If set to None, progress bar is only shown when
|
406 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
407 |
convert_to_numpy(`bool`, *optional*, defaults to True):
|
408 |
If true, the output is a list of numpy vectors.
|
409 |
Else, it is a list of pytorch tensors.
|
|
|
413 |
device(`torch.device`, *optional*, defaults to None):
|
414 |
Which torch.device to use for the computation
|
415 |
normalize_embeddings(`bool`, *optional*, defaults to False):
|
416 |
+
If set to true, returned vectors will have length 1. In that case,
|
417 |
+
the faster dot-product (util.dot_score) instead of cosine similarity
|
418 |
+
can be used.
|
419 |
Returns:
|
420 |
By default, a list of tensors is returned.
|
421 |
If convert_to_tensor, a stacked tensor is returned.
|
422 |
If convert_to_numpy, a numpy matrix is returned.
|
423 |
"""
|
424 |
from PIL import Image
|
425 |
+
|
426 |
is_training = self.training
|
427 |
self.eval()
|
428 |
|
|
|
455 |
0,
|
456 |
len(images),
|
457 |
batch_size,
|
458 |
+
desc='Encoding',
|
459 |
disable=not show_progress_bar,
|
460 |
)
|
461 |
else:
|
462 |
range_iter = range(0, len(images), batch_size)
|
463 |
|
464 |
+
for _ in range_iter:
|
465 |
processed_inputs = self.preprocess([Image.open(image) for image in images])
|
466 |
embeddings = self.get_image_features(processed_inputs)
|
467 |
if normalize_embeddings:
|