Anole / chameleon /inference /chameleon.py
xuefengli
update
7362797
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import base64
import io
import json
import math
import queue
import threading
from dataclasses import dataclass, field
from tqdm import tqdm
from enum import Enum
from multiprocessing import managers, queues, synchronize
from typing import Literal, Union
import PIL
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from PIL.Image import Image
from tokenizers import Tokenizer
from transformers import (
LogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopPLogitsWarper,
enable_full_determinism,
)
from chameleon.inference import loader
from chameleon.inference.alignment import AlignPromptRight
from chameleon.inference.generation import ChameleonGenerator
from chameleon.inference.image_tokenizer import ImageTokenizer
from chameleon.inference.logits_processor import (
AllowOnlyTokensLogitsProcessor,
DisallowTokensAtOrAfterIndexLogitsProcessor,
InBatchInstructCFGLogitsProcessor,
)
from chameleon.inference.model_adapter import ChameleonModelAdapter
from chameleon.inference.stopping_criteria import (
MaxLengthCriteria,
StopOnEOSAfterBatchIndex,
)
from chameleon.inference.token_selector import (
ArgmaxTokenSelector,
MultinomialTokenSelector,
ReplicatedInputTokenSelector,
)
from chameleon.inference.transformer import Transformer
from chameleon.inference.utils import DynamicGenerator, advance, random_unused_port
from chameleon.inference.vocab import VocabInfo, VocabTranslation
@dataclass
class Options:
@dataclass
class Text:
repetition_penalty: float = 1.2
temp: float = 1.0
top_p: float = 0.9
greedy: bool = False
@dataclass
class Image:
@dataclass
class CFG:
guidance_scale_text: float = 3.0
guidance_scale_image: float = 1.2
cfg: CFG = field(default_factory=CFG)
temp: float = 0.7
top_p: float = 0.9
greedy: bool = False
max_seq_len: int = 4096
max_gen_len: int = 4096
seed: int | None = None
txt: Text | bool = True
img: Image | bool = True
extra_eos_tokens: list[int | str] = field(default_factory=lambda: [])
def __post_init__(self):
if self.txt is True:
self.txt = Options.Text()
if self.img is True:
self.img = Options.Image()
class TokenManager:
def __init__(
self,
tokenizer_path: str,
vqgan_cfg_path: str,
vqgan_ckpt_path: str,
device: str | None = None,
):
self.tokenizer = Tokenizer.from_file(tokenizer_path)
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
self.translation = VocabTranslation(self.vocab, device=device)
self.image_tokenizer = ImageTokenizer(
cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device=device
)
def pil_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image:
image_tensor = self.translation.convert_bpe2img(bpe_tokens)
if image_tensor.shape[0] < 1024:
padding = (
torch.ones(
[1024 - image_tensor.shape[0]],
dtype=int,
device=image_tensor.device,
)
* image_tensor[0]
)
image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
return self.image_tokenizer.pil_from_img_toks(image_tensor)
def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
pil = self.pil_from_bpe_tokens(bpe_tokens)
img_io = io.BytesIO()
pil.save(img_io, format="PNG")
return img_io.getvalue()
def tokenize_text(self, text: str) -> list[int]:
return self.tokenizer.encode(text).ids
def tokenize_image(self, img: Image) -> list[int]:
return (
[self.vocab.begin_image]
+ self.translation.convert_img2bp2(
self.image_tokenizer.img_tokens_from_pil(img) # [0 : 8191], vqgan codebook ids
).tolist()
+ [self.vocab.end_image]
)
def tokenize_b64img(self, b64img: str) -> list[int]:
image_data = base64.b64decode(b64img)
image_file = io.BytesIO(image_data)
return self.tokenize_image(PIL.Image.open(image_file))
def tokens_from_ui(self, inputs: list[dict]) -> list[int]:
tokens = [self.vocab.bos_id]
for input_ in inputs:
if input_["type"] == "text":
tokens += self.tokenize_text(input_["value"])
elif input_["type"] == "image":
if isinstance(input_["value"], str):
if input_["value"].startswith("data:"):
# Value Format: 'data:image/[^;]+;base64,[A-Za-z0-9+/]+={0,2}'
tokens += self.tokenize_b64img(input_["value"].split(",", 1)[1])
elif input_["value"].startswith("file:"):
tokens += self.tokenize_image(
PIL.Image.open(input_["value"].split(":", 1)[1])
)
else:
raise ValueError("Unknown image format.")
elif isinstance(input_["value"], Image):
tokens += self.tokenize_image(input_["value"])
else:
raise ValueError("Unknown image type.")
elif input_["type"] == "sentinel":
tokens += [
{
"<START-OF-IMAGE>": self.vocab.begin_image,
"<END-OF-TURN>": self.vocab.eot_id,
}[input_["value"]]
]
elif input_["type"] == "ids":
tokens += input_["value"]
else:
raise ValueError("Unknown input type.")
return tokens
def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
if isinstance(ids, torch.Tensor):
ids = ids.tolist()
for row, values in enumerate(ids):
try:
ids[row] = values[: values.index(self.vocab.eos_id)]
except ValueError:
pass
return self.tokenizer.decode_batch(ids)
def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
return [self.pil_from_bpe_tokens(sample) for sample in ids]
@dataclass
class DecodePiece:
token: ChameleonGenerator.Token
next_decoder: type["Decoder"] | None
class Decoder:
def __init__(
self,
model: Transformer,
vocab: VocabInfo,
options: Options,
input_ids: list[int],
): ...
def __next__(self) -> DecodePiece: ...
class TextDecoder(Decoder):
def __init__(
self,
model: Transformer,
vocab: VocabInfo,
options: Options,
input_ids: list[list[int]],
):
self.vocab = vocab
self.options = options
assert vocab.eos_id is not None
prompt_lens = [len(inp) for inp in input_ids]
max_prompt_len = max(prompt_lens)
max_seq_len = min(options.max_seq_len, max_prompt_len + options.max_gen_len)
self.eos_ids = [vocab.eos_id]
for extra_eos_token in options.extra_eos_tokens:
if isinstance(extra_eos_token, str):
extra_eos_token = vocab.name2val[extra_eos_token]
assert isinstance(extra_eos_token, int)
self.eos_ids.append(extra_eos_token)
stopping_criteria = [
MaxLengthCriteria(max_seq_len),
] + [StopOnEOSAfterBatchIndex(eos_id, [max_prompt_len] * len(prompt_lens)) for eos_id in self.eos_ids]
self.gen = ChameleonGenerator(
model=ChameleonModelAdapter(model, max_seq_len=max_seq_len),
input_ids=input_ids,
stopping_criteria=stopping_criteria,
logits_processors=self._logits_processors(),
alignment=AlignPromptRight(vocab.pad_id),
token_selector=(
ArgmaxTokenSelector()
if options.txt.greedy
else MultinomialTokenSelector()
),
)
advance(self.gen, max_prompt_len)
def _allowed_tokens(self) -> list[int]:
allowed_tokens = [self.vocab.eos_id]
if self.options.txt:
allowed_tokens += self.vocab.text_tokens
if self.options.img:
allowed_tokens += [self.vocab.begin_image]
return allowed_tokens
def _logits_processors(self) -> list[LogitsProcessor]:
logits_processors = [
AllowOnlyTokensLogitsProcessor(self._allowed_tokens()),
]
if isinstance(self.options.img, Options.Image):
logits_processors += [
DisallowTokensAtOrAfterIndexLogitsProcessor(
[self.vocab.begin_image],
self.options.max_seq_len - 1026,
),
]
if isinstance(self.options.txt, Options.Text):
logits_processors += [
RepetitionPenaltyLogitsProcessor(self.options.txt.repetition_penalty),
TemperatureLogitsWarper(self.options.txt.temp),
TopPLogitsWarper(self.options.txt.top_p),
]
return logits_processors
def __next__(self) -> DecodePiece:
tok = next(self.gen)
next_decoder = None
if (
self.vocab.begin_image not in self.eos_ids
and (tok.id == self.vocab.begin_image).all()
):
next_decoder = ImageDecoder
return DecodePiece(tok, next_decoder)
class ImageDecoder(Decoder):
def __init__(
self,
model: Transformer,
vocab: VocabInfo,
options: Options,
input_ids: list[list[int]],
):
assert isinstance(options.img, Options.Image)
self.vocab = vocab
self.options = options
self.batch_size = len(input_ids)
logits_processors = [
InBatchInstructCFGLogitsProcessor(
options.img.cfg.guidance_scale_text,
options.img.cfg.guidance_scale_image,
),
AllowOnlyTokensLogitsProcessor(vocab.image_tokens),
TemperatureLogitsWarper(options.img.temp),
TopPLogitsWarper(options.img.top_p),
]
for inp in input_ids:
if inp[-1] != self.vocab.begin_image:
inp.append(self.vocab.begin_image)
max_prompt_len = max(len(inp) for inp in input_ids)
self.gen = ChameleonGenerator(
model=ChameleonModelAdapter(model, max_seq_len=max_prompt_len + 1024),
input_ids=self._split_inputs_for_cfg(input_ids),
logits_processors=logits_processors,
alignment=AlignPromptRight(vocab.pad_id),
token_selector=ReplicatedInputTokenSelector(
(
ArgmaxTokenSelector()
if options.img.greedy
else MultinomialTokenSelector()
),
n=3,
),
)
advance(self.gen, max_prompt_len)
self.gen_count = 0
def _split_inputs_for_cfg(self, input_ids: list[list[int]]) -> list[list[int]]:
image_conditioned_allowed = set(self.vocab.image_tokens) | {
self.vocab.bos_id,
self.vocab.begin_image,
self.vocab.end_image,
}
full_conditioned = input_ids
image_conditioned = [
[id for id in sample if id in image_conditioned_allowed]
for sample in input_ids
]
unconditioned = [
[
self.vocab.bos_id,
self.vocab.begin_image,
]
] * self.batch_size
return full_conditioned + image_conditioned + unconditioned
def __next__(self) -> DecodePiece:
if self.gen_count == 1024:
id = torch.tensor([self.vocab.end_image] * self.batch_size)
logits = torch.full(
(self.batch_size, len(self.vocab.all_tokens)), -math.inf
)
logits[:, self.vocab.end_image] = 0
return DecodePiece(
ChameleonGenerator.Token(id=id, logits=logits),
TextDecoder,
)
tok = next(self.gen)
tok.id = tok.id.chunk(3)[0]
self.gen_count += 1
return DecodePiece(tok, None)
class Generator(Decoder):
def __init__(
self,
model: Transformer,
vocab: VocabInfo,
options: Options,
input_ids: list[list[int]],
):
if options.seed is not None:
enable_full_determinism(options.seed, warn_only=True)
self.model = model
self.vocab = vocab
self.input_ids = input_ids[:]
self.generated_token_ids: list[torch.LongTensor] = []
self.options = options
if not self.options.txt:
self.dyngen = DynamicGenerator(
ImageDecoder(model, vocab, options, input_ids)
)
else:
self.dyngen = DynamicGenerator(
TextDecoder(model, vocab, options, input_ids)
)
def __iter__(self):
return self
def __next__(self) -> ChameleonGenerator.Token:
piece = next(self.dyngen)
self.generated_token_ids.append(piece.token.id)
if piece.next_decoder is not None:
if not self.options.txt:
raise StopIteration
self.input_ids = [
old_list + generated
for old_list, generated in zip(
self.input_ids, torch.stack(self.generated_token_ids).T.tolist()
)
]
self.generated_token_ids = []
self.dyngen.gen = piece.next_decoder(
self.model,
self.vocab,
self.options,
self.input_ids,
)
return piece.token
class DistributedMode(Enum):
AUTO = 0
THREAD = 1
PROCESS = 2
@dataclass
class _DistributedContext:
req_q: Union[queue.Queue, queues.Queue]
res_q: Union[queue.Queue, queues.Queue]
active_key: Union[dict[int, Literal[True]], managers.DictProxy]
active_key_lock: Union[threading.Lock, synchronize.Lock]
ready_barrier: Union[threading.Barrier, synchronize.Barrier]
worker_launcher: Union[type[threading.Thread], type[mp.Process]]
@staticmethod
def make_for_threading(world_size: int):
return _DistributedContext(
req_q=queue.Queue(),
res_q=queue.Queue(),
active_key={},
active_key_lock=threading.Lock(),
ready_barrier=threading.Barrier(world_size + 1),
worker_launcher=threading.Thread,
)
@staticmethod
def make_for_multiprocessing(world_size: int):
local_mp = mp.get_context("spawn")
return _DistributedContext(
req_q=local_mp.Queue(),
res_q=local_mp.Queue(),
active_key=local_mp.Manager().dict(),
active_key_lock=local_mp.Lock(),
ready_barrier=local_mp.Barrier(world_size + 1),
worker_launcher=local_mp.Process,
)
@staticmethod
def make(mode: DistributedMode, world_size: int):
if mode == DistributedMode.AUTO:
mode = DistributedMode.PROCESS
if mode == DistributedMode.THREAD:
return _DistributedContext.make_for_threading(world_size)
elif mode == DistributedMode.PROCESS:
return _DistributedContext.make_for_multiprocessing(world_size)
else:
raise ValueError("Unknown DistributedMode")
def _worker_impl(
init_method: str,
model: Transformer | str,
world_size: int,
rank: int,
vocab: VocabInfo,
dctx: _DistributedContext,
):
dist.init_process_group(
"nccl",
init_method=init_method,
world_size=world_size,
rank=rank,
)
torch.set_default_device(f"cuda:{rank}")
torch.cuda.set_device(rank)
if isinstance(model, str):
model = loader.load_model(model, rank=rank)
dctx.ready_barrier.wait()
is_coord = rank == 0
while True:
req = [Options(), [], 0, False]
if is_coord:
req = dctx.req_q.get()
dist.broadcast_object_list(req, src=0)
options, input_ids, key, shutdown = req
if shutdown:
break
for token in Generator(
model=model,
vocab=vocab,
options=options,
input_ids=input_ids,
):
if is_coord:
dctx.res_q.put((key, token))
to_continue = [True]
if is_coord:
with dctx.active_key_lock:
to_continue = [key in dctx.active_key]
dist.broadcast_object_list(to_continue, src=0)
if not to_continue[0]:
break
if is_coord:
dctx.res_q.put((key, None))
class ChameleonInferenceModel:
def __init__(
self,
model: Transformer | str,
tokenizer_path: str,
vqgan_cfg_path: str,
vqgan_ckpt_path: str,
*,
options: Options | None = None,
distributed_mode: DistributedMode = DistributedMode.AUTO,
):
self.options = options or Options()
self.next_key = 0
self.token_manager = TokenManager(
tokenizer_path=tokenizer_path,
vqgan_cfg_path=vqgan_cfg_path,
vqgan_ckpt_path=vqgan_ckpt_path,
device="cuda",
)
self.vocab = self.token_manager.vocab
world_size = 1
if isinstance(model, str):
world_size = loader.detect_shard_count(model)
self.dctx = _DistributedContext.make(distributed_mode, world_size)
init_method = f"tcp://0.0.0.0:{random_unused_port()}"
self.workers = [
self.dctx.worker_launcher(
target=_worker_impl,
args=(init_method, model, world_size, i, self.vocab, self.dctx),
daemon=True,
)
for i in range(world_size)
]
for w in self.workers:
w.start()
self.dctx.ready_barrier.wait()
def __del__(self):
try:
with self.dctx.active_key_lock:
self.dctx.active_key.clear()
self.dctx.req_q.put([None, None, None, True])
for w in self.workers:
w.join()
except FileNotFoundError:
pass
def stream(
self,
*,
input_ids: list[int] | None = None,
prompt_text: str | None = None,
prompt_ui: list[dict] | None = None,
batch_input_ids: list[list[int]] | None = None,
batch_prompt_text: list[str] | None = None,
batch_prompt_ui: list[list[dict]] | None = None,
options: Options | None = None,
):
# NOTE: Not thread-safe! Only one instance of generate may be run at a time.
if (
sum(
x is not None
for x in [
input_ids,
prompt_text,
prompt_ui,
batch_input_ids,
batch_prompt_text,
batch_prompt_ui,
]
)
!= 1
):
raise ValueError(
"Must specify exactly one of: input_ids, prompt_text, prompt_ui, batch_input_ids, batch_prompt_text, batch_prompt_ui"
)
options = options or self.options
if prompt_text is not None:
batch_prompt_text = [prompt_text]
if prompt_ui is not None:
batch_prompt_ui = [prompt_ui]
if input_ids is not None:
batch_input_ids = [input_ids]
if batch_prompt_text is not None:
batch_prompt_ui = [
[{"type": "text", "value": prompt_text}]
for prompt_text in batch_prompt_text
]
if batch_prompt_ui is not None:
batch_input_ids = [
self.token_manager.tokens_from_ui(prompt_ui)
for prompt_ui in batch_prompt_ui
]
assert batch_input_ids
if not options.txt and not options.img:
raise ValueError("Must specify at least one modality.")
if options.txt and options.img and len(batch_input_ids) > 1:
raise ValueError(
"Batch generation only supported for one modality at a time."
)
req_key = self.next_key
self.next_key += 1
with self.dctx.active_key_lock:
self.dctx.active_key[req_key] = True
self.dctx.req_q.put([options, batch_input_ids, req_key, False])
try:
while key_token := self.dctx.res_q.get():
key, token = key_token
if key != req_key:
# Residual from prior calls to generation. Skip.
continue
if token is None:
break
yield token
finally:
with self.dctx.active_key_lock:
del self.dctx.active_key[req_key]
def step(self, *args, **kwargs) -> ChameleonGenerator.Token:
return next(self.stream(*args, **kwargs))
def generate(self, *args, **kwargs) -> torch.LongTensor:
tokens = [t.id for t in self.stream(*args, **kwargs)]
if not tokens:
return torch.LongTensor()
return torch.stack(tokens).T
def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
return self.token_manager.decode_text(ids)
def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
return self.token_manager.decode_image(ids)
def sft_tokenization(self, json_path: str) -> list[dict]:
with open(json_path, 'r') as input_file:
jsonl_input = [json.loads(line) for line in input_file]
output_data = []
for entry in tqdm(jsonl_input, desc="Tokenize dataset"):
# print(i)
text_tokens = self.token_manager.tokenize_text(entry['text'])
image_tokens = self.token_manager.tokenize_image(PIL.Image.open(entry['image']))
entry['text_tokens'] = text_tokens
entry['image_tokens'] = image_tokens
output_data.append(entry)
return output_data