StreamingT2V / t2v_enhanced /model /requires_grad_setter.py
hpoghos's picture
add code
f949b3f
raw
history blame
1.51 kB
from typing import Union, Any, Dict, List, Optional, Tuple
import pytorch_lightning as pl
class LayerConfig():
def __init__(self,
gradient_setup: List[Tuple[bool, List[str]]] = None,
) -> None:
if gradient_setup is not None:
self.gradient_setup = gradient_setup
self.new_config = True
# TODO add option to specify quantization per layer
def set_requires_grad(self, pl_module: pl.LightningModule):
# [["True","unet.a.b","c"],["True,[]"]]
for selected_module_setup in self.gradient_setup:
for model_name, p in pl_module.named_parameters():
grad_mode = selected_module_setup[0] == True
selected_module_path = selected_module_setup[1]
path_is_matching = True
model_name_selection = model_name
for selected_module in selected_module_path:
position = model_name_selection.find(selected_module)
if position == -1:
path_is_matching = False
continue
else:
shift = len(selected_module)
model_name_selection = model_name_selection[position+shift:]
if path_is_matching:
# if grad_mode:
# print(
# f"Setting gradient for {model_name} to {grad_mode}")
p.requires_grad = grad_mode