Spaces:
Runtime error
Runtime error
from typing import Union, Any, Dict, List, Optional, Tuple | |
import pytorch_lightning as pl | |
class LayerConfig(): | |
def __init__(self, | |
gradient_setup: List[Tuple[bool, List[str]]] = None, | |
) -> None: | |
if gradient_setup is not None: | |
self.gradient_setup = gradient_setup | |
self.new_config = True | |
# TODO add option to specify quantization per layer | |
def set_requires_grad(self, pl_module: pl.LightningModule): | |
# [["True","unet.a.b","c"],["True,[]"]] | |
for selected_module_setup in self.gradient_setup: | |
for model_name, p in pl_module.named_parameters(): | |
grad_mode = selected_module_setup[0] == True | |
selected_module_path = selected_module_setup[1] | |
path_is_matching = True | |
model_name_selection = model_name | |
for selected_module in selected_module_path: | |
position = model_name_selection.find(selected_module) | |
if position == -1: | |
path_is_matching = False | |
continue | |
else: | |
shift = len(selected_module) | |
model_name_selection = model_name_selection[position+shift:] | |
if path_is_matching: | |
# if grad_mode: | |
# print( | |
# f"Setting gradient for {model_name} to {grad_mode}") | |
p.requires_grad = grad_mode | |