devforfu
commited on
Commit
·
12babad
1
Parent(s):
7b1ae8d
Fine-tuning support
Browse files- realfake/callbacks.py +6 -0
- realfake/models.py +1 -0
- realfake/train.py +4 -3
realfake/callbacks.py
CHANGED
@@ -47,6 +47,12 @@ class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
|
|
47 |
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
|
48 |
rank_zero_info("Freezing backbone")
|
49 |
self.freeze(_get_backbone(pl_module.model))
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
|
52 |
if epoch == self._unfreeze_at_epoch:
|
|
|
47 |
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
|
48 |
rank_zero_info("Freezing backbone")
|
49 |
self.freeze(_get_backbone(pl_module.model))
|
50 |
+
enabled_layers = [
|
51 |
+
name
|
52 |
+
for name, child in pl_module.model.named_children()
|
53 |
+
if all(param.requires_grad for param in child.parameters())
|
54 |
+
]
|
55 |
+
rank_zero_info(f"Gradient enabled layers: [{', '.join(enabled_layers)}]")
|
56 |
|
57 |
def finetune_function(self, pl_module: "pl.LightningModule", epoch: int, optimizer, opt_idx: int) -> None:
|
58 |
if epoch == self._unfreeze_at_epoch:
|
realfake/models.py
CHANGED
@@ -35,6 +35,7 @@ class RealFakeParams(Args):
|
|
35 |
epochs: int = Field(6)
|
36 |
base_lr: float = Field(1e-3)
|
37 |
pretrained: bool = Field(True)
|
|
|
38 |
accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
|
39 |
|
40 |
|
|
|
35 |
epochs: int = Field(6)
|
36 |
base_lr: float = Field(1e-3)
|
37 |
pretrained: bool = Field(True)
|
38 |
+
progress_bar: bool = Field(False)
|
39 |
accelerator: AcceleratorParams = Field(default_factory=AcceleratorParams)
|
40 |
|
41 |
|
realfake/train.py
CHANGED
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
|
|
7 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
9 |
|
10 |
-
from realfake.callbacks import ConsoleLogger
|
11 |
from realfake.models import RealFakeParams
|
12 |
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
|
13 |
|
@@ -45,7 +45,7 @@ def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
|
|
45 |
max_epochs=args.epochs,
|
46 |
num_nodes=1,
|
47 |
num_sanity_val_steps=0,
|
48 |
-
enable_progress_bar=
|
49 |
callbacks=[
|
50 |
ConsoleLogger(),
|
51 |
ModelCheckpoint(
|
@@ -56,12 +56,13 @@ def prepare_trainer(args: RealFakeParams) -> pl.Trainer:
|
|
56 |
dirpath=checkpoints_dir,
|
57 |
filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
|
58 |
),
|
|
|
59 |
],
|
60 |
resume_from_checkpoint=existing_checkpoint,
|
61 |
)
|
62 |
|
63 |
if job_id is not None:
|
64 |
-
trainer_params["enable_progress_bar"] =
|
65 |
trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
|
66 |
trainer_params["strategy"] = args.accelerator.strategy
|
67 |
|
|
|
7 |
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
9 |
|
10 |
+
from realfake.callbacks import ConsoleLogger, FeatureExtractorFreezeUnfreeze
|
11 |
from realfake.models import RealFakeParams
|
12 |
from realfake.utils import get_checkpoints_dir, find_latest_checkpoint
|
13 |
|
|
|
45 |
max_epochs=args.epochs,
|
46 |
num_nodes=1,
|
47 |
num_sanity_val_steps=0,
|
48 |
+
enable_progress_bar=args.progress_bar,
|
49 |
callbacks=[
|
50 |
ConsoleLogger(),
|
51 |
ModelCheckpoint(
|
|
|
56 |
dirpath=checkpoints_dir,
|
57 |
filename="%s-{epoch:02d}-{val_acc:.4f}" % args.model_name,
|
58 |
),
|
59 |
+
FeatureExtractorFreezeUnfreeze(unfreeze_at_epoch=args.freeze_epochs)
|
60 |
],
|
61 |
resume_from_checkpoint=existing_checkpoint,
|
62 |
)
|
63 |
|
64 |
if job_id is not None:
|
65 |
+
trainer_params["enable_progress_bar"] = False
|
66 |
trainer_params["plugins"] = SLURMEnvironment(requeue_signal=signal.SIGHUP)
|
67 |
trainer_params["strategy"] = args.accelerator.strategy
|
68 |
|