Spaces:
Runtime error
Runtime error
File size: 1,512 Bytes
f949b3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
from typing import Union, Any, Dict, List, Optional, Tuple
import pytorch_lightning as pl
class LayerConfig():
def __init__(self,
gradient_setup: List[Tuple[bool, List[str]]] = None,
) -> None:
if gradient_setup is not None:
self.gradient_setup = gradient_setup
self.new_config = True
# TODO add option to specify quantization per layer
def set_requires_grad(self, pl_module: pl.LightningModule):
# [["True","unet.a.b","c"],["True,[]"]]
for selected_module_setup in self.gradient_setup:
for model_name, p in pl_module.named_parameters():
grad_mode = selected_module_setup[0] == True
selected_module_path = selected_module_setup[1]
path_is_matching = True
model_name_selection = model_name
for selected_module in selected_module_path:
position = model_name_selection.find(selected_module)
if position == -1:
path_is_matching = False
continue
else:
shift = len(selected_module)
model_name_selection = model_name_selection[position+shift:]
if path_is_matching:
# if grad_mode:
# print(
# f"Setting gradient for {model_name} to {grad_mode}")
p.requires_grad = grad_mode
|