File size: 3,037 Bytes
9f1ebfc
d87b97b
 
 
 
 
 
 
 
 
 
 
 
 
 
ad01999
d87b97b
 
ad01999
 
eec55b1
ad01999
d87b97b
eec55b1
d87b97b
 
 
 
 
 
 
 
ad01999
d87b97b
 
 
 
 
 
ad01999
 
 
 
9f1ebfc
 
d87b97b
 
eec55b1
d87b97b
 
 
 
 
ad01999
d87b97b
 
ad01999
 
 
 
d87b97b
 
ad01999
 
 
 
 
 
 
 
 
 
d87b97b
 
 
eec55b1
ad01999
 
d87b97b
ad01999
 
d87b97b
 
eec55b1
d87b97b
 
 
 
 
 
eec55b1
 
ad01999
eec55b1
 
ad01999
 
 
d87b97b
ad01999
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_

import wandb
from tqdm import tqdm
from transformers import GPT2LMHeadModel
from gated_state_spaces_pytorch import GatedStateSpacesLM
from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper

from c4x import C4X
from accelerate import Accelerator


def main():
    accelerator = Accelerator(
        log_with="wandb",
        gradient_accumulation_steps=4,
    )
    accelerator.init_trackers("gated-state-space")

    f_emb = 1600
    model = AutoregressiveWrapper(
        GatedStateSpacesLM(
            num_tokens=50257,
            dim=f_emb,
            depth=24,
        ),
    )        
    model.net.token_emb.weight.requires_grad_(False)
    model.net.to_logits.weight.requires_grad_(False)
    model.net.to_logits = nn.Sequential(
        nn.LayerNorm(f_emb),
        model.net.to_logits,
    )
    model = model.to(accelerator.device)
    
    if accelerator.is_main_process:
        wandb.watch(model)
    
    model.load_state_dict(torch.load('model.pt'))
    optim = AdamW(model.parameters(), 2e-5)

    bs = 24
    kk = 128
    dsx = C4X(kk+1)
    dlx = DataLoader(
        dsx,
        batch_size=bs, 
        num_workers=8,
    )

    prog = tqdm(dlx, disable=not accelerator.is_main_process)
    
    model, optim, dlx = accelerator.prepare(model, optim, dlx)
    
    optim.zero_grad()
    for i, batch in enumerate(prog):
        batch = batch.to(accelerator.device)
        with accelerator.accumulate(model):
            with accelerator.autocast():
                los = model(batch)
            accelerator.backward(los)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(
                    model.parameters(), 
                    1.0,
                )
            optim.step()
            optim.zero_grad()

        if i % 1000 == 0:
            accelerator.wait_for_everyone()
            unwrapped_model = accelerator.unwrap_model(model)
            b, n = 4, 512
            init = torch.tensor([[50256]]*b).to(accelerator.device)
            prd = unwrapped_model.generate(init, n)
            prd = [dsx.decode(p) for p in prd]
            try:
                accelerator.log(dict(
                    text=wandb.Html(
                        '<hr>'.join(
                            p.replace('\n', '<br>') for p in prd
                        )
                )), step=i)
            except Exception as ex:
                accelerator.print('Failed to log to W&B...', ex)
            accelerator.save(unwrapped_model.state_dict(), 'model2.pt')
        
        if i % 10 == 0:
            accelerator.log(dict(
                loss=los.item(),
            ), step=i)
            prog.set_postfix(loss=los.item())

if __name__ == '__main__':
    main()