|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.optim import AdamW |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.nn.utils import clip_grad_norm_ |
|
|
|
import wandb |
|
from tqdm import tqdm |
|
from transformers import GPT2LMHeadModel |
|
from gated_state_spaces_pytorch import GatedStateSpacesLM |
|
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper |
|
|
|
from c4x import C4X |
|
from accelerate import Accelerator |
|
|
|
|
|
def main(): |
|
accelerator = Accelerator( |
|
log_with="wandb", |
|
gradient_accumulation_steps=4, |
|
) |
|
accelerator.init_trackers("gated-state-space") |
|
|
|
f_emb = 1600 |
|
model = AutoregressiveWrapper( |
|
GatedStateSpacesLM( |
|
num_tokens=50257, |
|
dim=f_emb, |
|
depth=24, |
|
), |
|
) |
|
|
|
|
|
model.net.to_logits = nn.Sequential( |
|
nn.LayerNorm(f_emb), |
|
model.net.to_logits, |
|
) |
|
model = model.to(accelerator.device) |
|
|
|
if accelerator.is_main_process: |
|
wandb.watch(model) |
|
|
|
model.load_state_dict(torch.load('model.pt')) |
|
optim = AdamW(model.parameters(), 5e-6) |
|
|
|
bs = 1 |
|
kk = 2048 |
|
dsx = C4X(kk+1) |
|
dlx = DataLoader( |
|
dsx, |
|
batch_size=bs, |
|
num_workers=4, |
|
) |
|
|
|
prog = tqdm(dlx, disable=not accelerator.is_main_process) |
|
|
|
model = accelerator.prepare(model) |
|
optim, dlx = accelerator.prepare(optim, dlx) |
|
|
|
optim.zero_grad() |
|
for i, batch in enumerate(prog): |
|
batch = batch.to(accelerator.device) |
|
with accelerator.accumulate(model): |
|
with accelerator.autocast(): |
|
los = model(batch) |
|
accelerator.backward(los) |
|
if accelerator.sync_gradients: |
|
accelerator.clip_grad_norm_( |
|
model.parameters(), |
|
1.0, |
|
) |
|
optim.step() |
|
optim.zero_grad() |
|
|
|
if i % 1000 == 0: |
|
accelerator.wait_for_everyone() |
|
unwrapped_model = accelerator.unwrap_model(model) |
|
b, n = 1, 2048 |
|
init = torch.tensor([[50256]]*b).to(accelerator.device) |
|
prd = unwrapped_model.generate(init, n) |
|
prd = [dsx.decode(p) for p in prd] |
|
try: |
|
accelerator.log(dict( |
|
text=wandb.Html( |
|
'<hr>'.join( |
|
p.replace('\n', '<br>') for p in prd |
|
) |
|
)), step=i) |
|
except Exception as ex: |
|
accelerator.print('Failed to log to W&B...', ex) |
|
accelerator.save(unwrapped_model.state_dict(), 'model2.pt') |
|
|
|
if i % 10 == 0: |
|
accelerator.log(dict( |
|
loss=los.item(), |
|
), step=i) |
|
prog.set_postfix(loss=los.item()) |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|