naxalpha commited on
Commit
9f1ebfc
1 Parent(s): d230351

update to resume training

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. app.py +10 -19
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ wandb
2
+ __pycache__
3
+ .ipynb_checkpoints
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.optim import AdamW
@@ -20,9 +21,9 @@ if __name__ == '__main__':
20
  entity="naxalpha",
21
  )
22
 
23
- gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl')
24
- gpt_2.requires_grad_(False)
25
- gpt_2 = gpt_2.cuda()
26
 
27
  f_emb = 1600
28
  model = AutoregressiveWrapper(
@@ -34,19 +35,20 @@ if __name__ == '__main__':
34
  )
35
  wandb.watch(model)
36
 
37
- emb = gpt_2.state_dict()['transformer.wte.weight']
38
 
39
  model.net.token_emb.weight.requires_grad_(False)
40
- model.net.token_emb.weight.copy_(emb)
41
 
42
  model.net.to_logits.weight.requires_grad_(False)
43
- model.net.to_logits.weight.copy_(emb)
44
 
45
  model.net.to_logits = nn.Sequential(
46
  nn.LayerNorm(f_emb),
47
  model.net.to_logits,
48
  )
49
-
 
50
  model = model.cuda()
51
  optim = AdamW(model.parameters(), 2e-5)
52
 
@@ -65,18 +67,7 @@ if __name__ == '__main__':
65
 
66
  for i, batch in enumerate(prog):
67
  batch = batch.cuda()
68
- if i % 2 == 0: # distil
69
- batch = batch[:, :-1]
70
- with torch.no_grad():
71
- logits = gpt_2(batch).logits
72
- probs = logits.softmax(dim=-1)
73
- out = model.net(batch)
74
- los = F.cross_entropy(
75
- out.flatten(0,1),
76
- probs.flatten(0,1),
77
- )
78
- else: # scratch
79
- los = model(batch)
80
 
81
  (los / k).backward()
82
  if (i+1) % k == 0:
 
1
+ # pip install accelerate datasets transformers huggingface_hub wandb gated_state_spaces_pytorch
2
  import torch
3
  import torch.nn as nn
4
  from torch.optim import AdamW
 
21
  entity="naxalpha",
22
  )
23
 
24
+ # gpt_2 = GPT2LMHeadModel.from_pretrained('gpt2-xl')
25
+ # gpt_2.requires_grad_(False)
26
+ # gpt_2 = gpt_2.cuda()
27
 
28
  f_emb = 1600
29
  model = AutoregressiveWrapper(
 
35
  )
36
  wandb.watch(model)
37
 
38
+ # emb = gpt_2.state_dict()['transformer.wte.weight']
39
 
40
  model.net.token_emb.weight.requires_grad_(False)
41
+ # model.net.token_emb.weight.copy_(emb)
42
 
43
  model.net.to_logits.weight.requires_grad_(False)
44
+ # model.net.to_logits.weight.copy_(emb)
45
 
46
  model.net.to_logits = nn.Sequential(
47
  nn.LayerNorm(f_emb),
48
  model.net.to_logits,
49
  )
50
+
51
+ model.load_state_dict(torch.load('model.pt'))
52
  model = model.cuda()
53
  optim = AdamW(model.parameters(), 2e-5)
54
 
 
67
 
68
  for i, batch in enumerate(prog):
69
  batch = batch.cuda()
70
+ los = model(batch)
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  (los / k).backward()
73
  if (i+1) % k == 0: