Spaces:
Runtime error
Runtime error
Soutrik
commited on
Commit
•
de7d21e
1
Parent(s):
b0bdbcf
new train model
Browse files- configs/callbacks/early_stopping.yaml +1 -1
- configs/callbacks/model_checkpoint.yaml +1 -1
- configs/callbacks/rich_model_summary.yaml +2 -1
- configs/callbacks/rich_progress_bar.yaml +2 -1
- configs/experiment/catdog_experiment.yaml +10 -10
- configs/model/catdog_classifier.yaml +4 -4
- configs/trainer/default.yaml +2 -0
- docker-compose.yaml +2 -2
- src/datamodules/catdog_datamodule.py +8 -10
- src/models/catdog_model.py +35 -39
- src/train_new.py +1 -1
configs/callbacks/early_stopping.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
2 |
|
3 |
early_stopping:
|
4 |
-
_target_:
|
5 |
monitor: val_loss # quantity to be monitored, must be specified !!!
|
6 |
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
7 |
patience: 3 # number of checks with no improvement after which training will be stopped
|
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
|
2 |
|
3 |
early_stopping:
|
4 |
+
_target_: lightning.pytorch.callbacks.EarlyStopping
|
5 |
monitor: val_loss # quantity to be monitored, must be specified !!!
|
6 |
min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
|
7 |
patience: 3 # number of checks with no improvement after which training will be stopped
|
configs/callbacks/model_checkpoint.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
2 |
|
3 |
model_checkpoint:
|
4 |
-
_target_:
|
5 |
dirpath: null # directory to save the model file
|
6 |
filename: best-checkpoint # checkpoint filename
|
7 |
monitor: val_loss # name of the logged metric which determines when model is improving
|
|
|
1 |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
|
2 |
|
3 |
model_checkpoint:
|
4 |
+
_target_: lightning.pytorch.callbacks.ModelCheckpoint
|
5 |
dirpath: null # directory to save the model file
|
6 |
filename: best-checkpoint # checkpoint filename
|
7 |
monitor: val_loss # name of the logged metric which determines when model is improving
|
configs/callbacks/rich_model_summary.yaml
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
rich_model_summary:
|
2 |
-
_target_:
|
3 |
max_depth: 1
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
|
2 |
rich_model_summary:
|
3 |
+
_target_: lightning.pytorch.callbacks.RichModelSummary
|
4 |
max_depth: 1
|
configs/callbacks/rich_progress_bar.yaml
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
rich_progress_bar:
|
2 |
-
_target_:
|
3 |
refresh_rate: 1
|
|
|
1 |
+
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichProgressBar.html
|
2 |
rich_progress_bar:
|
3 |
+
_target_: lightning.pytorch.callbacks.RichProgressBar
|
4 |
refresh_rate: 1
|
configs/experiment/catdog_experiment.yaml
CHANGED
@@ -18,28 +18,28 @@ seed: 42
|
|
18 |
name: "catdog_experiment"
|
19 |
|
20 |
data:
|
21 |
-
|
22 |
-
batch_size:
|
23 |
num_workers: 8
|
24 |
pin_memory: True
|
25 |
image_size: 224
|
26 |
|
27 |
model:
|
28 |
-
lr:
|
29 |
weight_decay: 1e-5
|
30 |
-
factor: 0.
|
31 |
-
patience:
|
32 |
min_lr: 1e-6
|
33 |
num_classes: 2
|
34 |
patch_size: 16
|
35 |
-
embed_dim:
|
36 |
-
depth:
|
37 |
-
num_heads:
|
38 |
-
mlp_ratio:
|
39 |
|
40 |
trainer:
|
41 |
min_epochs: 1
|
42 |
-
max_epochs:
|
43 |
|
44 |
callbacks:
|
45 |
model_checkpoint:
|
|
|
18 |
name: "catdog_experiment"
|
19 |
|
20 |
data:
|
21 |
+
data_dir: "cats_and_dogs_filtered"
|
22 |
+
batch_size: 64
|
23 |
num_workers: 8
|
24 |
pin_memory: True
|
25 |
image_size: 224
|
26 |
|
27 |
model:
|
28 |
+
lr: 5e-5
|
29 |
weight_decay: 1e-5
|
30 |
+
factor: 0.5
|
31 |
+
patience: 5
|
32 |
min_lr: 1e-6
|
33 |
num_classes: 2
|
34 |
patch_size: 16
|
35 |
+
embed_dim: 256
|
36 |
+
depth: 4
|
37 |
+
num_heads: 4
|
38 |
+
mlp_ratio: 4
|
39 |
|
40 |
trainer:
|
41 |
min_epochs: 1
|
42 |
+
max_epochs: 6
|
43 |
|
44 |
callbacks:
|
45 |
model_checkpoint:
|
configs/model/catdog_classifier.yaml
CHANGED
@@ -3,13 +3,13 @@
|
|
3 |
_target_: src.models.catdog_model.ViTTinyClassifier
|
4 |
|
5 |
# model params
|
6 |
-
img_size:
|
7 |
patch_size: 16
|
8 |
num_classes: 2
|
9 |
-
embed_dim:
|
10 |
depth: 6
|
11 |
-
num_heads:
|
12 |
-
mlp_ratio:
|
13 |
pre_norm: False
|
14 |
|
15 |
# optimizer params
|
|
|
3 |
_target_: src.models.catdog_model.ViTTinyClassifier
|
4 |
|
5 |
# model params
|
6 |
+
img_size: ${data.image_size}
|
7 |
patch_size: 16
|
8 |
num_classes: 2
|
9 |
+
embed_dim: 128
|
10 |
depth: 6
|
11 |
+
num_heads: 4
|
12 |
+
mlp_ratio: 4
|
13 |
pre_norm: False
|
14 |
|
15 |
# optimizer params
|
configs/trainer/default.yaml
CHANGED
@@ -17,3 +17,5 @@ deterministic: True
|
|
17 |
# Log every N steps in training and validation
|
18 |
log_every_n_steps: 10
|
19 |
fast_dev_run: False
|
|
|
|
|
|
17 |
# Log every N steps in training and validation
|
18 |
log_every_n_steps: 10
|
19 |
fast_dev_run: False
|
20 |
+
|
21 |
+
gradient_clip_val: 1.0
|
docker-compose.yaml
CHANGED
@@ -5,7 +5,7 @@ services:
|
|
5 |
build:
|
6 |
context: .
|
7 |
command: |
|
8 |
-
python -m src.
|
9 |
touch /app/checkpoints/train_done.flag
|
10 |
volumes:
|
11 |
- ./data:/app/data
|
@@ -25,7 +25,7 @@ services:
|
|
25 |
build:
|
26 |
context: .
|
27 |
command: |
|
28 |
-
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.
|
29 |
volumes:
|
30 |
- ./data:/app/data
|
31 |
- ./checkpoints:/app/checkpoints
|
|
|
5 |
build:
|
6 |
context: .
|
7 |
command: |
|
8 |
+
python -m src.train_new experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
|
9 |
touch /app/checkpoints/train_done.flag
|
10 |
volumes:
|
11 |
- ./data:/app/data
|
|
|
25 |
build:
|
26 |
context: .
|
27 |
command: |
|
28 |
+
sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_new experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
|
29 |
volumes:
|
30 |
- ./data:/app/data
|
31 |
- ./checkpoints:/app/checkpoints
|
src/datamodules/catdog_datamodule.py
CHANGED
@@ -14,7 +14,7 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
14 |
|
15 |
def __init__(
|
16 |
self,
|
17 |
-
|
18 |
data_dir: Union[str, Path] = "cats_and_dogs_filtered",
|
19 |
batch_size: int = 32,
|
20 |
num_workers: int = 4,
|
@@ -24,7 +24,7 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
24 |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
25 |
):
|
26 |
super().__init__()
|
27 |
-
self.
|
28 |
self.data_dir = data_dir
|
29 |
self.batch_size = batch_size
|
30 |
self.num_workers = num_workers
|
@@ -40,11 +40,11 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
40 |
|
41 |
def prepare_data(self):
|
42 |
"""Download the dataset if it doesn't exist."""
|
43 |
-
self.dataset_path = self.
|
44 |
if not self.dataset_path.exists():
|
45 |
logger.info("Downloading and extracting dataset.")
|
46 |
download_and_extract_archive(
|
47 |
-
url=self.url, download_root=self.
|
48 |
)
|
49 |
logger.info("Download completed.")
|
50 |
|
@@ -56,11 +56,9 @@ class CatDogImageDataModule(L.LightningDataModule):
|
|
56 |
train_transform = transforms.Compose(
|
57 |
[
|
58 |
transforms.Resize((self.image_size, self.image_size)),
|
59 |
-
transforms.RandomHorizontalFlip(0.
|
60 |
-
transforms.RandomRotation(
|
61 |
-
transforms.
|
62 |
-
transforms.RandomAutocontrast(0.1),
|
63 |
-
transforms.RandomAdjustSharpness(2, 0.1),
|
64 |
transforms.ToTensor(),
|
65 |
transforms.Normalize(
|
66 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
@@ -134,7 +132,7 @@ if __name__ == "__main__":
|
|
134 |
def test_datamodule(cfg: DictConfig):
|
135 |
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
136 |
datamodule = CatDogImageDataModule(
|
137 |
-
|
138 |
data_dir=cfg.data.data_dir,
|
139 |
batch_size=cfg.data.batch_size,
|
140 |
num_workers=cfg.data.num_workers,
|
|
|
14 |
|
15 |
def __init__(
|
16 |
self,
|
17 |
+
root_dir: Union[str, Path] = "data",
|
18 |
data_dir: Union[str, Path] = "cats_and_dogs_filtered",
|
19 |
batch_size: int = 32,
|
20 |
num_workers: int = 4,
|
|
|
24 |
url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
|
25 |
):
|
26 |
super().__init__()
|
27 |
+
self.root_dir = Path(root_dir)
|
28 |
self.data_dir = data_dir
|
29 |
self.batch_size = batch_size
|
30 |
self.num_workers = num_workers
|
|
|
40 |
|
41 |
def prepare_data(self):
|
42 |
"""Download the dataset if it doesn't exist."""
|
43 |
+
self.dataset_path = self.root_dir / self.data_dir
|
44 |
if not self.dataset_path.exists():
|
45 |
logger.info("Downloading and extracting dataset.")
|
46 |
download_and_extract_archive(
|
47 |
+
url=self.url, download_root=self.root_dir, remove_finished=True
|
48 |
)
|
49 |
logger.info("Download completed.")
|
50 |
|
|
|
56 |
train_transform = transforms.Compose(
|
57 |
[
|
58 |
transforms.Resize((self.image_size, self.image_size)),
|
59 |
+
transforms.RandomHorizontalFlip(0.5), # Flip probability increased
|
60 |
+
transforms.RandomRotation(5), # Reduced rotation for stability
|
61 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
|
|
|
|
62 |
transforms.ToTensor(),
|
63 |
transforms.Normalize(
|
64 |
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
|
132 |
def test_datamodule(cfg: DictConfig):
|
133 |
logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
|
134 |
datamodule = CatDogImageDataModule(
|
135 |
+
root_dir=cfg.data.root_dir,
|
136 |
data_dir=cfg.data.data_dir,
|
137 |
batch_size=cfg.data.batch_size,
|
138 |
num_workers=cfg.data.num_workers,
|
src/models/catdog_model.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
import lightning as L
|
2 |
-
import torch
|
3 |
-
from torch import
|
4 |
-
from torchmetrics import Accuracy,
|
5 |
from timm.models import VisionTransformer
|
|
|
6 |
|
7 |
|
8 |
class ViTTinyClassifier(L.LightningModule):
|
9 |
def __init__(
|
10 |
self,
|
11 |
img_size: int = 224,
|
12 |
-
num_classes: int = 2, #
|
13 |
embed_dim: int = 64,
|
14 |
depth: int = 6,
|
15 |
num_heads: int = 2,
|
@@ -25,7 +26,7 @@ class ViTTinyClassifier(L.LightningModule):
|
|
25 |
super().__init__()
|
26 |
self.save_hyperparameters()
|
27 |
|
28 |
-
#
|
29 |
self.model = VisionTransformer(
|
30 |
img_size=img_size,
|
31 |
patch_size=patch_size,
|
@@ -35,51 +36,40 @@ class ViTTinyClassifier(L.LightningModule):
|
|
35 |
depth=depth,
|
36 |
num_heads=num_heads,
|
37 |
mlp_ratio=mlp_ratio,
|
38 |
-
qkv_bias=
|
39 |
pre_norm=pre_norm,
|
40 |
global_pool="token",
|
41 |
)
|
42 |
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
"recall": Recall(task="binary"),
|
48 |
-
"f1": F1Score(task="binary"),
|
49 |
-
}
|
50 |
-
|
51 |
-
# Initialize metrics for each stage
|
52 |
-
self.train_metrics = nn.ModuleDict(
|
53 |
-
{name: metric.clone() for name, metric in metrics.items()}
|
54 |
-
)
|
55 |
-
self.val_metrics = nn.ModuleDict(
|
56 |
-
{name: metric.clone() for name, metric in metrics.items()}
|
57 |
-
)
|
58 |
-
self.test_metrics = nn.ModuleDict(
|
59 |
-
{name: metric.clone() for name, metric in metrics.items()}
|
60 |
-
)
|
61 |
|
62 |
-
|
63 |
-
self.
|
|
|
64 |
|
65 |
def forward(self, x):
|
66 |
return self.model(x)
|
67 |
|
68 |
-
def _shared_step(self, batch, stage
|
69 |
x, y = batch
|
70 |
-
logits = self(x)
|
71 |
-
loss =
|
72 |
-
preds =
|
73 |
-
|
74 |
-
#
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
-
# Log metrics
|
81 |
-
self.log(f"{stage}_loss", loss, prog_bar=True)
|
82 |
-
self.log_dict(metric_logs, prog_bar=True, on_step=False, on_epoch=True)
|
83 |
return loss
|
84 |
|
85 |
def training_step(self, batch, batch_idx):
|
@@ -100,6 +90,7 @@ class ViTTinyClassifier(L.LightningModule):
|
|
100 |
|
101 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
102 |
optimizer,
|
|
|
103 |
factor=self.hparams.factor,
|
104 |
patience=self.hparams.patience,
|
105 |
min_lr=self.hparams.min_lr,
|
@@ -113,3 +104,8 @@ class ViTTinyClassifier(L.LightningModule):
|
|
113 |
"interval": "epoch",
|
114 |
},
|
115 |
}
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import lightning as L
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import optim
|
4 |
+
from torchmetrics.classification import Accuracy, F1Score
|
5 |
from timm.models import VisionTransformer
|
6 |
+
import torch
|
7 |
|
8 |
|
9 |
class ViTTinyClassifier(L.LightningModule):
|
10 |
def __init__(
|
11 |
self,
|
12 |
img_size: int = 224,
|
13 |
+
num_classes: int = 2, # Binary classification with two classes
|
14 |
embed_dim: int = 64,
|
15 |
depth: int = 6,
|
16 |
num_heads: int = 2,
|
|
|
26 |
super().__init__()
|
27 |
self.save_hyperparameters()
|
28 |
|
29 |
+
# Vision Transformer model initialization
|
30 |
self.model = VisionTransformer(
|
31 |
img_size=img_size,
|
32 |
patch_size=patch_size,
|
|
|
36 |
depth=depth,
|
37 |
num_heads=num_heads,
|
38 |
mlp_ratio=mlp_ratio,
|
39 |
+
qkv_bias=True,
|
40 |
pre_norm=pre_norm,
|
41 |
global_pool="token",
|
42 |
)
|
43 |
|
44 |
+
# Define accuracy and F1 metrics for binary classification
|
45 |
+
self.train_acc = Accuracy(task="binary")
|
46 |
+
self.val_acc = Accuracy(task="binary")
|
47 |
+
self.test_acc = Accuracy(task="binary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
self.train_f1 = F1Score(task="binary")
|
50 |
+
self.val_f1 = F1Score(task="binary")
|
51 |
+
self.test_f1 = F1Score(task="binary")
|
52 |
|
53 |
def forward(self, x):
|
54 |
return self.model(x)
|
55 |
|
56 |
+
def _shared_step(self, batch, stage):
|
57 |
x, y = batch
|
58 |
+
logits = self(x) # Model output shape: [batch_size, num_classes]
|
59 |
+
loss = F.cross_entropy(logits, y) # Cross-entropy for binary classification
|
60 |
+
preds = torch.argmax(logits, dim=1) # Predicted class (0 or 1)
|
61 |
+
|
62 |
+
# Update and log metrics
|
63 |
+
acc = getattr(self, f"{stage}_acc")
|
64 |
+
f1 = getattr(self, f"{stage}_f1")
|
65 |
+
acc(preds, y)
|
66 |
+
f1(preds, y)
|
67 |
+
|
68 |
+
# Logging of metrics and loss
|
69 |
+
self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True)
|
70 |
+
self.log(f"{stage}_acc", acc, prog_bar=True, on_epoch=True)
|
71 |
+
self.log(f"{stage}_f1", f1, prog_bar=True, on_epoch=True)
|
72 |
|
|
|
|
|
|
|
73 |
return loss
|
74 |
|
75 |
def training_step(self, batch, batch_idx):
|
|
|
90 |
|
91 |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
92 |
optimizer,
|
93 |
+
mode="min",
|
94 |
factor=self.hparams.factor,
|
95 |
patience=self.hparams.patience,
|
96 |
min_lr=self.hparams.min_lr,
|
|
|
104 |
"interval": "epoch",
|
105 |
},
|
106 |
}
|
107 |
+
|
108 |
+
|
109 |
+
if __name__ == "__main__":
|
110 |
+
model = ViTTinyClassifier()
|
111 |
+
print(model)
|
src/train_new.py
CHANGED
@@ -160,7 +160,7 @@ def setup_run_trainer(cfg: DictConfig):
|
|
160 |
# Set up callbacks, loggers, and Trainer
|
161 |
callbacks = instantiate_callbacks(cfg.callbacks)
|
162 |
logger.info(f"Callbacks: {callbacks}")
|
163 |
-
loggers = instantiate_loggers(cfg.
|
164 |
logger.info(f"Loggers: {loggers}")
|
165 |
trainer: L.Trainer = hydra.utils.instantiate(
|
166 |
cfg.trainer, callbacks=callbacks, logger=loggers
|
|
|
160 |
# Set up callbacks, loggers, and Trainer
|
161 |
callbacks = instantiate_callbacks(cfg.callbacks)
|
162 |
logger.info(f"Callbacks: {callbacks}")
|
163 |
+
loggers = instantiate_loggers(cfg.logger)
|
164 |
logger.info(f"Loggers: {loggers}")
|
165 |
trainer: L.Trainer = hydra.utils.instantiate(
|
166 |
cfg.trainer, callbacks=callbacks, logger=loggers
|