# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch import torch import torch.nn as nn from torch.optim import AdamW from torch.nn import functional as F from torch.utils.data import DataLoader from torch.nn.utils import clip_grad_norm_ import wandb from tqdm import tqdm from transformers import GPT2LMHeadModel from gated_state_spaces_pytorch import GatedStateSpacesLM from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper from c4x import C4X from accelerate import Accelerator def main(): accelerator = Accelerator( gradient_accumulation_steps=4, ) if accelerator.is_main_process: wandb.init( project="gated-state-space", entity="naxalpha", ) f_emb = 1600 model = AutoregressiveWrapper( GatedStateSpacesLM( num_tokens=50257, dim=f_emb, depth=24, ), ) model.net.token_emb.weight.requires_grad_(False) model.net.to_logits.weight.requires_grad_(False) model.net.to_logits = nn.Sequential( nn.LayerNorm(f_emb), model.net.to_logits, ) model = model.to(accelerator.device) if accelerator.is_main_process: wandb.watch(model) model.load_state_dict(torch.load('model.pt')) optim = AdamW(model.parameters(), 2e-5) bs = 16 kk = 128 dsx = C4X(kk+1) dlx = DataLoader( dsx, batch_size=bs, num_workers=8, ) k = 4 prog = tqdm(dlx, disable=not accelerator.is_main_process) model, optim, dlx = accelerator.prepare(model, optim, dlx) optim.zero_grad() for i, batch in enumerate(prog): batch = batch.to(accelerator.device) with accelerator.accumulate(model): with accelerator.autocast(): los = model(batch) accelerator.backward(los) if accelerator.sync_gradients: accelerator.clip_grad_norm_( model.parameters(), 1.0, ) optim.step() optim.zero_grad() if i % 1000 == 0 and accelerator.is_main_process: print('generating...') accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) b, n = 4, 512 init = torch.tensor([[50256]]*b).to(accelerator.device) prd = unwrapped_model.generate(init, n) prd = [dsx.decode(p) for p in prd] try: wandb.log(dict( text=wandb.Html( '