|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class NeuralModule(nn.Module): |
|
|
|
@property |
|
def num_weights(self): |
|
""" |
|
Utility property that returns the total number of parameters of NeuralModule. |
|
""" |
|
return self._num_weights() |
|
|
|
@torch.jit.ignore |
|
def _num_weights(self): |
|
num: int = 0 |
|
for p in self.parameters(): |
|
if p.requires_grad: |
|
num += p.numel() |
|
return num |
|
|
|
def freeze(self) -> None: |
|
r""" |
|
Freeze all params for inference. |
|
|
|
This method sets `requires_grad` to False for all parameters of the module. |
|
It also stores the original `requires_grad` state of each parameter in a dictionary, |
|
so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`. |
|
""" |
|
grad_map = {} |
|
|
|
for pname, param in self.named_parameters(): |
|
|
|
grad_map[pname] = param.requires_grad |
|
|
|
param.requires_grad = False |
|
|
|
|
|
if not hasattr(self, '_frozen_grad_map'): |
|
self._frozen_grad_map = grad_map |
|
else: |
|
self._frozen_grad_map.update(grad_map) |
|
|
|
self.eval() |
|
|
|
def unfreeze(self, partial: bool = False) -> None: |
|
""" |
|
Unfreeze all parameters for training. |
|
|
|
Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). |
|
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were |
|
previously unfrozen prior `freeze()`. |
|
|
|
Example: |
|
Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always. |
|
|
|
```python |
|
model.encoder.freeze() # Freezes all parameters in the encoder explicitly |
|
``` |
|
|
|
During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method. |
|
This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called, |
|
we should keep the encoder parameters frozen. |
|
|
|
```python |
|
model.freeze() # Freezes all parameters in the model; encoder remains frozen |
|
``` |
|
|
|
Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling |
|
`unfreeze(partial=True)`. |
|
|
|
```python |
|
model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen |
|
``` |
|
|
|
Args: |
|
partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen |
|
when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`. |
|
""" |
|
if partial and not hasattr(self, '_frozen_grad_map'): |
|
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") |
|
|
|
for pname, param in self.named_parameters(): |
|
if not partial: |
|
|
|
param.requires_grad = True |
|
else: |
|
|
|
|
|
|
|
if pname in self._frozen_grad_map: |
|
param.requires_grad = self._frozen_grad_map[pname] |
|
else: |
|
|
|
print( |
|
f"Parameter {pname} not found in list of previously frozen parameters. " |
|
f"Unfreezing this parameter." |
|
) |
|
param.requires_grad = True |
|
|
|
|
|
if hasattr(self, '_frozen_grad_map'): |
|
delattr(self, '_frozen_grad_map') |
|
|
|
self.train() |