Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
from functools import cached_property | |
import torch | |
class VocabInfo: | |
def __init__(self, vocab_map: dict[str, int]): | |
self.name2val = vocab_map | |
self.bos_id = vocab_map.get("<s>") | |
self.eos_id = vocab_map.get("</s>") | |
self.boi_id = vocab_map.get("<racm3:break>") | |
self.eoi_id = vocab_map.get("<eoss>") | |
self.pad_id = vocab_map.get("<pad>") | |
self.eot_id = vocab_map.get("<reserved08706>") | |
def begin_sequence(self) -> int: | |
return self.bos_id | |
def end_sequence(self) -> int: | |
return self.eos_id | |
def begin_image(self) -> int: | |
return self.boi_id | |
def end_image(self) -> int: | |
return self.eoi_id | |
def padding(self) -> int: | |
return self.pad_id | |
def end_turn(self) -> int: | |
return self.eot_id | |
def val2name(self) -> dict[int, str]: | |
return {v: k for k, v in self.name2val.items()} | |
def all_tokens(self) -> list[int]: | |
return sorted(self.name2val.values()) | |
def image_tokens(self) -> list[int]: | |
return sorted( | |
[val for name, val in self.name2val.items() if name.startswith("IMGIMG")] | |
) | |
def special_tokens(self) -> list[int]: | |
return sorted( | |
[ | |
val | |
for name, val in self.name2val.items() | |
if name.startswith("<") and name != "<" | |
] | |
) | |
def text_tokens(self) -> list[int]: | |
return sorted( | |
set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens) | |
) | |
class VocabTranslation: | |
def __init__(self, vocab_info: VocabInfo, device: str | None = None): | |
self._vocab = vocab_info | |
self._device = device | |
def bpe2img(self) -> dict[int, int]: # vocab id => codebook id, i.e. [4:8195] => [0:8191] | |
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} # A-J: 0-9 | |
def remap(old_name: str) -> str: | |
return "".join( | |
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1] # last chr is 'Z' | |
) | |
# e.g.: IMGIMGFDZ => FD => 53, | |
return { | |
tok: int(remap(self._vocab.val2name[tok])) | |
for tok in self._vocab.image_tokens # the token starts with 'IMGIMG', value: [4: 8195] | |
} | |
def img2bpe(self) -> dict[int, int]: | |
return {v: k for k, v in self.bpe2img.items()} # codebook id => vocab id, i.e. [0:8191] => [4:8191] | |
def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: | |
sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device) | |
sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device) | |
return sorted_bpe, sorted_img | |
def img2bpe_mapping_tensor(self) -> torch.LongTensor: | |
mapping = torch.zeros( | |
max(self.img2bpe.keys()) + 1, | |
dtype=torch.int, | |
device=self._device, | |
) | |
for k, v in self.img2bpe.items(): | |
mapping[k] = v | |
return mapping | |
def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor: | |
bpe_tok, img_tok = self.bpe2img_search_tensors | |
return img_tok[torch.searchsorted(bpe_tok, bpe_batch)] | |
def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor: | |
return self.img2bpe_mapping_tensor[img_batch] | |