distributed training
Browse files
app.py
CHANGED
@@ -18,14 +18,10 @@ from accelerate import Accelerator
|
|
18 |
|
19 |
def main():
|
20 |
accelerator = Accelerator(
|
|
|
21 |
gradient_accumulation_steps=4,
|
22 |
)
|
23 |
-
|
24 |
-
if accelerator.is_main_process:
|
25 |
-
wandb.init(
|
26 |
-
project="gated-state-space",
|
27 |
-
entity="naxalpha",
|
28 |
-
)
|
29 |
|
30 |
f_emb = 1600
|
31 |
model = AutoregressiveWrapper(
|
@@ -49,7 +45,7 @@ def main():
|
|
49 |
model.load_state_dict(torch.load('model.pt'))
|
50 |
optim = AdamW(model.parameters(), 2e-5)
|
51 |
|
52 |
-
bs =
|
53 |
kk = 128
|
54 |
dsx = C4X(kk+1)
|
55 |
dlx = DataLoader(
|
@@ -58,7 +54,6 @@ def main():
|
|
58 |
num_workers=8,
|
59 |
)
|
60 |
|
61 |
-
k = 4
|
62 |
prog = tqdm(dlx, disable=not accelerator.is_main_process)
|
63 |
|
64 |
model, optim, dlx = accelerator.prepare(model, optim, dlx)
|
@@ -78,8 +73,7 @@ def main():
|
|
78 |
optim.step()
|
79 |
optim.zero_grad()
|
80 |
|
81 |
-
if i % 1000 == 0
|
82 |
-
print('generating...')
|
83 |
accelerator.wait_for_everyone()
|
84 |
unwrapped_model = accelerator.unwrap_model(model)
|
85 |
b, n = 4, 512
|
@@ -87,19 +81,18 @@ def main():
|
|
87 |
prd = unwrapped_model.generate(init, n)
|
88 |
prd = [dsx.decode(p) for p in prd]
|
89 |
try:
|
90 |
-
|
91 |
text=wandb.Html(
|
92 |
'<hr>'.join(
|
93 |
p.replace('\n', '<br>') for p in prd
|
94 |
)
|
95 |
)), step=i)
|
96 |
except Exception as ex:
|
97 |
-
print('Failed to log to W&B...', ex)
|
98 |
-
accelerator.save(unwrapped_model.state_dict(), '
|
99 |
|
100 |
-
if i % 10 == 0
|
101 |
-
|
102 |
-
wandb.log(dict(
|
103 |
loss=los.item(),
|
104 |
), step=i)
|
105 |
prog.set_postfix(loss=los.item())
|
|
|
18 |
|
19 |
def main():
|
20 |
accelerator = Accelerator(
|
21 |
+
log_with="wandb",
|
22 |
gradient_accumulation_steps=4,
|
23 |
)
|
24 |
+
accelerator.init_trackers("gated-state-space")
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
f_emb = 1600
|
27 |
model = AutoregressiveWrapper(
|
|
|
45 |
model.load_state_dict(torch.load('model.pt'))
|
46 |
optim = AdamW(model.parameters(), 2e-5)
|
47 |
|
48 |
+
bs = 24
|
49 |
kk = 128
|
50 |
dsx = C4X(kk+1)
|
51 |
dlx = DataLoader(
|
|
|
54 |
num_workers=8,
|
55 |
)
|
56 |
|
|
|
57 |
prog = tqdm(dlx, disable=not accelerator.is_main_process)
|
58 |
|
59 |
model, optim, dlx = accelerator.prepare(model, optim, dlx)
|
|
|
73 |
optim.step()
|
74 |
optim.zero_grad()
|
75 |
|
76 |
+
if i % 1000 == 0:
|
|
|
77 |
accelerator.wait_for_everyone()
|
78 |
unwrapped_model = accelerator.unwrap_model(model)
|
79 |
b, n = 4, 512
|
|
|
81 |
prd = unwrapped_model.generate(init, n)
|
82 |
prd = [dsx.decode(p) for p in prd]
|
83 |
try:
|
84 |
+
accelerator.log(dict(
|
85 |
text=wandb.Html(
|
86 |
'<hr>'.join(
|
87 |
p.replace('\n', '<br>') for p in prd
|
88 |
)
|
89 |
)), step=i)
|
90 |
except Exception as ex:
|
91 |
+
accelerator.print('Failed to log to W&B...', ex)
|
92 |
+
accelerator.save(unwrapped_model.state_dict(), 'model2.pt')
|
93 |
|
94 |
+
if i % 10 == 0:
|
95 |
+
accelerator.log(dict(
|
|
|
96 |
loss=los.item(),
|
97 |
), step=i)
|
98 |
prog.set_postfix(loss=los.item())
|