Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,427 Bytes
97214bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from typing import List, Optional, Union
import torch
from PIL import Image
from transformers import BatchFeature
from .processing_florence2 import Florence2Processor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
class ColFlorProcessor(BaseVisualRetrieverProcessor, Florence2Processor):
"""
Processor for ColPali.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.mock_image = Image.new("RGB", (16, 16), color="black")
def process_images(
self,
images: List[Image.Image],
) -> BatchFeature:
"""
Process images for ColFlor2.
"""
texts_doc = ["<OCR>"] * len(images)
images = [image.convert("RGB") for image in images]
batch_doc = self(
text=texts_doc,
images=images,
return_tensors="pt",
padding="longest",
)
new_part = torch.ones((batch_doc['attention_mask'].size()[0], 577)).to(batch_doc['attention_mask'].device)
batch_doc['full_attention_mask'] = torch.cat([new_part, batch_doc['attention_mask']], dim=1)
return batch_doc
def process_queries(
self,
queries: List[str],
max_length: int = 50,
suffix: Optional[str] = None,
) -> BatchFeature:
"""
Process queries for ColFlor2.
"""
if suffix is None:
suffix = "<pad>" * 10
texts_query: List[str] = []
for query in queries:
query = f"Question: {query}"
query += suffix # add suffix (pad tokens)
texts_query.append(query)
batch_query = self.tokenizer(
#images=[self.mock_image] * len(texts_query),
text=texts_query,
return_tensors="pt",
padding="longest",
max_length= max_length + self.image_seq_length,
)
return batch_query
def score(
self,
qs: List[torch.Tensor],
ps: List[torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.
"""
return self.score_multi_vector(qs, ps, device=device, **kwargs) |