Soutrik commited on
Commit
de7d21e
1 Parent(s): b0bdbcf

new train model

Browse files
configs/callbacks/early_stopping.yaml CHANGED
@@ -1,7 +1,7 @@
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
 
3
  early_stopping:
4
- _target_: pytorch_lightning.callbacks.EarlyStopping
5
  monitor: val_loss # quantity to be monitored, must be specified !!!
6
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
  patience: 3 # number of checks with no improvement after which training will be stopped
 
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html
2
 
3
  early_stopping:
4
+ _target_: lightning.pytorch.callbacks.EarlyStopping
5
  monitor: val_loss # quantity to be monitored, must be specified !!!
6
  min_delta: 0. # minimum change in the monitored quantity to qualify as an improvement
7
  patience: 3 # number of checks with no improvement after which training will be stopped
configs/callbacks/model_checkpoint.yaml CHANGED
@@ -1,7 +1,7 @@
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
 
3
  model_checkpoint:
4
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
5
  dirpath: null # directory to save the model file
6
  filename: best-checkpoint # checkpoint filename
7
  monitor: val_loss # name of the logged metric which determines when model is improving
 
1
  # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html
2
 
3
  model_checkpoint:
4
+ _target_: lightning.pytorch.callbacks.ModelCheckpoint
5
  dirpath: null # directory to save the model file
6
  filename: best-checkpoint # checkpoint filename
7
  monitor: val_loss # name of the logged metric which determines when model is improving
configs/callbacks/rich_model_summary.yaml CHANGED
@@ -1,3 +1,4 @@
 
1
  rich_model_summary:
2
- _target_: pytorch_lightning.callbacks.RichModelSummary
3
  max_depth: 1
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html
2
  rich_model_summary:
3
+ _target_: lightning.pytorch.callbacks.RichModelSummary
4
  max_depth: 1
configs/callbacks/rich_progress_bar.yaml CHANGED
@@ -1,3 +1,4 @@
 
1
  rich_progress_bar:
2
- _target_: pytorch_lightning.callbacks.RichProgressBar
3
  refresh_rate: 1
 
1
+ # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichProgressBar.html
2
  rich_progress_bar:
3
+ _target_: lightning.pytorch.callbacks.RichProgressBar
4
  refresh_rate: 1
configs/experiment/catdog_experiment.yaml CHANGED
@@ -18,28 +18,28 @@ seed: 42
18
  name: "catdog_experiment"
19
 
20
  data:
21
- dataset: "cats_and_dogs_filtered"
22
- batch_size: 32
23
  num_workers: 8
24
  pin_memory: True
25
  image_size: 224
26
 
27
  model:
28
- lr: 1e-3
29
  weight_decay: 1e-5
30
- factor: 0.1
31
- patience: 10
32
  min_lr: 1e-6
33
  num_classes: 2
34
  patch_size: 16
35
- embed_dim: 64
36
- depth: 6
37
- num_heads: 2
38
- mlp_ratio: 3
39
 
40
  trainer:
41
  min_epochs: 1
42
- max_epochs: 10
43
 
44
  callbacks:
45
  model_checkpoint:
 
18
  name: "catdog_experiment"
19
 
20
  data:
21
+ data_dir: "cats_and_dogs_filtered"
22
+ batch_size: 64
23
  num_workers: 8
24
  pin_memory: True
25
  image_size: 224
26
 
27
  model:
28
+ lr: 5e-5
29
  weight_decay: 1e-5
30
+ factor: 0.5
31
+ patience: 5
32
  min_lr: 1e-6
33
  num_classes: 2
34
  patch_size: 16
35
+ embed_dim: 256
36
+ depth: 4
37
+ num_heads: 4
38
+ mlp_ratio: 4
39
 
40
  trainer:
41
  min_epochs: 1
42
+ max_epochs: 6
43
 
44
  callbacks:
45
  model_checkpoint:
configs/model/catdog_classifier.yaml CHANGED
@@ -3,13 +3,13 @@
3
  _target_: src.models.catdog_model.ViTTinyClassifier
4
 
5
  # model params
6
- img_size: 160
7
  patch_size: 16
8
  num_classes: 2
9
- embed_dim: 64
10
  depth: 6
11
- num_heads: 2
12
- mlp_ratio: 3.0
13
  pre_norm: False
14
 
15
  # optimizer params
 
3
  _target_: src.models.catdog_model.ViTTinyClassifier
4
 
5
  # model params
6
+ img_size: ${data.image_size}
7
  patch_size: 16
8
  num_classes: 2
9
+ embed_dim: 128
10
  depth: 6
11
+ num_heads: 4
12
+ mlp_ratio: 4
13
  pre_norm: False
14
 
15
  # optimizer params
configs/trainer/default.yaml CHANGED
@@ -17,3 +17,5 @@ deterministic: True
17
  # Log every N steps in training and validation
18
  log_every_n_steps: 10
19
  fast_dev_run: False
 
 
 
17
  # Log every N steps in training and validation
18
  log_every_n_steps: 10
19
  fast_dev_run: False
20
+
21
+ gradient_clip_val: 1.0
docker-compose.yaml CHANGED
@@ -5,7 +5,7 @@ services:
5
  build:
6
  context: .
7
  command: |
8
- python -m src.train experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
9
  touch /app/checkpoints/train_done.flag
10
  volumes:
11
  - ./data:/app/data
@@ -25,7 +25,7 @@ services:
25
  build:
26
  context: .
27
  command: |
28
- sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
29
  volumes:
30
  - ./data:/app/data
31
  - ./checkpoints:/app/checkpoints
 
5
  build:
6
  context: .
7
  command: |
8
+ python -m src.train_new experiment=catdog_experiment ++task_name=train ++train=True ++test=False && \
9
  touch /app/checkpoints/train_done.flag
10
  volumes:
11
  - ./data:/app/data
 
25
  build:
26
  context: .
27
  command: |
28
+ sh -c 'while [ ! -f /app/checkpoints/train_done.flag ]; do sleep 10; done && python -m src.train_new experiment=catdog_experiment ++task_name=eval ++train=False ++test=True'
29
  volumes:
30
  - ./data:/app/data
31
  - ./checkpoints:/app/checkpoints
src/datamodules/catdog_datamodule.py CHANGED
@@ -14,7 +14,7 @@ class CatDogImageDataModule(L.LightningDataModule):
14
 
15
  def __init__(
16
  self,
17
- data_root: Union[str, Path] = "data",
18
  data_dir: Union[str, Path] = "cats_and_dogs_filtered",
19
  batch_size: int = 32,
20
  num_workers: int = 4,
@@ -24,7 +24,7 @@ class CatDogImageDataModule(L.LightningDataModule):
24
  url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
25
  ):
26
  super().__init__()
27
- self.data_root = Path(data_root)
28
  self.data_dir = data_dir
29
  self.batch_size = batch_size
30
  self.num_workers = num_workers
@@ -40,11 +40,11 @@ class CatDogImageDataModule(L.LightningDataModule):
40
 
41
  def prepare_data(self):
42
  """Download the dataset if it doesn't exist."""
43
- self.dataset_path = self.data_root / self.data_dir
44
  if not self.dataset_path.exists():
45
  logger.info("Downloading and extracting dataset.")
46
  download_and_extract_archive(
47
- url=self.url, download_root=self.data_root, remove_finished=True
48
  )
49
  logger.info("Download completed.")
50
 
@@ -56,11 +56,9 @@ class CatDogImageDataModule(L.LightningDataModule):
56
  train_transform = transforms.Compose(
57
  [
58
  transforms.Resize((self.image_size, self.image_size)),
59
- transforms.RandomHorizontalFlip(0.1),
60
- transforms.RandomRotation(10),
61
- transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
62
- transforms.RandomAutocontrast(0.1),
63
- transforms.RandomAdjustSharpness(2, 0.1),
64
  transforms.ToTensor(),
65
  transforms.Normalize(
66
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
@@ -134,7 +132,7 @@ if __name__ == "__main__":
134
  def test_datamodule(cfg: DictConfig):
135
  logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
136
  datamodule = CatDogImageDataModule(
137
- data_root=cfg.paths.data_dir,
138
  data_dir=cfg.data.data_dir,
139
  batch_size=cfg.data.batch_size,
140
  num_workers=cfg.data.num_workers,
 
14
 
15
  def __init__(
16
  self,
17
+ root_dir: Union[str, Path] = "data",
18
  data_dir: Union[str, Path] = "cats_and_dogs_filtered",
19
  batch_size: int = 32,
20
  num_workers: int = 4,
 
24
  url: str = "https://download.pytorch.org/tutorials/cats_and_dogs_filtered.zip",
25
  ):
26
  super().__init__()
27
+ self.root_dir = Path(root_dir)
28
  self.data_dir = data_dir
29
  self.batch_size = batch_size
30
  self.num_workers = num_workers
 
40
 
41
  def prepare_data(self):
42
  """Download the dataset if it doesn't exist."""
43
+ self.dataset_path = self.root_dir / self.data_dir
44
  if not self.dataset_path.exists():
45
  logger.info("Downloading and extracting dataset.")
46
  download_and_extract_archive(
47
+ url=self.url, download_root=self.root_dir, remove_finished=True
48
  )
49
  logger.info("Download completed.")
50
 
 
56
  train_transform = transforms.Compose(
57
  [
58
  transforms.Resize((self.image_size, self.image_size)),
59
+ transforms.RandomHorizontalFlip(0.5), # Flip probability increased
60
+ transforms.RandomRotation(5), # Reduced rotation for stability
61
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
 
 
62
  transforms.ToTensor(),
63
  transforms.Normalize(
64
  mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 
132
  def test_datamodule(cfg: DictConfig):
133
  logger.info(f"Config:\n{OmegaConf.to_yaml(cfg)}")
134
  datamodule = CatDogImageDataModule(
135
+ root_dir=cfg.data.root_dir,
136
  data_dir=cfg.data.data_dir,
137
  batch_size=cfg.data.batch_size,
138
  num_workers=cfg.data.num_workers,
src/models/catdog_model.py CHANGED
@@ -1,15 +1,16 @@
1
  import lightning as L
2
- import torch
3
- from torch import nn, optim
4
- from torchmetrics import Accuracy, Precision, Recall, F1Score
5
  from timm.models import VisionTransformer
 
6
 
7
 
8
  class ViTTinyClassifier(L.LightningModule):
9
  def __init__(
10
  self,
11
  img_size: int = 224,
12
- num_classes: int = 2, # Should be 2 for binary classification
13
  embed_dim: int = 64,
14
  depth: int = 6,
15
  num_heads: int = 2,
@@ -25,7 +26,7 @@ class ViTTinyClassifier(L.LightningModule):
25
  super().__init__()
26
  self.save_hyperparameters()
27
 
28
- # Create ViT model
29
  self.model = VisionTransformer(
30
  img_size=img_size,
31
  patch_size=patch_size,
@@ -35,51 +36,40 @@ class ViTTinyClassifier(L.LightningModule):
35
  depth=depth,
36
  num_heads=num_heads,
37
  mlp_ratio=mlp_ratio,
38
- qkv_bias=False,
39
  pre_norm=pre_norm,
40
  global_pool="token",
41
  )
42
 
43
- # Metrics for binary classification
44
- metrics = {
45
- "acc": Accuracy(task="binary"),
46
- "precision": Precision(task="binary"),
47
- "recall": Recall(task="binary"),
48
- "f1": F1Score(task="binary"),
49
- }
50
-
51
- # Initialize metrics for each stage
52
- self.train_metrics = nn.ModuleDict(
53
- {name: metric.clone() for name, metric in metrics.items()}
54
- )
55
- self.val_metrics = nn.ModuleDict(
56
- {name: metric.clone() for name, metric in metrics.items()}
57
- )
58
- self.test_metrics = nn.ModuleDict(
59
- {name: metric.clone() for name, metric in metrics.items()}
60
- )
61
 
62
- # Loss function
63
- self.criterion = nn.CrossEntropyLoss()
 
64
 
65
  def forward(self, x):
66
  return self.model(x)
67
 
68
- def _shared_step(self, batch, stage: str):
69
  x, y = batch
70
- logits = self(x)
71
- loss = self.criterion(logits, y)
72
- preds = logits.argmax(dim=1)
73
-
74
- # Get appropriate metric dictionary based on stage
75
- metrics = getattr(self, f"{stage}_metrics")
76
- metric_logs = {
77
- f"{stage}_{name}": metric(preds, y) for name, metric in metrics.items()
78
- }
 
 
 
 
 
79
 
80
- # Log metrics
81
- self.log(f"{stage}_loss", loss, prog_bar=True)
82
- self.log_dict(metric_logs, prog_bar=True, on_step=False, on_epoch=True)
83
  return loss
84
 
85
  def training_step(self, batch, batch_idx):
@@ -100,6 +90,7 @@ class ViTTinyClassifier(L.LightningModule):
100
 
101
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
102
  optimizer,
 
103
  factor=self.hparams.factor,
104
  patience=self.hparams.patience,
105
  min_lr=self.hparams.min_lr,
@@ -113,3 +104,8 @@ class ViTTinyClassifier(L.LightningModule):
113
  "interval": "epoch",
114
  },
115
  }
 
 
 
 
 
 
1
  import lightning as L
2
+ import torch.nn.functional as F
3
+ from torch import optim
4
+ from torchmetrics.classification import Accuracy, F1Score
5
  from timm.models import VisionTransformer
6
+ import torch
7
 
8
 
9
  class ViTTinyClassifier(L.LightningModule):
10
  def __init__(
11
  self,
12
  img_size: int = 224,
13
+ num_classes: int = 2, # Binary classification with two classes
14
  embed_dim: int = 64,
15
  depth: int = 6,
16
  num_heads: int = 2,
 
26
  super().__init__()
27
  self.save_hyperparameters()
28
 
29
+ # Vision Transformer model initialization
30
  self.model = VisionTransformer(
31
  img_size=img_size,
32
  patch_size=patch_size,
 
36
  depth=depth,
37
  num_heads=num_heads,
38
  mlp_ratio=mlp_ratio,
39
+ qkv_bias=True,
40
  pre_norm=pre_norm,
41
  global_pool="token",
42
  )
43
 
44
+ # Define accuracy and F1 metrics for binary classification
45
+ self.train_acc = Accuracy(task="binary")
46
+ self.val_acc = Accuracy(task="binary")
47
+ self.test_acc = Accuracy(task="binary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ self.train_f1 = F1Score(task="binary")
50
+ self.val_f1 = F1Score(task="binary")
51
+ self.test_f1 = F1Score(task="binary")
52
 
53
  def forward(self, x):
54
  return self.model(x)
55
 
56
+ def _shared_step(self, batch, stage):
57
  x, y = batch
58
+ logits = self(x) # Model output shape: [batch_size, num_classes]
59
+ loss = F.cross_entropy(logits, y) # Cross-entropy for binary classification
60
+ preds = torch.argmax(logits, dim=1) # Predicted class (0 or 1)
61
+
62
+ # Update and log metrics
63
+ acc = getattr(self, f"{stage}_acc")
64
+ f1 = getattr(self, f"{stage}_f1")
65
+ acc(preds, y)
66
+ f1(preds, y)
67
+
68
+ # Logging of metrics and loss
69
+ self.log(f"{stage}_loss", loss, prog_bar=True, on_epoch=True)
70
+ self.log(f"{stage}_acc", acc, prog_bar=True, on_epoch=True)
71
+ self.log(f"{stage}_f1", f1, prog_bar=True, on_epoch=True)
72
 
 
 
 
73
  return loss
74
 
75
  def training_step(self, batch, batch_idx):
 
90
 
91
  scheduler = optim.lr_scheduler.ReduceLROnPlateau(
92
  optimizer,
93
+ mode="min",
94
  factor=self.hparams.factor,
95
  patience=self.hparams.patience,
96
  min_lr=self.hparams.min_lr,
 
104
  "interval": "epoch",
105
  },
106
  }
107
+
108
+
109
+ if __name__ == "__main__":
110
+ model = ViTTinyClassifier()
111
+ print(model)
src/train_new.py CHANGED
@@ -160,7 +160,7 @@ def setup_run_trainer(cfg: DictConfig):
160
  # Set up callbacks, loggers, and Trainer
161
  callbacks = instantiate_callbacks(cfg.callbacks)
162
  logger.info(f"Callbacks: {callbacks}")
163
- loggers = instantiate_loggers(cfg.loggers)
164
  logger.info(f"Loggers: {loggers}")
165
  trainer: L.Trainer = hydra.utils.instantiate(
166
  cfg.trainer, callbacks=callbacks, logger=loggers
 
160
  # Set up callbacks, loggers, and Trainer
161
  callbacks = instantiate_callbacks(cfg.callbacks)
162
  logger.info(f"Callbacks: {callbacks}")
163
+ loggers = instantiate_loggers(cfg.logger)
164
  logger.info(f"Loggers: {loggers}")
165
  trainer: L.Trainer = hydra.utils.instantiate(
166
  cfg.trainer, callbacks=callbacks, logger=loggers