naxalpha commited on
Commit
d87b97b
·
1 Parent(s): 7a42e6d

add training code

Browse files
Files changed (2) hide show
  1. app.py +109 -0
  2. c4x.py +62 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.optim import AdamW
4
+ from torch.nn import functional as F
5
+ from torch.utils.data import DataLoader
6
+ from torch.nn.utils import clip_grad_norm_
7
+
8
+ import wandb
9
+ from tqdm import tqdm
10
+ from transformers import GPT2LMHeadModel
11
+ from gated_state_spaces_pytorch import GatedStateSpacesLM
12
+ from gated_state_spaces_pytorch.autoregressive_wrapper import AutoregressiveWrapper
13
+
14
+ from c4x import C4X
15
+
16
+
17
+ if __name__ == '__main__':
18
+ wandb.init(
19
+ project="gated-state-space",
20
+ entity="naxalpha",
21
+ )
22
+
23
+ gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl')
24
+ gpt_2.requires_grad_(False)
25
+ gpt_2 = gpt_2.cuda()
26
+
27
+ f_emb = 1600
28
+ model = AutoregressiveWrapper(
29
+ GatedStateSpacesLM(
30
+ num_tokens=50257,
31
+ dim=f_emb,
32
+ depth=24,
33
+ ),
34
+ )
35
+ wandb.watch(model)
36
+
37
+ emb = gpt_2.state_dict()['transformer.wte.weight']
38
+
39
+ model.net.token_emb.weight.requires_grad_(False)
40
+ model.net.token_emb.weight.copy_(emb)
41
+
42
+ model.net.to_logits.weight.requires_grad_(False)
43
+ model.net.to_logits.weight.copy_(emb)
44
+
45
+ model.net.to_logits = nn.Sequential(
46
+ nn.LayerNorm(f_emb),
47
+ model.net.to_logits,
48
+ )
49
+
50
+ model = model.cuda()
51
+ optim = AdamW(model.parameters(), 2e-5)
52
+
53
+ bs = 8
54
+ kk = 128
55
+ dsx = C4X(kk+1)
56
+ dlx = DataLoader(
57
+ dsx,
58
+ batch_size=bs,
59
+ num_workers=16,
60
+ )
61
+
62
+ k = 4
63
+ prog = tqdm(dlx)
64
+ optim.zero_grad()
65
+
66
+ for i, batch in enumerate(prog):
67
+ batch = batch.cuda()
68
+ if i % 2 == 0: # distil
69
+ batch = batch[:, :-1]
70
+ with torch.no_grad():
71
+ logits = gpt_2(batch).logits
72
+ probs = logits.softmax(dim=-1)
73
+ out = model.net(batch)
74
+ los = F.cross_entropy(
75
+ out.flatten(0,1),
76
+ probs.flatten(0,1),
77
+ )
78
+ else: # scratch
79
+ los = model(batch)
80
+
81
+ (los / k).backward()
82
+ if (i+1) % k == 0:
83
+ clip_grad_norm_(
84
+ model.parameters(),
85
+ max_norm=1.,
86
+ )
87
+ optim.step()
88
+ optim.zero_grad()
89
+
90
+ if i % 1000 == 0:
91
+ b, n = 4, 512
92
+ init = torch.tensor([[50256]]*b).cuda()
93
+ prd = model.generate(init, n)
94
+ prd = [dsx.decode(p) for p in prd]
95
+ try:
96
+ wandb.log(dict(
97
+ text=wandb.Html(
98
+ '<hr>'.join(
99
+ p.replace('\n', '<br>') for p in prd
100
+ )
101
+ )), step=i)
102
+ except Exception as ex:
103
+ print('Failed to log to W&B...', ex)
104
+ torch.save(model.state_dict(), 'model.pt')
105
+
106
+ wandb.log(dict(
107
+ loss=los.item(),
108
+ ), step=i)
109
+ prog.set_postfix(loss=los.item())
c4x.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
2
+ import json
3
+ import torch
4
+ import random
5
+ from datasets import load_dataset
6
+ from transformers import GPT2Tokenizer
7
+ from torch.utils.data import Dataset, get_worker_info
8
+
9
+
10
+ def cycled(itr):
11
+ while True:
12
+ for itm in itr:
13
+ yield itm
14
+
15
+ class C4X(Dataset):
16
+
17
+ def __init__(self, seq_len=512, split='train'):
18
+ self.seq = seq_len
19
+ self.ds = load_dataset(
20
+ 'c4',
21
+ name='en',
22
+ split=split,
23
+ streaming=True,
24
+ )
25
+ self.tok = GPT2Tokenizer.from_pretrained('gpt2')
26
+ self.init = False
27
+
28
+ def __len__(self):
29
+ return 1_000_000_000
30
+
31
+ def _init(self):
32
+ if self.init:
33
+ return
34
+ wi = get_worker_info()
35
+ self.ds = cycled(
36
+ self.ds.shuffle(
37
+ seed=wi.seed,
38
+ buffer_size=10_000,
39
+ )
40
+ )
41
+ self.init = True
42
+
43
+ def _get_next(self):
44
+ self._init()
45
+ obj = next(self.ds)['text']
46
+ tkn = self.tok.encode(obj)
47
+ return tkn
48
+
49
+ def _get_full(self):
50
+ obj = []
51
+ while len(obj) < self.seq:
52
+ obj += self._get_next()
53
+ obj.append(self.tok.eos_token_id)
54
+ s = random.randint(0, len(obj)-self.seq)
55
+ return obj[s:s+self.seq]
56
+
57
+ def __getitem__(self, _):
58
+ return torch.tensor(self._get_full())
59
+
60
+ def decode(self, tkns):
61
+ return self.tok.decode(tkns)
62
+