Files changed (1) hide show
  1. modeling_clip.py +221 -23
modeling_clip.py CHANGED
@@ -5,8 +5,9 @@
5
  # and adjusted for Jina CLIP
6
 
7
  from functools import partial
8
- from typing import Optional, Tuple, Union
9
 
 
10
  import torch
11
  import torch.nn.functional as f
12
  import torch.utils.checkpoint
@@ -18,6 +19,12 @@ from transformers.models.clip.modeling_clip import (
18
  CLIPVisionModelOutput,
19
  clip_loss,
20
  )
 
 
 
 
 
 
21
 
22
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
23
  from .eva_model import EVAVisionTransformer
@@ -215,6 +222,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
215
  self.visual_projection = nn.Identity()
216
  self.text_projection = nn.Identity()
217
 
 
 
218
  self.post_init()
219
 
220
  def get_text_features(
@@ -239,33 +248,222 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
239
  )
240
  return self.visual_projection(self.vision_model(x=x))
241
 
 
 
 
 
 
 
242
  def encode_text(
243
  self,
244
- input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
245
- return_dict: Optional[bool] = None,
246
- *_,
247
- **__,
248
- ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
249
- return_dict = (
250
- return_dict if return_dict is not None else self.config.use_return_dict
251
- )
252
- feats = self.get_text_features(input_ids=input_ids)
253
- out = CLIPTextModelOutput(text_embeds=feats)
254
- return out if return_dict else out.to_tuple()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  def encode_image(
257
  self,
258
- pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
259
- return_dict: Optional[bool] = None,
260
- *_,
261
- **__,
262
- ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
263
- return_dict = (
264
- return_dict if return_dict is not None else self.config.use_return_dict
265
- )
266
- feats = self.get_image_features(pixel_values=pixel_values)
267
- out = CLIPVisionModelOutput(image_embeds=feats)
268
- return out if return_dict else out.to_tuple()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  def forward(
271
  self,
 
5
  # and adjusted for Jina CLIP
6
 
7
  from functools import partial
8
+ from typing import Optional, Tuple, Union, List
9
 
10
+ import numpy as np
11
  import torch
12
  import torch.nn.functional as f
13
  import torch.utils.checkpoint
 
19
  CLIPVisionModelOutput,
20
  clip_loss,
21
  )
22
+ try:
23
+ from tqdm.autonotebook import trange
24
+
25
+ has_tqdm = True
26
+ except ImportError:
27
+ has_tqdm = False
28
 
29
  from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
30
  from .eva_model import EVAVisionTransformer
 
222
  self.visual_projection = nn.Identity()
223
  self.text_projection = nn.Identity()
224
 
225
+ self.tokenizer = None
226
+ self.preprocess = None
227
  self.post_init()
228
 
229
  def get_text_features(
 
248
  )
249
  return self.visual_projection(self.vision_model(x=x))
250
 
251
+ def get_tokenizer(self):
252
+ if not self.tokenizer:
253
+ self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
254
+ return self.tokenizer
255
+
256
+ @torch.inference_mode()
257
  def encode_text(
258
  self,
259
+ sentences: Union[str, List[str]],
260
+ batch_size: int = 32,
261
+ show_progress_bar: Optional[bool] = None,
262
+ convert_to_numpy: bool = True,
263
+ convert_to_tensor: bool = False,
264
+ device: Optional[torch.device] = None,
265
+ normalize_embeddings: bool = False,
266
+ **tokenizer_kwargs,
267
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
268
+ """
269
+ Computes sentence embeddings
270
+ Args:
271
+ sentences(`str` or `List[str]`):
272
+ Sentence or sentences to be encoded
273
+ batch_size(`int`, *optional*, defaults to 32):
274
+ Batch size for the computation
275
+ show_progress_bar(`bool`, *optional*, defaults to None):
276
+ Show a progress bar when encoding sentences.
277
+ If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
278
+ convert_to_numpy(`bool`, *optional*, defaults to True):
279
+ If true, the output is a list of numpy vectors.
280
+ Else, it is a list of pytorch tensors.
281
+ convert_to_tensor(`bool`, *optional*, defaults to False):
282
+ If true, you get one large tensor as return.
283
+ Overwrites any setting from convert_to_numpy
284
+ device(`torch.device`, *optional*, defaults to None):
285
+ Which torch.device to use for the computation
286
+ normalize_embeddings(`bool`, *optional*, defaults to False):
287
+ If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
288
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
289
+ Keyword arguments for the tokenizer
290
+ Returns:
291
+ By default, a list of tensors is returned.
292
+ If convert_to_tensor, a stacked tensor is returned.
293
+ If convert_to_numpy, a numpy matrix is returned.
294
+ """
295
+ is_training = self.training
296
+ self.eval()
297
+
298
+ self.tokenizer = self.get_tokenizer()
299
+
300
+ if show_progress_bar is None:
301
+ show_progress_bar = (
302
+ logger.getEffectiveLevel() == logging.INFO
303
+ or logger.getEffectiveLevel() == logging.DEBUG
304
+ )
305
+
306
+ if convert_to_tensor:
307
+ convert_to_numpy = False
308
+
309
+ input_was_string = False
310
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
311
+ sentences = [sentences]
312
+ input_was_string = True
313
+
314
+ if device is not None:
315
+ self.to(device)
316
+
317
+ permutation = np.argsort([-len(i) for i in sentences])
318
+ inverse_permutation = np.argsort(permutation)
319
+ sentences = [sentences[idx] for idx in permutation]
320
 
321
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
322
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
323
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
324
+
325
+ if has_tqdm:
326
+ range_iter = trange(
327
+ 0,
328
+ len(sentences),
329
+ batch_size,
330
+ desc="Encoding",
331
+ disable=not show_progress_bar,
332
+ )
333
+ else:
334
+ range_iter = range(0, len(sentences), batch_size)
335
+
336
+ for i in range_iter:
337
+ encoded_input = self.tokenizer(
338
+ sentences[i : i + batch_size],
339
+ return_tensors='pt',
340
+ **tokenizer_kwargs,
341
+ ).to(self.device)
342
+
343
+ embeddings = self.get_text_features(input_ids=encoded_input)
344
+ if normalize_embeddings:
345
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
346
+ if convert_to_numpy:
347
+ embeddings = embeddings.cpu()
348
+ all_embeddings.extend(embeddings)
349
+
350
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
351
+
352
+ if convert_to_tensor:
353
+ all_embeddings = torch.stack(all_embeddings)
354
+ elif convert_to_numpy:
355
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
356
+
357
+ if input_was_string:
358
+ all_embeddings = all_embeddings[0]
359
+
360
+ self.train(is_training)
361
+ return all_embeddings
362
+
363
+
364
+ def get_preprocess(self):
365
+ if not self.preprocess:
366
+ self.preprocess = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)
367
+ return self.preprocess
368
+
369
+
370
+ @torch.inference_mode()
371
  def encode_image(
372
  self,
373
+ images: Union[str, List[str]],
374
+ batch_size: int = 32,
375
+ show_progress_bar: Optional[bool] = None,
376
+ convert_to_numpy: bool = True,
377
+ convert_to_tensor: bool = False,
378
+ device: Optional[torch.device] = None,
379
+ normalize_embeddings: bool = False,
380
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
381
+ """
382
+ Computes image embeddings.
383
+
384
+ Args:
385
+ images(`str` or `List[str]`):
386
+ image or images paths to be encoded
387
+ batch_size(`int`, *optional*, defaults to 32):
388
+ Batch size for the computation
389
+ show_progress_bar(`bool`, *optional*, defaults to None):
390
+ Show a progress bar when encoding images.
391
+ If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
392
+ convert_to_numpy(`bool`, *optional*, defaults to True):
393
+ If true, the output is a list of numpy vectors.
394
+ Else, it is a list of pytorch tensors.
395
+ convert_to_tensor(`bool`, *optional*, defaults to False):
396
+ If true, you get one large tensor as return.
397
+ Overwrites any setting from convert_to_numpy
398
+ device(`torch.device`, *optional*, defaults to None):
399
+ Which torch.device to use for the computation
400
+ normalize_embeddings(`bool`, *optional*, defaults to False):
401
+ If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
402
+ Returns:
403
+ By default, a list of tensors is returned.
404
+ If convert_to_tensor, a stacked tensor is returned.
405
+ If convert_to_numpy, a numpy matrix is returned.
406
+ """
407
+ from PIL.Image import Image
408
+
409
+ is_training = self.training
410
+ self.eval()
411
+
412
+ self.preprocess = self.get_preprocess()
413
+
414
+ if show_progress_bar is None:
415
+ show_progress_bar = (
416
+ logger.getEffectiveLevel() == logging.INFO
417
+ or logger.getEffectiveLevel() == logging.DEBUG
418
+ )
419
+
420
+ if convert_to_tensor:
421
+ convert_to_numpy = False
422
+
423
+ input_was_single_img = False
424
+ if isinstance(images, str) or not hasattr(images, '__len__'):
425
+ images = [images]
426
+ input_was_single_img = True
427
+
428
+ if device is not None:
429
+ self.to(device)
430
+
431
+ permutation = np.argsort([-len(i) for i in images])
432
+ inverse_permutation = np.argsort(permutation)
433
+ images = [images[idx] for idx in permutation]
434
+
435
+ if has_tqdm:
436
+ range_iter = trange(
437
+ 0,
438
+ len(images),
439
+ batch_size,
440
+ desc="Encoding",
441
+ disable=not show_progress_bar,
442
+ )
443
+ else:
444
+ range_iter = range(0, len(images), batch_size)
445
+
446
+ for i in range_iter:
447
+ processed_inputs = self.process([Image.open(image) for image in images])
448
+ embeddings = self.get_image_features(processed_inputs)
449
+ if normalize_embeddings:
450
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
451
+ if convert_to_numpy:
452
+ embeddings = embeddings.cpu()
453
+ all_embeddings.extend(embeddings)
454
+
455
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
456
+
457
+ if convert_to_tensor:
458
+ all_embeddings = torch.stack(all_embeddings)
459
+ elif convert_to_numpy:
460
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
461
+
462
+ if input_was_single_img:
463
+ all_embeddings = all_embeddings[0]
464
+
465
+ self.train(is_training)
466
+ return all_embeddings
467
 
468
  def forward(
469
  self,