File size: 3,795 Bytes
1a030c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from mamba_ssm.modules.mamba2_simple import Mamba2Simple
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from pathlib import Path
from einops import repeat
from image_utils import ImageDB, ImageBatch, RGBToModel
from image_utils import ModelToRGB
epochs = 10_000
bs = 16
# orig;
# bs = 16
# d_model = 768
# headdim = 64
# n_layer = 4
d_model = 1024
headdim = 64
n_layer = 4
OPTS = {
'device': "cuda",
'dtype': torch.bfloat16
}
# Since we have KISS flip/flop think that number of mamba layers are actually 2 times higher
# This is somewhat relatable to LLM model where 1 block had two mamba layers: one replaced ATTN, one replaced MLP
weights_path = Path(
f"data/image-flip-weights-{d_model}x{n_layer}-{str(OPTS['dtype'])}.bin")
print(f"Weight path is {str(weights_path)}")
class MambaWrap(nn.Module):
def __init__(self) -> None:
super().__init__()
self.mamba = Mamba2Simple(d_model, **OPTS, headdim=headdim)
self.norm = nn.LayerNorm(d_model, **OPTS)
def forward(self, x):
residual = x
x = self.norm(x)
x = self.mamba(x)
x = residual + x
return x
class MambaFlipFlop(nn.Module):
def __init__(self, n_values) -> None:
super().__init__()
self.mb_forward = MambaWrap()
self.mb_backward = MambaWrap()
self.n_values = n_values
def forward(self, x):
x = self.mb_forward(x)
x = self.swap_order(x)
x = self.mb_backward(x)
x = self.swap_order(x)
return x
def swap_order(self, x):
T = x.shape[1]
head = torch.arange(0, T - self.n_values)
tail = torch.arange(T - 1, T - self.n_values - 1, -1)
seq = torch.cat((head, tail))
x = x[:, seq]
return x
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.from_rgb = RGBToModel(d_model, **OPTS)
self.to_rgb = ModelToRGB(d_model, **OPTS)
self.s0 = nn.Parameter(torch.randn(1, 1, d_model, **OPTS))
self.suffix = nn.Parameter(torch.randn(64*64, d_model, **OPTS))
self.layers = nn.ModuleList([MambaFlipFlop(64*64)
for _ in range(n_layer)])
self.norm0 = nn.LayerNorm(d_model, **OPTS)
def forward(self, batch: ImageBatch):
B = batch.n_batch
batch = batch.as_1d()
batch.im8 = self.from_rgb(batch.im8)
s0 = self.s0.repeat(B, 1, 1)
s1 = self.zoom(batch.im8)
x = torch.cat((s0, batch.im8, s1), 1)
x = self.norm0(x)
x = self.mamba(x)
x = x[:, -64*64:]
y_hat = self.to_rgb(x)
y_true = batch.im64
batch.loss = F.mse_loss(y_hat, y_true)
batch.im64 = y_hat
return batch.as_2d()
def zoom(self, im8):
im8 = im8.view(im8.shape[0], 8, 8, im8.shape[-1])
im8 = repeat(im8, "B H W C -> B (H 8) (W 8) C")
im8 = im8.view(im8.shape[0], 64*64, im8.shape[-1])
im8 = im8 + self.suffix
return im8
def mamba(self, x):
for layer in self.layers:
x = layer(x)
return x
if __name__ == "__main__":
image_db = ImageDB(dtype=OPTS["dtype"])
model = Model()
if weights_path.exists():
print(f"*** Load {str(weights_path)}")
model.load_state_dict(torch.load(weights_path))
opt = torch.optim.AdamW(model.parameters(), fused=True)
for e in (bar := tqdm(range(epochs))):
b = model(image_db.random_batch(bs))
b.loss.backward()
opt.step()
opt.zero_grad()
bar.set_description(f'L:{b.loss.item():.4f}')
if e and e % 100 == 0:
torch.save(model.state_dict(), weights_path)
torch.save(model.state_dict(), weights_path)
|