feat-inference-mode
#1
by
bwang0911
- opened
- modeling_clip.py +221 -23
modeling_clip.py
CHANGED
@@ -5,8 +5,9 @@
|
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
from functools import partial
|
8 |
-
from typing import Optional, Tuple, Union
|
9 |
|
|
|
10 |
import torch
|
11 |
import torch.nn.functional as f
|
12 |
import torch.utils.checkpoint
|
@@ -18,6 +19,12 @@ from transformers.models.clip.modeling_clip import (
|
|
18 |
CLIPVisionModelOutput,
|
19 |
clip_loss,
|
20 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
23 |
from .eva_model import EVAVisionTransformer
|
@@ -215,6 +222,8 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
215 |
self.visual_projection = nn.Identity()
|
216 |
self.text_projection = nn.Identity()
|
217 |
|
|
|
|
|
218 |
self.post_init()
|
219 |
|
220 |
def get_text_features(
|
@@ -239,33 +248,222 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
239 |
)
|
240 |
return self.visual_projection(self.vision_model(x=x))
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
def encode_text(
|
243 |
self,
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
def encode_image(
|
257 |
self,
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
def forward(
|
271 |
self,
|
|
|
5 |
# and adjusted for Jina CLIP
|
6 |
|
7 |
from functools import partial
|
8 |
+
from typing import Optional, Tuple, Union, List
|
9 |
|
10 |
+
import numpy as np
|
11 |
import torch
|
12 |
import torch.nn.functional as f
|
13 |
import torch.utils.checkpoint
|
|
|
19 |
CLIPVisionModelOutput,
|
20 |
clip_loss,
|
21 |
)
|
22 |
+
try:
|
23 |
+
from tqdm.autonotebook import trange
|
24 |
+
|
25 |
+
has_tqdm = True
|
26 |
+
except ImportError:
|
27 |
+
has_tqdm = False
|
28 |
|
29 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
30 |
from .eva_model import EVAVisionTransformer
|
|
|
222 |
self.visual_projection = nn.Identity()
|
223 |
self.text_projection = nn.Identity()
|
224 |
|
225 |
+
self.tokenizer = None
|
226 |
+
self.preprocess = None
|
227 |
self.post_init()
|
228 |
|
229 |
def get_text_features(
|
|
|
248 |
)
|
249 |
return self.visual_projection(self.vision_model(x=x))
|
250 |
|
251 |
+
def get_tokenizer(self):
|
252 |
+
if not self.tokenizer:
|
253 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True)
|
254 |
+
return self.tokenizer
|
255 |
+
|
256 |
+
@torch.inference_mode()
|
257 |
def encode_text(
|
258 |
self,
|
259 |
+
sentences: Union[str, List[str]],
|
260 |
+
batch_size: int = 32,
|
261 |
+
show_progress_bar: Optional[bool] = None,
|
262 |
+
convert_to_numpy: bool = True,
|
263 |
+
convert_to_tensor: bool = False,
|
264 |
+
device: Optional[torch.device] = None,
|
265 |
+
normalize_embeddings: bool = False,
|
266 |
+
**tokenizer_kwargs,
|
267 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
268 |
+
"""
|
269 |
+
Computes sentence embeddings
|
270 |
+
Args:
|
271 |
+
sentences(`str` or `List[str]`):
|
272 |
+
Sentence or sentences to be encoded
|
273 |
+
batch_size(`int`, *optional*, defaults to 32):
|
274 |
+
Batch size for the computation
|
275 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
276 |
+
Show a progress bar when encoding sentences.
|
277 |
+
If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
278 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
279 |
+
If true, the output is a list of numpy vectors.
|
280 |
+
Else, it is a list of pytorch tensors.
|
281 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
282 |
+
If true, you get one large tensor as return.
|
283 |
+
Overwrites any setting from convert_to_numpy
|
284 |
+
device(`torch.device`, *optional*, defaults to None):
|
285 |
+
Which torch.device to use for the computation
|
286 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
287 |
+
If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
|
288 |
+
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
289 |
+
Keyword arguments for the tokenizer
|
290 |
+
Returns:
|
291 |
+
By default, a list of tensors is returned.
|
292 |
+
If convert_to_tensor, a stacked tensor is returned.
|
293 |
+
If convert_to_numpy, a numpy matrix is returned.
|
294 |
+
"""
|
295 |
+
is_training = self.training
|
296 |
+
self.eval()
|
297 |
+
|
298 |
+
self.tokenizer = self.get_tokenizer()
|
299 |
+
|
300 |
+
if show_progress_bar is None:
|
301 |
+
show_progress_bar = (
|
302 |
+
logger.getEffectiveLevel() == logging.INFO
|
303 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
304 |
+
)
|
305 |
+
|
306 |
+
if convert_to_tensor:
|
307 |
+
convert_to_numpy = False
|
308 |
+
|
309 |
+
input_was_string = False
|
310 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
311 |
+
sentences = [sentences]
|
312 |
+
input_was_string = True
|
313 |
+
|
314 |
+
if device is not None:
|
315 |
+
self.to(device)
|
316 |
+
|
317 |
+
permutation = np.argsort([-len(i) for i in sentences])
|
318 |
+
inverse_permutation = np.argsort(permutation)
|
319 |
+
sentences = [sentences[idx] for idx in permutation]
|
320 |
|
321 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
322 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
|
323 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
324 |
+
|
325 |
+
if has_tqdm:
|
326 |
+
range_iter = trange(
|
327 |
+
0,
|
328 |
+
len(sentences),
|
329 |
+
batch_size,
|
330 |
+
desc="Encoding",
|
331 |
+
disable=not show_progress_bar,
|
332 |
+
)
|
333 |
+
else:
|
334 |
+
range_iter = range(0, len(sentences), batch_size)
|
335 |
+
|
336 |
+
for i in range_iter:
|
337 |
+
encoded_input = self.tokenizer(
|
338 |
+
sentences[i : i + batch_size],
|
339 |
+
return_tensors='pt',
|
340 |
+
**tokenizer_kwargs,
|
341 |
+
).to(self.device)
|
342 |
+
|
343 |
+
embeddings = self.get_text_features(input_ids=encoded_input)
|
344 |
+
if normalize_embeddings:
|
345 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
346 |
+
if convert_to_numpy:
|
347 |
+
embeddings = embeddings.cpu()
|
348 |
+
all_embeddings.extend(embeddings)
|
349 |
+
|
350 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
351 |
+
|
352 |
+
if convert_to_tensor:
|
353 |
+
all_embeddings = torch.stack(all_embeddings)
|
354 |
+
elif convert_to_numpy:
|
355 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
356 |
+
|
357 |
+
if input_was_string:
|
358 |
+
all_embeddings = all_embeddings[0]
|
359 |
+
|
360 |
+
self.train(is_training)
|
361 |
+
return all_embeddings
|
362 |
+
|
363 |
+
|
364 |
+
def get_preprocess(self):
|
365 |
+
if not self.preprocess:
|
366 |
+
self.preprocess = AutoImageProcessor.from_pretrained(config._name_or_path, trust_remote_code=True)
|
367 |
+
return self.preprocess
|
368 |
+
|
369 |
+
|
370 |
+
@torch.inference_mode()
|
371 |
def encode_image(
|
372 |
self,
|
373 |
+
images: Union[str, List[str]],
|
374 |
+
batch_size: int = 32,
|
375 |
+
show_progress_bar: Optional[bool] = None,
|
376 |
+
convert_to_numpy: bool = True,
|
377 |
+
convert_to_tensor: bool = False,
|
378 |
+
device: Optional[torch.device] = None,
|
379 |
+
normalize_embeddings: bool = False,
|
380 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
381 |
+
"""
|
382 |
+
Computes image embeddings.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
images(`str` or `List[str]`):
|
386 |
+
image or images paths to be encoded
|
387 |
+
batch_size(`int`, *optional*, defaults to 32):
|
388 |
+
Batch size for the computation
|
389 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
390 |
+
Show a progress bar when encoding images.
|
391 |
+
If set to None, progress bar is only shown when `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
392 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
393 |
+
If true, the output is a list of numpy vectors.
|
394 |
+
Else, it is a list of pytorch tensors.
|
395 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
396 |
+
If true, you get one large tensor as return.
|
397 |
+
Overwrites any setting from convert_to_numpy
|
398 |
+
device(`torch.device`, *optional*, defaults to None):
|
399 |
+
Which torch.device to use for the computation
|
400 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
401 |
+
If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.
|
402 |
+
Returns:
|
403 |
+
By default, a list of tensors is returned.
|
404 |
+
If convert_to_tensor, a stacked tensor is returned.
|
405 |
+
If convert_to_numpy, a numpy matrix is returned.
|
406 |
+
"""
|
407 |
+
from PIL.Image import Image
|
408 |
+
|
409 |
+
is_training = self.training
|
410 |
+
self.eval()
|
411 |
+
|
412 |
+
self.preprocess = self.get_preprocess()
|
413 |
+
|
414 |
+
if show_progress_bar is None:
|
415 |
+
show_progress_bar = (
|
416 |
+
logger.getEffectiveLevel() == logging.INFO
|
417 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
418 |
+
)
|
419 |
+
|
420 |
+
if convert_to_tensor:
|
421 |
+
convert_to_numpy = False
|
422 |
+
|
423 |
+
input_was_single_img = False
|
424 |
+
if isinstance(images, str) or not hasattr(images, '__len__'):
|
425 |
+
images = [images]
|
426 |
+
input_was_single_img = True
|
427 |
+
|
428 |
+
if device is not None:
|
429 |
+
self.to(device)
|
430 |
+
|
431 |
+
permutation = np.argsort([-len(i) for i in images])
|
432 |
+
inverse_permutation = np.argsort(permutation)
|
433 |
+
images = [images[idx] for idx in permutation]
|
434 |
+
|
435 |
+
if has_tqdm:
|
436 |
+
range_iter = trange(
|
437 |
+
0,
|
438 |
+
len(images),
|
439 |
+
batch_size,
|
440 |
+
desc="Encoding",
|
441 |
+
disable=not show_progress_bar,
|
442 |
+
)
|
443 |
+
else:
|
444 |
+
range_iter = range(0, len(images), batch_size)
|
445 |
+
|
446 |
+
for i in range_iter:
|
447 |
+
processed_inputs = self.process([Image.open(image) for image in images])
|
448 |
+
embeddings = self.get_image_features(processed_inputs)
|
449 |
+
if normalize_embeddings:
|
450 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
451 |
+
if convert_to_numpy:
|
452 |
+
embeddings = embeddings.cpu()
|
453 |
+
all_embeddings.extend(embeddings)
|
454 |
+
|
455 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
456 |
+
|
457 |
+
if convert_to_tensor:
|
458 |
+
all_embeddings = torch.stack(all_embeddings)
|
459 |
+
elif convert_to_numpy:
|
460 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
461 |
+
|
462 |
+
if input_was_single_img:
|
463 |
+
all_embeddings = all_embeddings[0]
|
464 |
+
|
465 |
+
self.train(is_training)
|
466 |
+
return all_embeddings
|
467 |
|
468 |
def forward(
|
469 |
self,
|