harry commited on
Commit
91c78a9
1 Parent(s): d723609

feat: add init model

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .pytest_cache/
4
+ wandb/
5
+ checkpoints/
6
+ *.egg-info/
7
+ dist/
README.md CHANGED
@@ -1,9 +1,11 @@
1
  ---
2
  license: mit
3
  datasets:
4
- - ylecun/mnist
5
  language:
6
- - en
7
  ---
8
 
9
- MNIST classifier model for learning transformer fundamentals.
 
 
 
1
  ---
2
  license: mit
3
  datasets:
4
+ - ylecun/mnist
5
  language:
6
+ - en
7
  ---
8
 
9
+ # MNIST classifier
10
+
11
+ MNIST classifier model for learning transformer fundamentals.
mnist_classifier/configs/config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training:
2
+ batch_size: 64
3
+ max_epochs: 10
4
+ learning_rate: 0.001
5
+ early_stopping_patience: 5
6
+
7
+ model:
8
+ conv1_channels: 32
9
+ conv2_channels: 64
10
+ fc1_size: 128
11
+ dropout_rate: 0.25
12
+
13
+ wandb:
14
+ project: "mnist-classifier"
15
+ entity: "bardenha"
mnist_classifier/data/datamodule.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ import pytorch_lightning as pl
4
+ from datasets import load_dataset
5
+ from torch.utils.data import DataLoader
6
+
7
+ class MNISTDataModule(pl.LightningDataModule):
8
+ def __init__(self, config: Dict[str, Any]):
9
+ super().__init__()
10
+ self.config = config
11
+
12
+ def setup(self, stage=None):
13
+ self.dataset = load_dataset('mnist')
14
+ self.dataset = self.dataset.with_transform(self.config.transform_dataset)
15
+
16
+ def train_dataloader(self):
17
+ return DataLoader(
18
+ self.dataset['train'],
19
+ batch_size=self.config.batch_size,
20
+ shuffle=True
21
+ )
22
+
23
+ def val_dataloader(self):
24
+ return DataLoader(
25
+ self.dataset['test'], # Using test set as validation
26
+ batch_size=self.config.batch_size
27
+ )
mnist_classifier/models/mnist_model.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+
3
+ import pytorch_lightning as pl
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchmetrics
7
+ import wandb
8
+
9
+ # Simple CNN architecture for MNIST
10
+ class MNISTNet(nn.Module):
11
+ def __init__(self, config: Dict[str, Any]):
12
+ super().__init__()
13
+ self.conv1 = nn.Conv2d(1, config['model']['conv1_channels'], kernel_size=3)
14
+ self.conv2 = nn.Conv2d(config['model']['conv1_channels'],
15
+ config['model']['conv2_channels'], kernel_size=3)
16
+ self.pool = nn.MaxPool2d(2)
17
+ self.dropout = nn.Dropout(config['model']['dropout_rate'])
18
+ self.fc1 = nn.Linear(config['model']['conv2_channels'] * 5 * 5,
19
+ config['model']['fc1_size'])
20
+ self.fc2 = nn.Linear(config['model']['fc1_size'], 10)
21
+
22
+ def forward(self, x):
23
+ x = torch.relu(self.conv1(x))
24
+ x = self.pool(torch.relu(self.conv2(x)))
25
+ x = self.dropout(x)
26
+ x = x.view(-1, 64 * 5 * 5)
27
+ x = torch.relu(self.fc1(x))
28
+ x = self.fc2(x)
29
+ return x
30
+
31
+
32
+
33
+ class MNISTModule(pl.LightningModule):
34
+ def __init__(self, config: Dict[str, Any]):
35
+ super().__init__()
36
+ self.config = config
37
+ self.model = MNISTNet(config)
38
+
39
+ # Initialize metrics
40
+ self.train_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
41
+ self.val_accuracy = torchmetrics.Accuracy(task='multiclass', num_classes=10)
42
+ self.train_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10)
43
+ self.val_f1 = torchmetrics.F1Score(task='multiclass', num_classes=10)
44
+ self.confusion_matrix = torchmetrics.ConfusionMatrix(task='multiclass', num_classes=10)
45
+
46
+ def forward(self, x):
47
+ return self.model(x)
48
+
49
+ def training_step(self, batch, batch_idx):
50
+ x, y = batch['pixel_values'], batch['label']
51
+ logits = self(x)
52
+ loss = nn.CrossEntropyLoss()(logits, y)
53
+
54
+ # Calculate and log metrics
55
+ preds = torch.argmax(logits, dim=1)
56
+ self.train_accuracy(preds, y)
57
+ self.train_f1(preds, y)
58
+
59
+ # Log metrics
60
+ self.log('train_loss', loss, prog_bar=True)
61
+ self.log('train_accuracy', self.train_accuracy, prog_bar=True)
62
+ self.log('train_f1', self.train_f1, prog_bar=True)
63
+
64
+ return loss
65
+
66
+ def validation_step(self, batch, batch_idx):
67
+ x, y = batch['pixel_values'], batch['label']
68
+ logits = self(x)
69
+ loss = nn.CrossEntropyLoss()(logits, y)
70
+
71
+ # Calculate metrics
72
+ preds = torch.argmax(logits, dim=1)
73
+ self.val_accuracy(preds, y)
74
+ self.val_f1(preds, y)
75
+ self.confusion_matrix(preds, y)
76
+
77
+ # Log metrics
78
+ self.log('val_loss', loss, prog_bar=True)
79
+ self.log('val_accuracy', self.val_accuracy, prog_bar=True)
80
+ self.log('val_f1', self.val_f1, prog_bar=True)
81
+
82
+ # Log sample predictions periodically
83
+ if batch_idx == 0: # First batch of each epoch
84
+ self._log_sample_predictions(x, y, preds)
85
+
86
+ def _log_sample_predictions(self, images, labels, predictions):
87
+ # Log a grid of sample predictions
88
+ if self.logger:
89
+ n_samples = min(16, len(images))
90
+ self.logger.experiment.log({
91
+ "sample_predictions": [
92
+ wandb.Image(
93
+ images[i],
94
+ caption=f"True: {labels[i].item()} Pred: {predictions[i].item()}"
95
+ )
96
+ for i in range(n_samples)
97
+ ]
98
+ })
99
+
100
+ def on_validation_epoch_end(self):
101
+ # Log confusion matrix at the end of each validation epoch
102
+ conf_mat = self.confusion_matrix.compute()
103
+ self.logger.experiment.log({
104
+ "confusion_matrix": wandb.plot.confusion_matrix(
105
+ probs=None,
106
+ y_true=conf_mat.flatten(),
107
+ preds=None,
108
+ class_names=range(10)
109
+ )
110
+ })
111
+ self.confusion_matrix.reset()
112
+
113
+ def configure_optimizers(self):
114
+ optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
115
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
116
+ optimizer, mode='min', factor=0.1, patience=3, verbose=True
117
+ )
118
+ return {
119
+ "optimizer": optimizer,
120
+ "lr_scheduler": {
121
+ "scheduler": scheduler,
122
+ "monitor": "val_loss"
123
+ }
124
+ }
mnist_classifier/train.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.loggers import WandbLogger
3
+ from pathlib import Path
4
+
5
+ from mnist_classifier.models.mnist_model import MNISTModule
6
+ from mnist_classifier.data.datamodule import MNISTDataModule
7
+ from mnist_classifier.utils.metrics import load_config
8
+
9
+ def main():
10
+ config = load_config(Path("mnist_classifier/configs/config.yaml"))
11
+
12
+ # Initialize wandb logger
13
+ wandb_logger = WandbLogger(
14
+ project=config['wandb']['project'],
15
+ entity=config['wandb']['entity']
16
+ )
17
+
18
+ # Initialize trainer
19
+ trainer = pl.Trainer(
20
+ max_epochs=config['training']['max_epochs'],
21
+ accelerator='gpu',
22
+ devices=[0],
23
+ logger=wandb_logger,
24
+ callbacks=[
25
+ pl.callbacks.ModelCheckpoint(
26
+ dirpath='checkpoints',
27
+ filename='mnist-{epoch:02d}-{val_loss:.2f}',
28
+ save_top_k=3,
29
+ monitor='val_loss',
30
+ mode='min'
31
+ ),
32
+ pl.callbacks.EarlyStopping(
33
+ monitor='val_loss',
34
+ patience=config['training']['early_stopping_patience'],
35
+ mode='min'
36
+ ),
37
+ pl.callbacks.LearningRateMonitor(logging_interval='epoch')
38
+ ]
39
+ )
40
+
41
+ # Initialize data module and model
42
+ data_module = MNISTDataModule(config)
43
+ model = MNISTModule(config)
44
+
45
+ # Train
46
+ trainer.fit(model, data_module)
47
+
48
+ if __name__ == "__main__":
49
+ main()
mnist_classifier/utils/metrics.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import yaml
2
+ from pathlib import Path
3
+
4
+ def load_config(config_path: str):
5
+ with open(config_path, 'r') as f:
6
+ return yaml.safe_load(f)
poetry.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -8,7 +8,7 @@ readme = "README.md"
8
 
9
  [tool.poetry.dependencies]
10
  python = "^3.10"
11
- torch = "^2.0.0"
12
  torchvision = "^0.15.0"
13
  pytorch-lightning = "^2.0.0"
14
  wandb = "^0.15.0"
 
8
 
9
  [tool.poetry.dependencies]
10
  python = "^3.10"
11
+ torch = "^2.4.0"
12
  torchvision = "^0.15.0"
13
  pytorch-lightning = "^2.0.0"
14
  wandb = "^0.15.0"
tests/test_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+ from mnist_classifier.models.mnist_model import MNISTNet
4
+ from mnist_classifier.utils.metrics import load_config
5
+
6
+ def test_mnist_net_forward():
7
+ config = load_config('mnist_classifier/configs/config.yaml')
8
+ model = MNISTNet(config)
9
+ x = torch.randn(1, 1, 28, 28)
10
+ output = model(x)
11
+ assert output.shape == (1, 10)