naxalpha commited on
Commit
eec55b1
1 Parent(s): ad01999

distributed training

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -18,14 +18,10 @@ from accelerate import Accelerator
18
 
19
  def main():
20
  accelerator = Accelerator(
 
21
  gradient_accumulation_steps=4,
22
  )
23
-
24
- if accelerator.is_main_process:
25
- wandb.init(
26
- project="gated-state-space",
27
- entity="naxalpha",
28
- )
29
 
30
  f_emb = 1600
31
  model = AutoregressiveWrapper(
@@ -49,7 +45,7 @@ def main():
49
  model.load_state_dict(torch.load('model.pt'))
50
  optim = AdamW(model.parameters(), 2e-5)
51
 
52
- bs = 16
53
  kk = 128
54
  dsx = C4X(kk+1)
55
  dlx = DataLoader(
@@ -58,7 +54,6 @@ def main():
58
  num_workers=8,
59
  )
60
 
61
- k = 4
62
  prog = tqdm(dlx, disable=not accelerator.is_main_process)
63
 
64
  model, optim, dlx = accelerator.prepare(model, optim, dlx)
@@ -78,8 +73,7 @@ def main():
78
  optim.step()
79
  optim.zero_grad()
80
 
81
- if i % 1000 == 0 and accelerator.is_main_process:
82
- print('generating...')
83
  accelerator.wait_for_everyone()
84
  unwrapped_model = accelerator.unwrap_model(model)
85
  b, n = 4, 512
@@ -87,19 +81,18 @@ def main():
87
  prd = unwrapped_model.generate(init, n)
88
  prd = [dsx.decode(p) for p in prd]
89
  try:
90
- wandb.log(dict(
91
  text=wandb.Html(
92
  '<hr>'.join(
93
  p.replace('\n', '<br>') for p in prd
94
  )
95
  )), step=i)
96
  except Exception as ex:
97
- print('Failed to log to W&B...', ex)
98
- accelerator.save(unwrapped_model.state_dict(), 'model.pt')
99
 
100
- if i % 10 == 0 and accelerator.is_main_process:
101
- print('logging...')
102
- wandb.log(dict(
103
  loss=los.item(),
104
  ), step=i)
105
  prog.set_postfix(loss=los.item())
 
18
 
19
  def main():
20
  accelerator = Accelerator(
21
+ log_with="wandb",
22
  gradient_accumulation_steps=4,
23
  )
24
+ accelerator.init_trackers("gated-state-space")
 
 
 
 
 
25
 
26
  f_emb = 1600
27
  model = AutoregressiveWrapper(
 
45
  model.load_state_dict(torch.load('model.pt'))
46
  optim = AdamW(model.parameters(), 2e-5)
47
 
48
+ bs = 24
49
  kk = 128
50
  dsx = C4X(kk+1)
51
  dlx = DataLoader(
 
54
  num_workers=8,
55
  )
56
 
 
57
  prog = tqdm(dlx, disable=not accelerator.is_main_process)
58
 
59
  model, optim, dlx = accelerator.prepare(model, optim, dlx)
 
73
  optim.step()
74
  optim.zero_grad()
75
 
76
+ if i % 1000 == 0:
 
77
  accelerator.wait_for_everyone()
78
  unwrapped_model = accelerator.unwrap_model(model)
79
  b, n = 4, 512
 
81
  prd = unwrapped_model.generate(init, n)
82
  prd = [dsx.decode(p) for p in prd]
83
  try:
84
+ accelerator.log(dict(
85
  text=wandb.Html(
86
  '<hr>'.join(
87
  p.replace('\n', '<br>') for p in prd
88
  )
89
  )), step=i)
90
  except Exception as ex:
91
+ accelerator.print('Failed to log to W&B...', ex)
92
+ accelerator.save(unwrapped_model.state_dict(), 'model2.pt')
93
 
94
+ if i % 10 == 0:
95
+ accelerator.log(dict(
 
96
  loss=los.item(),
97
  ), step=i)
98
  prog.set_postfix(loss=los.item())