File size: 4,003 Bytes
be19c03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import sys

import torch
import torch.nn as nn
from mamba_ssm.modules.mamba_simple import Mamba
from mamba_ssm.utils.generation import InferenceParams
from torch import Tensor
from tqdm.auto import tqdm

dim_model = 512
n_vocab = 2
n_layers = 4


@torch.no_grad()
def string_to_bits(text: str, msb=True, _cache={}) -> Tensor:
    all_values = torch.arange(0, 256)
    if msb not in _cache:
        if msb:
            bits = [((all_values & (1 << i)) != 0).int()
                    for i in range(7, -1, -1)]
        else:
            bits = [((all_values & (1 << i)) != 0).int() for i in range(8)]
        bits_tensor = torch.stack(bits).mT
        _cache[msb] = bits_tensor
    else:
        bits_tensor = _cache[msb]
    binary = text.encode()
    raw = torch.frombuffer(binary, dtype=torch.uint8).int()
    return bits_tensor[raw].long().ravel()


@torch.no_grad()
def bits_to_string(bits: Tensor, msb=True):
    if bits.dim() == 2:
        return [bits_to_string(t) for t in bits]
    assert bits.dim() == 1
    assert len(bits) % 8 == 0
    if msb:
        factors = torch.tensor([2**i for i in range(7, -1, -1)])
    else:
        factors = torch.tensor([2**i for i in range(8)])
    factors = factors.to(device=bits.device)
    as_bytes = bits.view(-1, 8)
    as_bytes = (as_bytes*factors).sum(-1)
    return ''.join([chr(x) for x in as_bytes])  # type: ignore


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(n_vocab, dim_model)

    def forward(self, x):
        return self.emb(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.LayerNorm(dim_model)
        self.decoder = nn.Linear(dim_model, n_vocab, False)

    def forward(self, x):
        x = self.norm(x)
        x = self.decoder(x)
        return x

class MambaLayer(nn.Module):
    def __init__(self, layer_idx=None):
        super().__init__()
        self.in_norm = nn.LayerNorm(dim_model)
        self.mamba = Mamba(dim_model, layer_idx=layer_idx)

    def forward(self, x, inference_params=None):
        residual = x
        x = self.in_norm(x)
        x = self.mamba(x, inference_params=inference_params)
        x = residual + x
        return x

class MambaBit(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.layers = nn.ModuleList([MambaLayer(layer_idx=idx) for idx in range(n_layers)])
        self.dec = Decoder()

    def forward(self, x, inference_params=None):
        x = self.enc(x)
        for layer in self.layers:
            x = x + layer(x, inference_params=inference_params)
        x = self.dec(x)
        return x

# test using O(N^2) cacheless stateless algorithm.
@torch.no_grad()
def test_n2(m: MambaBit, prompt: str, chars=10):
    x = string_to_bits(prompt).cuda()[None]
    process = chars * 8
    for i in tqdm(range(process)):
        y = m(x)
        new = y[:, -1:].argmax(-1)
        x = torch.cat((x, new), 1)
    return bits_to_string(x)

# test using O(N) by reusing state


@torch.no_grad()
def test_n(m: MambaBit, prompt: str, chars=10):
    x = string_to_bits(prompt).cuda()[None]
    process = chars * 8

    inference_parms = InferenceParams(
        max_seqlen=x.numel() + process,
        max_batch_size=1)

    y = m(x, inference_params=inference_parms)
    new = y[:, -1:].argmax(-1)
    for i in tqdm(range(process)):
        x = torch.cat((x, new), 1)
        inference_parms.seqlen_offset = x.numel() + i
        y = m(new, inference_params=inference_parms)
        new = y[:, -1:].argmax(-1)
    return bits_to_string(x)


def run():
    mamba_bit = MambaBit().bfloat16().cuda()
    mamba_bit.load_state_dict(torch.load("mamba_bit.tiny.bin"))
    
    prompt = "Once upon a time" if len(sys.argv) != 2 else sys.argv[1]
    s = test_n(mamba_bit, prompt, chars=256)[0]
    print(s)

def model_numel(m: nn.Module):
    return sum(p.numel() for p in m.parameters())
    
if __name__ == "__main__":    
    run()