# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Chameleon License found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass import torch from transformers import ( LogitsProcessor, LogitsProcessorList, ) from transformers.generation.streamers import BaseStreamer from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment from chameleon.inference.model_adapter import ModelAdapter from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector class ChameleonGenerator: @dataclass class Token: id: torch.LongTensor logits: torch.Tensor | None def __init__( self, model: ModelAdapter, input_ids: list[list[int]], stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None, logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, token_selector: TokenSelector | None = None, alignment: PromptAlignment = AlignPromptLeft(), ): assert model.supports_alignment(alignment) self.model = model self.stopping_criteria = stopping_criteria self.logits_processors = logits_processors self.probability_processors = probability_processors self.token_selector: TokenSelector = ( token_selector or MultinomialTokenSelector() ) self.alignment = alignment self.model.initialize(input_ids) self._inputs = self.alignment.prepare_inputs( input_ids ) # inputs.shape = [batch, seq-len] self._idx = 0 self._start_idx = self.alignment.start_index(input_ids) self._original_inputs = self._inputs.clone() self._inputs = self._inputs[:, : self._start_idx] def __iter__(self): return self @torch.inference_mode() def __next__(self) -> Token: # Are we done? if self.stopping_criteria(self._inputs, None): raise StopIteration # Emit initial tokens. # Model is not run for these. # If you want the logits, you can do a separate forward pass outside generation. if self._idx < self._start_idx: idx, self._idx = self._idx, self._idx + 1 return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None) # Run the model for the next token. self._inputs = self._inputs.contiguous() outputs = self.model(self._inputs) # outputs.shape = [batch, seq-len, vocab] # Pull out and process the logits. logits = outputs[:, -1, :] # logits.shape = [batch, vocab] logits = self.logits_processors(self._inputs, logits) probs = logits.softmax(dim=1) # probs.shape = [batch, vocab] probs = self.probability_processors(self._inputs, probs) # Select a token and add it to the inputs. next_tokens = self.token_selector( self._inputs, probs ) # next_tokens.shape = [batch] self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1) # Run alignment specific postprocessing. self._inputs = self.alignment.postprocess_inputs( self._inputs, self._original_inputs ) # Return the next step result. return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits) @property def stopping_criteria(self) -> StoppingCriteriaList: return self._stopping_criteria @stopping_criteria.setter def stopping_criteria( self, value: StoppingCriteriaList | list[StoppingCriteria] | None ): self._stopping_criteria = StoppingCriteriaList(value or []) @property def logits_processors(self) -> LogitsProcessorList: return self._logits_processors @logits_processors.setter def logits_processors( self, value: LogitsProcessorList | list[LogitsProcessor] | None ): self._logits_processors = LogitsProcessorList(value or []) @property def probability_processors(self) -> LogitsProcessorList: return self._probability_processors @probability_processors.setter def probability_processors( self, value: LogitsProcessorList | list[LogitsProcessor] | None ): self._probability_processors = LogitsProcessorList(value or []) def run_generation( model: torch.nn.Module, input_ids: list[list[int]], stopping_criteria: StoppingCriteriaList | list[StoppingCriteria], logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None, token_selector: TokenSelector | None = None, alignment: PromptAlignment = AlignPromptLeft(), streamer: BaseStreamer | None = None, ) -> torch.LongTensor: result = torch.empty((len(input_ids), 0), dtype=int) for tok in ChameleonGenerator( model=model, input_ids=input_ids, stopping_criteria=stopping_criteria, logits_processors=logits_processors, probability_processors=probability_processors, token_selector=token_selector, alignment=alignment, ): if streamer is not None: streamer.put(tok.id) result = torch.cat([result, tok.id.view(-1, 1)], dim=1) if streamer is not None: streamer.end() return result