Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
import torch | |
from transformers import ( | |
LogitsProcessor, | |
LogitsProcessorList, | |
) | |
from transformers.generation.streamers import BaseStreamer | |
from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment | |
from chameleon.inference.model_adapter import ModelAdapter | |
from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList | |
from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector | |
class ChameleonGenerator: | |
class Token: | |
id: torch.LongTensor | |
logits: torch.Tensor | None | |
def __init__( | |
self, | |
model: ModelAdapter, | |
input_ids: list[list[int]], | |
stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None, | |
logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, | |
probability_processors: LogitsProcessorList | |
| list[LogitsProcessor] | |
| None = None, | |
token_selector: TokenSelector | None = None, | |
alignment: PromptAlignment = AlignPromptLeft(), | |
): | |
assert model.supports_alignment(alignment) | |
self.model = model | |
self.stopping_criteria = stopping_criteria | |
self.logits_processors = logits_processors | |
self.probability_processors = probability_processors | |
self.token_selector: TokenSelector = ( | |
token_selector or MultinomialTokenSelector() | |
) | |
self.alignment = alignment | |
self.model.initialize(input_ids) | |
self._inputs = self.alignment.prepare_inputs( | |
input_ids | |
) # inputs.shape = [batch, seq-len] | |
self._idx = 0 | |
self._start_idx = self.alignment.start_index(input_ids) | |
self._original_inputs = self._inputs.clone() | |
self._inputs = self._inputs[:, : self._start_idx] | |
def __iter__(self): | |
return self | |
def __next__(self) -> Token: | |
# Are we done? | |
if self.stopping_criteria(self._inputs, None): | |
raise StopIteration | |
# Emit initial tokens. | |
# Model is not run for these. | |
# If you want the logits, you can do a separate forward pass outside generation. | |
if self._idx < self._start_idx: | |
idx, self._idx = self._idx, self._idx + 1 | |
return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None) | |
# Run the model for the next token. | |
self._inputs = self._inputs.contiguous() | |
outputs = self.model(self._inputs) # outputs.shape = [batch, seq-len, vocab] | |
# Pull out and process the logits. | |
logits = outputs[:, -1, :] # logits.shape = [batch, vocab] | |
logits = self.logits_processors(self._inputs, logits) | |
probs = logits.softmax(dim=1) # probs.shape = [batch, vocab] | |
probs = self.probability_processors(self._inputs, probs) | |
# Select a token and add it to the inputs. | |
next_tokens = self.token_selector( | |
self._inputs, probs | |
) # next_tokens.shape = [batch] | |
self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1) | |
# Run alignment specific postprocessing. | |
self._inputs = self.alignment.postprocess_inputs( | |
self._inputs, self._original_inputs | |
) | |
# Return the next step result. | |
return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits) | |
def stopping_criteria(self) -> StoppingCriteriaList: | |
return self._stopping_criteria | |
def stopping_criteria( | |
self, value: StoppingCriteriaList | list[StoppingCriteria] | None | |
): | |
self._stopping_criteria = StoppingCriteriaList(value or []) | |
def logits_processors(self) -> LogitsProcessorList: | |
return self._logits_processors | |
def logits_processors( | |
self, value: LogitsProcessorList | list[LogitsProcessor] | None | |
): | |
self._logits_processors = LogitsProcessorList(value or []) | |
def probability_processors(self) -> LogitsProcessorList: | |
return self._probability_processors | |
def probability_processors( | |
self, value: LogitsProcessorList | list[LogitsProcessor] | None | |
): | |
self._probability_processors = LogitsProcessorList(value or []) | |
def run_generation( | |
model: torch.nn.Module, | |
input_ids: list[list[int]], | |
stopping_criteria: StoppingCriteriaList | list[StoppingCriteria], | |
logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, | |
probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, | |
token_selector: TokenSelector | None = None, | |
alignment: PromptAlignment = AlignPromptLeft(), | |
streamer: BaseStreamer | None = None, | |
) -> torch.LongTensor: | |
result = torch.empty((len(input_ids), 0), dtype=int) | |
for tok in ChameleonGenerator( | |
model=model, | |
input_ids=input_ids, | |
stopping_criteria=stopping_criteria, | |
logits_processors=logits_processors, | |
probability_processors=probability_processors, | |
token_selector=token_selector, | |
alignment=alignment, | |
): | |
if streamer is not None: | |
streamer.put(tok.id) | |
result = torch.cat([result, tok.id.view(-1, 1)], dim=1) | |
if streamer is not None: | |
streamer.end() | |
return result | |