File size: 3,822 Bytes
7362797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.

from functools import cached_property

import torch


class VocabInfo:
    def __init__(self, vocab_map: dict[str, int]):
        self.name2val = vocab_map

        self.bos_id = vocab_map.get("<s>")
        self.eos_id = vocab_map.get("</s>")
        self.boi_id = vocab_map.get("<racm3:break>")
        self.eoi_id = vocab_map.get("<eoss>")
        self.pad_id = vocab_map.get("<pad>")
        self.eot_id = vocab_map.get("<reserved08706>")

    @property
    def begin_sequence(self) -> int:
        return self.bos_id

    @property
    def end_sequence(self) -> int:
        return self.eos_id

    @property
    def begin_image(self) -> int:
        return self.boi_id

    @property
    def end_image(self) -> int:
        return self.eoi_id

    @property
    def padding(self) -> int:
        return self.pad_id

    @property
    def end_turn(self) -> int:
        return self.eot_id

    @cached_property
    def val2name(self) -> dict[int, str]:
        return {v: k for k, v in self.name2val.items()}

    @cached_property
    def all_tokens(self) -> list[int]:
        return sorted(self.name2val.values())

    @cached_property
    def image_tokens(self) -> list[int]:
        return sorted(
            [val for name, val in self.name2val.items() if name.startswith("IMGIMG")]
        )

    @cached_property
    def special_tokens(self) -> list[int]:
        return sorted(
            [
                val
                for name, val in self.name2val.items()
                if name.startswith("<") and name != "<"
            ]
        )

    @cached_property
    def text_tokens(self) -> list[int]:
        return sorted(
            set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)
        )


class VocabTranslation:
    def __init__(self, vocab_info: VocabInfo, device: str | None = None):
        self._vocab = vocab_info
        self._device = device

    @cached_property
    def bpe2img(self) -> dict[int, int]:    # vocab id => codebook id, i.e. [4:8195] => [0:8191]
        img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}    # A-J: 0-9

        def remap(old_name: str) -> str:
            return "".join(
                img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1] # last chr is 'Z'
            )
            # e.g.: IMGIMGFDZ => FD => 53,

        return {
            tok: int(remap(self._vocab.val2name[tok]))
            for tok in self._vocab.image_tokens # the token starts with 'IMGIMG', value: [4: 8195]
        }

    @cached_property
    def img2bpe(self) -> dict[int, int]:
        return {v: k for k, v in self.bpe2img.items()}  # codebook id => vocab id, i.e. [0:8191] => [4:8191]

    @cached_property
    def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
        sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device)
        sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device)
        return sorted_bpe, sorted_img

    @cached_property
    def img2bpe_mapping_tensor(self) -> torch.LongTensor:
        mapping = torch.zeros(
            max(self.img2bpe.keys()) + 1,
            dtype=torch.int,
            device=self._device,
        )
        for k, v in self.img2bpe.items():
            mapping[k] = v
        return mapping

    def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor:
        bpe_tok, img_tok = self.bpe2img_search_tensors
        return img_tok[torch.searchsorted(bpe_tok, bpe_batch)]

    def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor:
        return self.img2bpe_mapping_tensor[img_batch]