Spaces:
Running
Running
rzimmerdev
commited on
Commit
·
987f571
1
Parent(s):
7dc7452
feature: Added manual training and PyTorch Lightning training loops
Browse files- src/{main.py → train.py} +26 -5
src/{main.py → train.py}
RENAMED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from torch import nn, optim
|
2 |
from torch.utils.data import random_split
|
3 |
import pytorch_lightning as pl
|
@@ -22,13 +24,32 @@ def main():
|
|
22 |
validate_dataloader = DataLoader(validate_data, num_workers=2)
|
23 |
test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
|
24 |
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
32 |
trainer.test(model=pl_net, dataloaders=test_dataloader)
|
33 |
|
34 |
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
from torch import nn, optim
|
4 |
from torch.utils.data import random_split
|
5 |
import pytorch_lightning as pl
|
|
|
24 |
validate_dataloader = DataLoader(validate_data, num_workers=2)
|
25 |
test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
|
26 |
|
27 |
+
net = CNN(input_channels=1, num_classes=10).to("cuda")
|
28 |
+
opt = optim.Adam(net.parameters(), lr=1e-4)
|
29 |
+
loss_fn = nn.CrossEntropyLoss()
|
30 |
+
max_epochs = 10
|
31 |
+
for i in range(max_epochs):
|
32 |
+
for idx, batch in enumerate(train_dataloader):
|
33 |
+
x, y = batch
|
34 |
+
x = x.to("cuda")
|
35 |
+
y = y.to("cuda")
|
36 |
+
|
37 |
+
y_pred = net(x).reshape(1, -1)
|
38 |
+
loss = loss_fn(y_pred, y)
|
39 |
+
|
40 |
+
opt.zero_grad()
|
41 |
+
loss.backward()
|
42 |
+
opt.step()
|
43 |
|
44 |
+
if idx % 1000 == 0:
|
45 |
+
print(f"Loss: {loss.item()} ({idx} / {len(train_dataloader)})")
|
46 |
|
47 |
+
torch.save(net, "../checkpoints/pytorch/version_1.pt")
|
48 |
+
|
49 |
+
# grayscale channels = 1, mnist num_labels = 10
|
50 |
+
trainer = pl.Trainer(limit_train_batches=100, max_epochs=10, default_root_dir="../checkpoints")
|
51 |
+
pl_net = LitTrainer(CNN(input_channels=1, num_classes=10))
|
52 |
+
trainer.fit(pl_net, train_dataloader, validate_dataloader)
|
53 |
trainer.test(model=pl_net, dataloaders=test_dataloader)
|
54 |
|
55 |
|