File size: 3,062 Bytes
9f1ebfc d87b97b ad01999 d87b97b ad01999 eec55b1 ad01999 d87b97b eec55b1 d87b97b ad01999 d87b97b ad01999 9f1ebfc d87b97b e14438a d87b97b e14438a d87b97b ad01999 e14438a ad01999 d87b97b ad01999 d87b97b eec55b1 ad01999 d87b97b ad01999 d87b97b eec55b1 d87b97b eec55b1 ad01999 eec55b1 ad01999 d87b97b ad01999 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
import wandb
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from c4x import C4X
from accelerate import Accelerator
def main():
accelerator = Accelerator(
log_with="wandb",
gradient_accumulation_steps=4,
)
accelerator.init_trackers("gated-state-space")
f_emb = 1600
model = AutoregressiveWrapper(
GatedStateSpacesLM(
num_tokens=50257,
dim=f_emb,
depth=24,
),
)
model.net.token_emb.weight.requires_grad_(False)
model.net.to_logits.weight.requires_grad_(False)
model.net.to_logits = nn.Sequential(
nn.LayerNorm(f_emb),
model.net.to_logits,
)
model = model.to(accelerator.device)
if accelerator.is_main_process:
wandb.watch(model)
model.load_state_dict(torch.load('model.pt'))
optim = AdamW(model.parameters(), 2e-5)
bs = 1
kk = 2048
dsx = C4X(kk+1)
dlx = DataLoader(
dsx,
batch_size=bs,
num_workers=4,
)
prog = tqdm(dlx, disable=not accelerator.is_main_process)
model = accelerator.prepare(model)
optim, dlx = accelerator.prepare(optim, dlx)
optim.zero_grad()
for i, batch in enumerate(prog):
batch = batch.to(accelerator.device)
with accelerator.accumulate(model):
with accelerator.autocast():
los = model(batch)
accelerator.backward(los)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
model.parameters(),
1.0,
)
optim.step()
optim.zero_grad()
if i % 1000 == 0:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
b, n = 4, 512
init = torch.tensor([[50256]]*b).to(accelerator.device)
prd = unwrapped_model.generate(init, n)
prd = [dsx.decode(p) for p in prd]
try:
accelerator.log(dict(
text=wandb.Html(
'<hr>'.join(
p.replace('\n', '<br>') for p in prd
)
)), step=i)
except Exception as ex:
accelerator.print('Failed to log to W&B...', ex)
accelerator.save(unwrapped_model.state_dict(), 'model2.pt')
if i % 10 == 0:
accelerator.log(dict(
loss=los.item(),
), step=i)
prog.set_postfix(loss=los.item())
if __name__ == '__main__':
main()
|