File size: 3,306 Bytes
4f83ec0
 
 
 
 
 
 
 
72a89db
4f83ec0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import logging
from typing import Tuple

import torch
from outlines.samplers import MultinomialSampler

logger = logging.getLogger(__name__)


class PenalizedMultinomialSampler(MultinomialSampler):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.penalized_tokens_group: list[torch.IntTensor] = []
        self.max_repeats_per_token_group: list[int] = []
        self.repeats_per_token_group: list[int] = []
        self.token_id_to_tokens_groups: list[list[int]] = []
    
    def set_max_repeats(self, token_ids: list[int], max_repeats: int) -> None:
        max_token_ids = max(token_ids)
        if max_token_ids >= len(self.token_id_to_tokens_groups):
            self.token_id_to_tokens_groups += [[] for _ in range(len(self.token_id_to_tokens_groups), max_token_ids + 1)]
        for token_id in token_ids:
            self.token_id_to_tokens_groups[token_id].append(len(self.penalized_tokens_group))
        self.penalized_tokens_group.append(torch.tensor(token_ids, dtype=torch.int32))
        self.max_repeats_per_token_group.append(max_repeats)
        self.repeats_per_token_group.append(0)

    def __call__(
        self,
        next_token_logits: torch.DoubleTensor,
        sequence_weights: torch.DoubleTensor,
        rng: torch.Generator,
    ) -> Tuple[torch.DoubleTensor, torch.DoubleTensor, torch.DoubleTensor]:
        """Call the multinomial sampler.

        Parameters
        ----------
        next_token_logits
            A tensor of shape ``(n_seqs, vocab_size,)`` that represents the
            probability distribution of the next token over the vocabulary.
        sequence_weights
            A tensor of shape ``(n_seqs,)`` that represents the cumulative
            weight of each sequence.
        rng
            A random number generator.

        Returns
        -------
        A tuple with an array that contains the ids of the sampled tokens of
        shape ``(n_seqs, 1)``, an array that contains the ancestors of each
        sampled id of shape ``(n_seqs,)`` and an array that contains the updated
        cumulative weights of each sequence of shape ``(n_seqs,)``.

        """
        if sequence_weights.min() == sequence_weights.max() == 0:
            self.repeats_per_token_group = [0] * len(self.repeats_per_token_group)
        else:
            for penalized_tokens_group, max_repeats_per_token_group, repeats_per_token_group in zip(self.penalized_tokens_group, self.max_repeats_per_token_group, self.repeats_per_token_group):
                if repeats_per_token_group >= max_repeats_per_token_group:
                    penalty = torch.zeros_like(next_token_logits)
                    penalty[:, penalized_tokens_group] = - torch.inf
                    next_token_logits = next_token_logits + penalty
        next_token_ids, ancestors, weights = super().__call__(
            next_token_logits=next_token_logits,
            sequence_weights=sequence_weights,
            rng=rng
        )
        for next_token_id in next_token_ids.cpu():
            if next_token_id < len(self.token_id_to_tokens_groups):
                for token_group in self.token_id_to_tokens_groups[next_token_id]:
                    self.repeats_per_token_group[token_group] += 1 
        return next_token_ids, ancestors, weights