naxalpha's picture
wip
ad01999
raw
history blame
3.22 kB
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
import wandb
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from c4x import C4X
from accelerate import Accelerator
def main():
accelerator = Accelerator(
gradient_accumulation_steps=4,
)
if accelerator.is_main_process:
wandb.init(
project="gated-state-space",
entity="naxalpha",
)
f_emb = 1600
model = AutoregressiveWrapper(
GatedStateSpacesLM(
num_tokens=50257,
dim=f_emb,
depth=24,
),
)
model.net.token_emb.weight.requires_grad_(False)
model.net.to_logits.weight.requires_grad_(False)
model.net.to_logits = nn.Sequential(
nn.LayerNorm(f_emb),
model.net.to_logits,
)
model = model.to(accelerator.device)
if accelerator.is_main_process:
wandb.watch(model)
model.load_state_dict(torch.load('model.pt'))
optim = AdamW(model.parameters(), 2e-5)
bs = 16
kk = 128
dsx = C4X(kk+1)
dlx = DataLoader(
dsx,
batch_size=bs,
num_workers=8,
)
k = 4
prog = tqdm(dlx, disable=not accelerator.is_main_process)
model, optim, dlx = accelerator.prepare(model, optim, dlx)
optim.zero_grad()
for i, batch in enumerate(prog):
batch = batch.to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
los = model(batch)
accelerator.backward(los)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
model.parameters(),
1.0,
)
optim.step()
optim.zero_grad()
if i % 1000 == 0 and accelerator.is_main_process:
print('generating...')
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
b, n = 4, 512
init = torch.tensor([[50256]]*b).to(accelerator.device)
prd = unwrapped_model.generate(init, n)
prd = [dsx.decode(p) for p in prd]
try:
wandb.log(dict(
text=wandb.Html(
'<hr>'.join(
p.replace('\n', '<br>') for p in prd
)
)), step=i)
except Exception as ex:
print('Failed to log to W&B...', ex)
accelerator.save(unwrapped_model.state_dict(), 'model.pt')
if i % 10 == 0 and accelerator.is_main_process:
print('logging...')
wandb.log(dict(
loss=los.item(),
), step=i)
prog.set_postfix(loss=los.item())
if __name__ == '__main__':
main()