Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
import math | |
from abc import ABC, abstractmethod | |
import torch | |
from chameleon.inference import transformer | |
from chameleon.inference.alignment import ( | |
AlignPromptLeft, | |
AlignPromptRight, | |
PromptAlignment, | |
) | |
from chameleon.inference.cudagraph import cudagraph_wrap | |
class ModelAdapter(ABC): | |
def initialize(self, prompt_tokens: list[list[int]]): | |
... | |
def supports_alignment(self, alignment: PromptAlignment) -> bool: | |
... | |
def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor: | |
... | |
class ChameleonModelAdapter(ModelAdapter): | |
"""Adapter for Chameleon-style model that handles state, such as cache.""" | |
def __init__( | |
self, | |
model: transformer.Transformer, | |
max_seq_len: int, | |
dtype: torch.dtype | None = None, | |
): | |
super().__init__() | |
self._args = model.args | |
self._model = model | |
self._max_seq_len = max_seq_len | |
self._dtype = dtype or next(model.parameters()).data.dtype | |
def initialize(self, prompt_tokens: list[list[int]]): | |
self._prompt_lengths = [len(toks) for toks in prompt_tokens] | |
batch_size = len(prompt_tokens) | |
self._cache = transformer.make_cache( | |
args=self._args, | |
length=batch_size * self._max_seq_len, | |
dtype=self._dtype, | |
) | |
self._local_inputs = torch.zeros([batch_size], dtype=int, device="cuda") | |
self._forward = cudagraph_wrap(self._model.forward_with_attn_bias) | |
self._first_pass = True | |
def supports_alignment(self, alignment: PromptAlignment) -> bool: | |
return isinstance(alignment, AlignPromptLeft) or isinstance( | |
alignment, AlignPromptRight | |
) | |
def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor: | |
# inputs.shape=[batch, seq-len] | |
batch_size, seq_len = inputs.shape | |
if self._first_pass: | |
attn_seqlen = [min(pl, seq_len) for pl in self._prompt_lengths] | |
self._bias = transformer.AttnBias.from_seqlens( | |
q_seqlen=attn_seqlen, | |
kv_seqlen=attn_seqlen, | |
kv_padding=self._max_seq_len, | |
) | |
mask = torch.zeros_like(inputs, dtype=torch.bool) | |
for i, k in enumerate(self._prompt_lengths): | |
mask[i, -k:] = True | |
flat_outputs: torch.Tensor = self._forward( # type: ignore | |
token_values=inputs[mask], | |
attn_bias=self._bias, | |
cache=self._cache, | |
) | |
self._local_outputs = torch.full( | |
(inputs.shape[0], inputs.shape[1], flat_outputs.shape[-1]), | |
-math.inf, | |
) | |
self._local_outputs[mask] = flat_outputs | |
self._vocab_size = self._local_outputs.shape[-1] | |
self._bias.q_seqinfo.seqstart.copy_( | |
torch.arange(batch_size + 1, dtype=torch.int) | |
) | |
self._bias.q_seqinfo.max_seqlen = 1 | |
self._bias.q_seqinfo.seqstart_py = self._bias.q_seqinfo.seqstart.tolist() | |
self._first_pass = False | |
else: | |
self._local_inputs.copy_(inputs[:, -1]) # type: ignore | |
self._local_outputs = self._forward( # type: ignore | |
token_values=self._local_inputs, | |
attn_bias=self._bias, | |
cache=self._cache, | |
) | |
self._bias.k_seqinfo.seqlen.add_(1) | |
return self._local_outputs.view(batch_size, -1, self._vocab_size) | |