Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
class TokenSelector: | |
def __call__( | |
self, input_ids: torch.LongTensor, probs: torch.FloatTensor | |
) -> torch.FloatTensor: | |
# input_ids.shape=[batch, seq_len] | |
# probs.shape=[batch, vocab] | |
... | |
class ArgmaxTokenSelector(TokenSelector): | |
def __call__( | |
self, _: torch.LongTensor, probs: torch.FloatTensor | |
) -> torch.LongTensor: | |
# probs.shape=[batch, vocab] | |
return probs.argmax(dim=1) | |
class MultinomialTokenSelector(TokenSelector): | |
def __call__( | |
self, _: torch.LongTensor, probs: torch.FloatTensor | |
) -> torch.LongTensor: | |
# probs.shape=[batch, vocab] | |
return probs.multinomial(num_samples=1).squeeze(1) | |
class ReplicatedInputTokenSelector(TokenSelector): | |
def __init__(self, token_selector: TokenSelector, n: int): | |
self.token_selector = token_selector | |
self.n = n | |
def __call__( | |
self, input_ids: torch.LongTensor, probs: torch.FloatTensor | |
) -> torch.LongTensor: | |
# input_ids.shape=[n*batch, seq_len] | |
# probs.shape=[n*batch, vocab] | |
primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0] | |
primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0] | |
tokens = self.token_selector(primary_input_ids, primary_probs) | |
return tokens.repeat(self.n) | |