gmastrapas commited on
Commit
cd1adcb
1 Parent(s): 4ed2f34

style: apply ruff & isort

Browse files
Files changed (1) hide show
  1. modeling_clip.py +62 -46
modeling_clip.py CHANGED
@@ -5,20 +5,28 @@
5
  # and adjusted for Jina CLIP
6
 
7
  from functools import partial
8
- from typing import Optional, Tuple, Union, List
9
 
10
  import numpy as np
11
  import torch
12
  import torch.nn.functional as f
13
  import torch.utils.checkpoint
14
  from torch import nn
15
- from transformers import BatchEncoding, BatchFeature, PreTrainedModel, logging, AutoImageProcessor, AutoTokenizer
 
 
 
 
 
 
 
16
  from transformers.models.clip.modeling_clip import (
17
  CLIPOutput,
18
  CLIPTextModelOutput,
19
  CLIPVisionModelOutput,
20
  clip_loss,
21
  )
 
22
  try:
23
  from tqdm.autonotebook import trange
24
 
@@ -226,6 +234,20 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
226
  self.preprocess = None
227
  self.post_init()
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def get_text_features(
230
  self,
231
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
@@ -248,11 +270,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
248
  )
249
  return self.visual_projection(self.vision_model(x=x))
250
 
251
- def get_tokenizer(self):
252
- if not self.tokenizer:
253
- self.tokenizer = AutoTokenizer.from_pretrained(self.config._name_or_path, trust_remote_code=True)
254
- return self.tokenizer
255
-
256
  @torch.inference_mode()
257
  def encode_text(
258
  self,
@@ -266,38 +283,41 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
266
  **tokenizer_kwargs,
267
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
268
  """
269
- Computes sentence embeddings
270
- Args:
271
- sentences(`str` or `List[str]`):
272
- Sentence or sentences to be encoded
273
- batch_size(`int`, *optional*, defaults to 32):
274
- Batch size for the computation
275
- show_progress_bar(`bool`, *optional*, defaults to None):
276
- Show a progress bar when encoding sentences.
277
- If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
278
- convert_to_numpy(`bool`, *optional*, defaults to True):
279
- If true, the output is a list of numpy vectors.
280
- Else, it is a list of pytorch tensors.
281
- convert_to_tensor(`bool`, *optional*, defaults to False):
282
- If true, you get one large tensor as return.
283
- Overwrites any setting from convert_to_numpy
284
- device(`torch.device`, *optional*, defaults to None):
285
- Which torch.device to use for the computation
286
- normalize_embeddings(`bool`, *optional*, defaults to False):
287
- If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
288
- tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
289
- Keyword arguments for the tokenizer
290
- Returns:
291
- By default, a list of tensors is returned.
292
- If convert_to_tensor, a stacked tensor is returned.
293
- If convert_to_numpy, a numpy matrix is returned.
 
 
 
294
  """
295
  is_training = self.training
296
  self.eval()
297
  all_embeddings = []
298
 
299
  self.tokenizer = self.get_tokenizer()
300
-
301
  if show_progress_bar is None:
302
  show_progress_bar = (
303
  logger.getEffectiveLevel() == logging.INFO
@@ -328,7 +348,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
328
  0,
329
  len(sentences),
330
  batch_size,
331
- desc="Encoding",
332
  disable=not show_progress_bar,
333
  )
334
  else:
@@ -361,13 +381,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
361
  self.train(is_training)
362
  return all_embeddings
363
 
364
-
365
- def get_preprocess(self):
366
- if not self.preprocess:
367
- self.preprocess = AutoImageProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
368
- return self.preprocess
369
-
370
-
371
  @torch.inference_mode()
372
  def encode_image(
373
  self,
@@ -389,7 +402,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
389
  Batch size for the computation
390
  show_progress_bar(`bool`, *optional*, defaults to None):
391
  Show a progress bar when encoding images.
392
- If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
 
393
  convert_to_numpy(`bool`, *optional*, defaults to True):
394
  If true, the output is a list of numpy vectors.
395
  Else, it is a list of pytorch tensors.
@@ -399,14 +413,16 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
399
  device(`torch.device`, *optional*, defaults to None):
400
  Which torch.device to use for the computation
401
  normalize_embeddings(`bool`, *optional*, defaults to False):
402
- If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
 
 
403
  Returns:
404
  By default, a list of tensors is returned.
405
  If convert_to_tensor, a stacked tensor is returned.
406
  If convert_to_numpy, a numpy matrix is returned.
407
  """
408
  from PIL import Image
409
-
410
  is_training = self.training
411
  self.eval()
412
 
@@ -439,13 +455,13 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
439
  0,
440
  len(images),
441
  batch_size,
442
- desc="Encoding",
443
  disable=not show_progress_bar,
444
  )
445
  else:
446
  range_iter = range(0, len(images), batch_size)
447
 
448
- for i in range_iter:
449
  processed_inputs = self.preprocess([Image.open(image) for image in images])
450
  embeddings = self.get_image_features(processed_inputs)
451
  if normalize_embeddings:
 
5
  # and adjusted for Jina CLIP
6
 
7
  from functools import partial
8
+ from typing import List, Optional, Tuple, Union
9
 
10
  import numpy as np
11
  import torch
12
  import torch.nn.functional as f
13
  import torch.utils.checkpoint
14
  from torch import nn
15
+ from transformers import (
16
+ AutoImageProcessor,
17
+ AutoTokenizer,
18
+ BatchEncoding,
19
+ BatchFeature,
20
+ PreTrainedModel,
21
+ logging,
22
+ )
23
  from transformers.models.clip.modeling_clip import (
24
  CLIPOutput,
25
  CLIPTextModelOutput,
26
  CLIPVisionModelOutput,
27
  clip_loss,
28
  )
29
+
30
  try:
31
  from tqdm.autonotebook import trange
32
 
 
234
  self.preprocess = None
235
  self.post_init()
236
 
237
+ def get_tokenizer(self):
238
+ if not self.tokenizer:
239
+ self.tokenizer = AutoTokenizer.from_pretrained(
240
+ self.config._name_or_path, trust_remote_code=True
241
+ )
242
+ return self.tokenizer
243
+
244
+ def get_preprocess(self):
245
+ if not self.preprocess:
246
+ self.preprocess = AutoImageProcessor.from_pretrained(
247
+ self.config._name_or_path, trust_remote_code=True
248
+ )
249
+ return self.preprocess
250
+
251
  def get_text_features(
252
  self,
253
  input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
 
270
  )
271
  return self.visual_projection(self.vision_model(x=x))
272
 
 
 
 
 
 
273
  @torch.inference_mode()
274
  def encode_text(
275
  self,
 
283
  **tokenizer_kwargs,
284
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
285
  """
286
+ Computes sentence embeddings
287
+ Args:
288
+ sentences(`str` or `List[str]`):
289
+ Sentence or sentences to be encoded
290
+ batch_size(`int`, *optional*, defaults to 32):
291
+ Batch size for the computation
292
+ show_progress_bar(`bool`, *optional*, defaults to None):
293
+ Show a progress bar when encoding sentences.
294
+ If set to None, progress bar is only shown when
295
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
296
+ convert_to_numpy(`bool`, *optional*, defaults to True):
297
+ If true, the output is a list of numpy vectors.
298
+ Else, it is a list of pytorch tensors.
299
+ convert_to_tensor(`bool`, *optional*, defaults to False):
300
+ If true, you get one large tensor as return.
301
+ Overwrites any setting from convert_to_numpy
302
+ device(`torch.device`, *optional*, defaults to None):
303
+ Which torch.device to use for the computation
304
+ normalize_embeddings(`bool`, *optional*, defaults to False):
305
+ If set to true, returned vectors will have length 1. In that case,
306
+ the faster dot-product (util.dot_score) instead of cosine similarity
307
+ can be used.
308
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
309
+ Keyword arguments for the tokenizer
310
+ Returns:
311
+ By default, a list of tensors is returned.
312
+ If convert_to_tensor, a stacked tensor is returned.
313
+ If convert_to_numpy, a numpy matrix is returned.
314
  """
315
  is_training = self.training
316
  self.eval()
317
  all_embeddings = []
318
 
319
  self.tokenizer = self.get_tokenizer()
320
+
321
  if show_progress_bar is None:
322
  show_progress_bar = (
323
  logger.getEffectiveLevel() == logging.INFO
 
348
  0,
349
  len(sentences),
350
  batch_size,
351
+ desc='Encoding',
352
  disable=not show_progress_bar,
353
  )
354
  else:
 
381
  self.train(is_training)
382
  return all_embeddings
383
 
 
 
 
 
 
 
 
384
  @torch.inference_mode()
385
  def encode_image(
386
  self,
 
402
  Batch size for the computation
403
  show_progress_bar(`bool`, *optional*, defaults to None):
404
  Show a progress bar when encoding images.
405
+ If set to None, progress bar is only shown when
406
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
407
  convert_to_numpy(`bool`, *optional*, defaults to True):
408
  If true, the output is a list of numpy vectors.
409
  Else, it is a list of pytorch tensors.
 
413
  device(`torch.device`, *optional*, defaults to None):
414
  Which torch.device to use for the computation
415
  normalize_embeddings(`bool`, *optional*, defaults to False):
416
+ If set to true, returned vectors will have length 1. In that case,
417
+ the faster dot-product (util.dot_score) instead of cosine similarity
418
+ can be used.
419
  Returns:
420
  By default, a list of tensors is returned.
421
  If convert_to_tensor, a stacked tensor is returned.
422
  If convert_to_numpy, a numpy matrix is returned.
423
  """
424
  from PIL import Image
425
+
426
  is_training = self.training
427
  self.eval()
428
 
 
455
  0,
456
  len(images),
457
  batch_size,
458
+ desc='Encoding',
459
  disable=not show_progress_bar,
460
  )
461
  else:
462
  range_iter = range(0, len(images), batch_size)
463
 
464
+ for _ in range_iter:
465
  processed_inputs = self.preprocess([Image.open(image) for image in images])
466
  embeddings = self.get_image_features(processed_inputs)
467
  if normalize_embeddings: