devforfu commited on
Commit
12babad
·
1 Parent(s): 7b1ae8d

Fine-tuning support

Browse files
realfake/callbacks.py CHANGED
@@ -47,6 +47,12 @@ class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
47
  def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
48
  rank_zero_info("Freezing backbone")
49
  self.freeze(_get_backbone(pl_module.model))
 
 
 
 
 
 
50
 
51
  def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
52
  if epoch == self._unfreeze_at_epoch:
 
47
  def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
48
  rank_zero_info("Freezing backbone")
49
  self.freeze(_get_backbone(pl_module.model))
50
+ enabled_layers = [
51
+ name
52
+ for name, child in pl_module.model.named_children()
53
+ if all(param.requires_grad for param in child.parameters())
54
+ ]
55
+ rank_zero_info(f"Gradient enabled layers: [{', '.join(enabled_layers)}]")
56
 
57
  def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
58
  if epoch == self._unfreeze_at_epoch:
realfake/models.py CHANGED
@@ -35,6 +35,7 @@ class RealFakeParams(Args):
35
  epochs: int = Field(6)
36
  base_lr: float = Field(1e-3)
37
  pretrained: bool = Field(True)
 
38
  accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
39
 
40
 
 
35
  epochs: int = Field(6)
36
  base_lr: float = Field(1e-3)
37
  pretrained: bool = Field(True)
38
+ progress_bar: bool = Field(False)
39
  accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
40
 
41
 
realfake/train.py CHANGED
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
7
  from pytorch_lightning.callbacks import ModelCheckpoint
8
  from pytorch_lightning.plugins.environments import SLURMEnvironment
9
 
10
- from realfake.callbacks import ConsoleLogger
11
  from realfake.models import RealFakeParams
12
  from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
13
 
@@ -45,7 +45,7 @@ def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
45
  max_epochs=args.epochs,
46
  num_nodes=1,
47
  num_sanity_val_steps=0,
48
- enable_progress_bar=False,
49
  callbacks=[
50
  ConsoleLogger(),
51
  ModelCheckpoint(
@@ -56,12 +56,13 @@ def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
56
  dirpath=checkpoints_dir,
57
  filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
58
  ),
 
59
  ],
60
  resume_from_checkpoint=existing_checkpoint,
61
  )
62
 
63
  if job_id is not None:
64
- trainer_params["enable_progress_bar"] = True
65
  trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
66
  trainer_params["strategy"] = args.accelerator.strategy
67
 
 
7
  from pytorch_lightning.callbacks import ModelCheckpoint
8
  from pytorch_lightning.plugins.environments import SLURMEnvironment
9
 
10
+ from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze
11
  from realfake.models import RealFakeParams
12
  from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
13
 
 
45
  max_epochs=args.epochs,
46
  num_nodes=1,
47
  num_sanity_val_steps=0,
48
+ enable_progress_bar=args.progress_bar,
49
  callbacks=[
50
  ConsoleLogger(),
51
  ModelCheckpoint(
 
56
  dirpath=checkpoints_dir,
57
  filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
58
  ),
59
+ FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs)
60
  ],
61
  resume_from_checkpoint=existing_checkpoint,
62
  )
63
 
64
  if job_id is not None:
65
+ trainer_params["enable_progress_bar"] = False
66
  trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
67
  trainer_params["strategy"] = args.accelerator.strategy
68