rzimmerdev commited on
Commit
987f571
·
1 Parent(s): 7dc7452

feature: Added manual training and PyTorch Lightning training loops

Browse files
Files changed (1) hide show
  1. src/{main.py → train.py} +26 -5
src/{main.py → train.py} RENAMED
@@ -1,3 +1,5 @@
 
 
1
  from torch import nn, optim
2
  from torch.utils.data import random_split
3
  import pytorch_lightning as pl
@@ -22,13 +24,32 @@ def main():
22
  validate_dataloader = DataLoader(validate_data, num_workers=2)
23
  test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
24
 
25
- # grayscale channels = 1, mnist num_labels = 10
26
- net = CNN(input_channels=1, num_classes=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- pl_net = LitTrainer(net, nn.CrossEntropyLoss(), optim.Adam(net.parameters()))
29
- trainer = pl.Trainer(limit_train_batches=100, max_epochs=1, default_root_dir="../checkpoints")
30
 
31
- trainer.fit(model=pl_net, train_dataloaders=train_dataloader, val_dataloaders=validate_dataloader)
 
 
 
 
 
32
  trainer.test(model=pl_net, dataloaders=test_dataloader)
33
 
34
 
 
1
+ import torch
2
+ import numpy as np
3
  from torch import nn, optim
4
  from torch.utils.data import random_split
5
  import pytorch_lightning as pl
 
24
  validate_dataloader = DataLoader(validate_data, num_workers=2)
25
  test_dataloader = DataLoader(test_data, num_workers=8) # My CPU has 8 cores
26
 
27
+ net = CNN(input_channels=1, num_classes=10).to("cuda")
28
+ opt = optim.Adam(net.parameters(), lr=1e-4)
29
+ loss_fn = nn.CrossEntropyLoss()
30
+ max_epochs = 10
31
+ for i in range(max_epochs):
32
+ for idx, batch in enumerate(train_dataloader):
33
+ x, y = batch
34
+ x = x.to("cuda")
35
+ y = y.to("cuda")
36
+
37
+ y_pred = net(x).reshape(1, -1)
38
+ loss = loss_fn(y_pred, y)
39
+
40
+ opt.zero_grad()
41
+ loss.backward()
42
+ opt.step()
43
 
44
+ if idx % 1000 == 0:
45
+ print(f"Loss: {loss.item()} ({idx} / {len(train_dataloader)})")
46
 
47
+ torch.save(net, "../checkpoints/pytorch/version_1.pt")
48
+
49
+ # grayscale channels = 1, mnist num_labels = 10
50
+ trainer = pl.Trainer(limit_train_batches=100, max_epochs=10, default_root_dir="../checkpoints")
51
+ pl_net = LitTrainer(CNN(input_channels=1, num_classes=10))
52
+ trainer.fit(pl_net, train_dataloader, validate_dataloader)
53
  trainer.test(model=pl_net, dataloaders=test_dataloader)
54
 
55