Anole / chameleon /inference /token_selector.py
xuefengli
update
7362797
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Chameleon License found in the
# LICENSE file in the root directory of this source tree.
import torch
class TokenSelector:
def __call__(
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
) -> torch.FloatTensor:
# input_ids.shape=[batch, seq_len]
# probs.shape=[batch, vocab]
...
class ArgmaxTokenSelector(TokenSelector):
def __call__(
self, _: torch.LongTensor, probs: torch.FloatTensor
) -> torch.LongTensor:
# probs.shape=[batch, vocab]
return probs.argmax(dim=1)
class MultinomialTokenSelector(TokenSelector):
def __call__(
self, _: torch.LongTensor, probs: torch.FloatTensor
) -> torch.LongTensor:
# probs.shape=[batch, vocab]
return probs.multinomial(num_samples=1).squeeze(1)
class ReplicatedInputTokenSelector(TokenSelector):
def __init__(self, token_selector: TokenSelector, n: int):
self.token_selector = token_selector
self.n = n
def __call__(
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
) -> torch.LongTensor:
# input_ids.shape=[n*batch, seq_len]
# probs.shape=[n*batch, vocab]
primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
tokens = self.token_selector(primary_input_ids, primary_probs)
return tokens.repeat(self.n)