refactor: refine encode_text
Browse files- modeling_clip.py +90 -10
modeling_clip.py
CHANGED
@@ -18,6 +18,12 @@ from transformers.models.clip.modeling_clip import (
|
|
18 |
CLIPVisionModelOutput,
|
19 |
clip_loss,
|
20 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
23 |
from .eva_model import EVAVisionTransformer
|
@@ -215,6 +221,7 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
215 |
self.visual_projection = nn.Identity()
|
216 |
self.text_projection = nn.Identity()
|
217 |
|
|
|
218 |
self.post_init()
|
219 |
|
220 |
def get_text_features(
|
@@ -239,19 +246,92 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
239 |
)
|
240 |
return self.visual_projection(self.vision_model(x=x))
|
241 |
|
|
|
242 |
def encode_text(
|
243 |
self,
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
248 |
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
def encode_image(
|
257 |
self,
|
|
|
18 |
CLIPVisionModelOutput,
|
19 |
clip_loss,
|
20 |
)
|
21 |
+
try:
|
22 |
+
from tqdm.autonotebook import trange
|
23 |
+
|
24 |
+
has_tqdm = True
|
25 |
+
except ImportError:
|
26 |
+
has_tqdm = False
|
27 |
|
28 |
from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
|
29 |
from .eva_model import EVAVisionTransformer
|
|
|
221 |
self.visual_projection = nn.Identity()
|
222 |
self.text_projection = nn.Identity()
|
223 |
|
224 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
|
225 |
self.post_init()
|
226 |
|
227 |
def get_text_features(
|
|
|
246 |
)
|
247 |
return self.visual_projection(self.vision_model(x=x))
|
248 |
|
249 |
+
@torch.inference_mode()
|
250 |
def encode_text(
|
251 |
self,
|
252 |
+
sentences: Union[str, List[str]],
|
253 |
+
batch_size: int = 32,
|
254 |
+
show_progress_bar: Optional[bool] = None,
|
255 |
+
output_value: str = 'sentence_embedding',
|
256 |
+
convert_to_numpy: bool = True,
|
257 |
+
convert_to_tensor: bool = False,
|
258 |
+
device: Optional[torch.device] = None,
|
259 |
+
normalize_embeddings: bool = False,
|
260 |
+
**tokenizer_kwargs,
|
261 |
) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
|
262 |
+
|
263 |
+
self.eval()
|
264 |
+
|
265 |
+
if show_progress_bar is None:
|
266 |
+
show_progress_bar = (
|
267 |
+
logger.getEffectiveLevel() == logging.INFO
|
268 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
269 |
+
)
|
270 |
+
|
271 |
+
if convert_to_tensor:
|
272 |
+
convert_to_numpy = False
|
273 |
+
|
274 |
+
if output_value != 'sentence_embedding':
|
275 |
+
convert_to_tensor = False
|
276 |
+
convert_to_numpy = False
|
277 |
+
|
278 |
+
input_was_string = False
|
279 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
280 |
+
sentences = [sentences]
|
281 |
+
input_was_string = True
|
282 |
+
|
283 |
+
if device is not None:
|
284 |
+
self.to(device)
|
285 |
+
|
286 |
+
permutation = np.argsort([-len(i) for i in sentences])
|
287 |
+
inverse_permutation = np.argsort(permutation)
|
288 |
+
sentences = [sentences[idx] for idx in permutation]
|
289 |
+
|
290 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
291 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
|
292 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
293 |
+
|
294 |
+
if has_tqdm:
|
295 |
+
range_iter = trange(
|
296 |
+
0,
|
297 |
+
len(sentences),
|
298 |
+
batch_size,
|
299 |
+
desc="Encoding",
|
300 |
+
disable=not show_progress_bar,
|
301 |
+
)
|
302 |
+
else:
|
303 |
+
range_iter = range(0, len(sentences), batch_size)
|
304 |
+
|
305 |
+
for i in range_iter:
|
306 |
+
encoded_input = self.tokenizer(
|
307 |
+
sentences[i : i + batch_size],
|
308 |
+
return_tensors='pt',
|
309 |
+
**tokenizer_kwargs,
|
310 |
+
).to(self.device)
|
311 |
+
|
312 |
+
if output_value == 'token_embeddings':
|
313 |
+
raise NotImplementedError
|
314 |
+
elif output_value is None:
|
315 |
+
raise NotImplementedError
|
316 |
+
else:
|
317 |
+
embeddings = self.get_text_features(input_ids=encoded_input)
|
318 |
+
if normalize_embeddings:
|
319 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
320 |
+
if convert_to_numpy:
|
321 |
+
embeddings = embeddings.cpu()
|
322 |
+
all_embeddings.extend(embeddings)
|
323 |
+
|
324 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
325 |
+
|
326 |
+
if convert_to_tensor:
|
327 |
+
all_embeddings = torch.stack(all_embeddings)
|
328 |
+
elif convert_to_numpy:
|
329 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
330 |
+
|
331 |
+
if input_was_string:
|
332 |
+
all_embeddings = all_embeddings[0]
|
333 |
+
|
334 |
+
return all_embeddings
|
335 |
|
336 |
def encode_image(
|
337 |
self,
|