amaye15 commited on
Commit
59f3026
·
1 Parent(s): 2c8e3a0

Optimised handler V2

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -304,7 +304,7 @@ class EndpointHandler:
304
  with torch.no_grad():
305
  for batch in text_loader:
306
  batch_texts = batch[0].to(self.device, non_blocking=True)
307
- with torch.cuda.amp.autocast():
308
  embeddings = self.model(**batch_texts)
309
  all_embeddings.append(embeddings)
310
  text_embeddings = torch.cat(all_embeddings, dim=0)
@@ -388,7 +388,7 @@ class EndpointHandler:
388
  if image_embeddings is not None and text_embeddings is not None:
389
  self.logger.info("Computing similarity scores.")
390
  try:
391
- with torch.no_grad(), torch.cuda.amp.autocast():
392
  scores = self.processor.score_multi_vector(
393
  text_embeddings, image_embeddings
394
  )
 
304
  with torch.no_grad():
305
  for batch in text_loader:
306
  batch_texts = batch[0].to(self.device, non_blocking=True)
307
+ with torch.amp.autocast():
308
  embeddings = self.model(**batch_texts)
309
  all_embeddings.append(embeddings)
310
  text_embeddings = torch.cat(all_embeddings, dim=0)
 
388
  if image_embeddings is not None and text_embeddings is not None:
389
  self.logger.info("Computing similarity scores.")
390
  try:
391
+ with torch.no_grad(), torch.amp.autocast():
392
  scores = self.processor.score_multi_vector(
393
  text_embeddings, image_embeddings
394
  )