Damian Stewart
wip adding AMP and xformers to training code path
fc73e59
raw
history blame
2.78 kB
import copy
import re
import torch
import util
from StableDiffuser import StableDiffuser
class FineTunedModel(torch.nn.Module):
def __init__(self,
model: StableDiffuser,
modules,
frozen_modules=[]
):
super().__init__()
if isinstance(modules, str):
modules = [modules]
self.model = model
self.ft_modules = {}
self.orig_modules = {}
util.freeze(self.model)
for module_name, module in model.named_modules():
for ft_module_regex in modules:
match = re.search(ft_module_regex, module_name)
if match is not None:
ft_module = copy.deepcopy(module)
self.orig_modules[module_name] = module
self.ft_modules[module_name] = ft_module
util.unfreeze(ft_module)
print(f"=> Finetuning {module_name}")
for ft_module_name, module in ft_module.named_modules():
ft_module_name = f"{module_name}.{ft_module_name}"
for freeze_module_name in frozen_modules:
match = re.search(freeze_module_name, ft_module_name)
if match:
print(f"=> Freezing {ft_module_name}")
util.freeze(module)
self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())
@classmethod
def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
if isinstance(checkpoint, str):
checkpoint = torch.load(checkpoint)
modules = [f"{key}$" for key in list(checkpoint.keys())]
ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
ftm.load_state_dict(checkpoint)
return ftm
def __enter__(self):
for key, ft_module in self.ft_modules.items():
util.set_module(self.model, key, ft_module)
def __exit__(self, exc_type, exc_value, tb):
for key, module in self.orig_modules.items():
util.set_module(self.model, key, module)
def parameters(self):
parameters = []
for ft_module in self.ft_modules.values():
parameters.extend(list(ft_module.parameters()))
return parameters
def state_dict(self):
state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
return state_dict
def load_state_dict(self, state_dict):
for key, sd in state_dict.items():
self.ft_modules[key].load_state_dict(sd)