File size: 2,427 Bytes
97214bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List, Optional, Union

import torch
from PIL import Image
from transformers import BatchFeature

from .processing_florence2 import Florence2Processor

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor


class ColFlorProcessor(BaseVisualRetrieverProcessor, Florence2Processor):
    """

    Processor for ColPali.

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mock_image = Image.new("RGB", (16, 16), color="black")

    def process_images(

        self,

        images: List[Image.Image],

    ) -> BatchFeature:
        """

        Process images for ColFlor2.

        """
        texts_doc = ["<OCR>"] * len(images)
        images = [image.convert("RGB") for image in images]

        batch_doc = self(
            text=texts_doc,
            images=images,
            return_tensors="pt",
            padding="longest",
        )

        new_part = torch.ones((batch_doc['attention_mask'].size()[0], 577)).to(batch_doc['attention_mask'].device)
        batch_doc['full_attention_mask'] = torch.cat([new_part, batch_doc['attention_mask']], dim=1)

        return batch_doc

    def process_queries(

        self,

        queries: List[str],

        max_length: int = 50,

        suffix: Optional[str] = None,

    ) -> BatchFeature:
        """

        Process queries for ColFlor2.

        """
        if suffix is None:
            suffix = "<pad>" * 10
        texts_query: List[str] = []

        for query in queries:
            query = f"Question: {query}"
            query += suffix  # add suffix (pad tokens)
            texts_query.append(query)

        batch_query = self.tokenizer(
            #images=[self.mock_image] * len(texts_query),
            text=texts_query,
            return_tensors="pt",
            padding="longest",
            max_length= max_length + self.image_seq_length,
        )

        return batch_query

    def score(

        self,

        qs: List[torch.Tensor],

        ps: List[torch.Tensor],

        device: Optional[Union[str, torch.device]] = None,

        **kwargs,

    ) -> torch.Tensor:
        """

        Compute the MaxSim score (ColBERT-like) for the given multi-vector query and passage embeddings.

        """
        return self.score_multi_vector(qs, ps, device=device, **kwargs)