File size: 2,921 Bytes
565faca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Optional

import torch
from torch.nn import functional as F


class NonCausalInferenceMixin:
    """
    Mixin class for non-causal inference in a language model.

    This class provides methods for performing non-causal sampling using a language model.
    """

    @torch.no_grad()
    def _non_causal_sample(
        self, *, idx: torch.Tensor, speaker_embs: Optional[torch.Tensor], temperature: float, top_k: int
    ):
        """
        Perform non-causal sampling.

        Args:
            idx (torch.Tensor): Input tensor of shape (batch_size, num_in_hierarchies, sequence_length).
            speaker_embs (Optional[torch.Tensor]): Speaker embeddings tensor of shape (batch_size, embedding_size).
            temperature (float): Temperature parameter for scaling the logits.
            top_k (int): Number of top options to consider.

        Returns:
            torch.Tensor: Sampled output tensor of shape (batch_size, num_out_hierarchies, sequence_length).
        """
        b, c, t = idx.size()
        assert t == self.config.block_size, f"input size {t} != config.block_size {self.config.block_size}"
        # forward the model to get the logits for the index in the sequence
        list_logits, _ = self(idx, speaker_embs=speaker_embs)  # c x (b, t, vocab_size)

        # scale by desired temperature
        list_logits = [logits / temperature for logits in list_logits]  # c x (b, t, vocab_size)

        # optionally crop the logits to only the top k options
        if top_k is not None:
            for i in range(len(list_logits)):
                logits = list_logits[i]  # (b, t, vocab_size)

                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))  # (b, t, top_k)
                logits[logits < v[:, :, [-1]]] = -float("Inf")
                list_logits[i] = logits  # (b, t, vocab_size)
                assert logits.shape[0] == b and logits.shape[1] == t

        # apply softmax to convert logits to (normalized) probabilities
        # TODO: check shapes here!
        probs = [F.softmax(logits, dim=-1) for logits in list_logits]  # c x (b, t, top_k)
        assert probs[0].shape[0] == b and probs[0].shape[1] == t

        # TODO: output shape is as expected
        outs = []
        for b_prob in probs:  # c x (b, t, top_k) -> (b, t, top_k)
            out = [
                torch.multinomial(prob, num_samples=1).transpose(0, 1).unsqueeze(0) for prob in b_prob
            ]  # b x (t, top_k) -> b x (t, 1) -> b x (1, t) -> b x (1, 1, t)
            assert len(out) == b and out[0].shape[0] == 1 and out[0].shape[1] == 1 and out[0].shape[2] == t
            out = torch.cat(out, dim=0)  # (b, 1, t)
            assert out.shape[0] == b and out.shape[1] == 1 and out.shape[2] == t
            outs.append(out)

        out = torch.cat(outs, dim=1)  # (b, c, t)
        assert out.shape[0] == b and out.shape[2] == t

        return out