bwang0911 commited on
Commit
47604ae
1 Parent(s): b3163bd

refactor: add docstrings

Browse files
Files changed (1) hide show
  1. modeling_clip.py +31 -3
modeling_clip.py CHANGED
@@ -391,7 +391,33 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
391
  device: Optional[torch.device] = None,
392
  normalize_embeddings: bool = False,
393
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
394
- from PIL.Image import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
  is_training = self.training
397
  self.eval()
@@ -422,17 +448,19 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
422
  if has_tqdm:
423
  range_iter = trange(
424
  0,
425
- len(sentences),
426
  batch_size,
427
  desc="Encoding",
428
  disable=not show_progress_bar,
429
  )
430
  else:
431
- range_iter = range(0, len(sentences), batch_size)
432
 
433
  for i in range_iter:
434
  processed_inputs = self.process([Image.open(image) for image in images])
435
  embeddings = self.get_image_features(processed_inputs)
 
 
436
  if convert_to_numpy:
437
  embeddings = embeddings.cpu()
438
  all_embeddings.extend(embeddings)
 
391
  device: Optional[torch.device] = None,
392
  normalize_embeddings: bool = False,
393
  ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
394
+ """
395
+ Computes image embeddings.
396
+
397
+ Args:
398
+ images(`str` or `List[str]`):
399
+ image or images paths to be encoded
400
+ batch_size(`int`, *optional*, defaults to 32):
401
+ Batch size for the computation
402
+ show_progress_bar(`bool`, *optional*, defaults to None):
403
+ Show a progress bar when encoding images.
404
+ If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
405
+ convert_to_numpy(`bool`, *optional*, defaults to True):
406
+ If true, the output is a list of numpy vectors.
407
+ Else, it is a list of pytorch tensors.
408
+ convert_to_tensor(`bool`, *optional*, defaults to False):
409
+ If true, you get one large tensor as return.
410
+ Overwrites any setting from convert_to_numpy
411
+ device(`torch.device`, *optional*, defaults to None):
412
+ Which torch.device to use for the computation
413
+ normalize_embeddings(`bool`, *optional*, defaults to False):
414
+ If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
415
+ Returns:
416
+ By default, a list of tensors is returned.
417
+ If convert_to_tensor, a stacked tensor is returned.
418
+ If convert_to_numpy, a numpy matrix is returned.
419
+ """
420
+ from PIL.Image import Image
421
 
422
  is_training = self.training
423
  self.eval()
 
448
  if has_tqdm:
449
  range_iter = trange(
450
  0,
451
+ len(images),
452
  batch_size,
453
  desc="Encoding",
454
  disable=not show_progress_bar,
455
  )
456
  else:
457
+ range_iter = range(0, len(images), batch_size)
458
 
459
  for i in range_iter:
460
  processed_inputs = self.process([Image.open(image) for image in images])
461
  embeddings = self.get_image_features(processed_inputs)
462
+ if normalize_embeddings:
463
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
464
  if convert_to_numpy:
465
  embeddings = embeddings.cpu()
466
  all_embeddings.extend(embeddings)