Spaces:
Runtime error
Runtime error
import copy | |
import re | |
import torch | |
import util | |
from StableDiffuser import StableDiffuser | |
class FineTunedModel(torch.nn.Module): | |
def __init__(self, | |
model: StableDiffuser, | |
modules, | |
frozen_modules=[] | |
): | |
super().__init__() | |
if isinstance(modules, str): | |
modules = [modules] | |
self.model = model | |
self.ft_modules = {} | |
self.orig_modules = {} | |
util.freeze(self.model) | |
for module_name, module in model.named_modules(): | |
for ft_module_regex in modules: | |
match = re.search(ft_module_regex, module_name) | |
if match is not None: | |
ft_module = copy.deepcopy(module) | |
self.orig_modules[module_name] = module | |
self.ft_modules[module_name] = ft_module | |
util.unfreeze(ft_module) | |
print(f"=> Finetuning {module_name}") | |
for ft_module_name, module in ft_module.named_modules(): | |
ft_module_name = f"{module_name}.{ft_module_name}" | |
for freeze_module_name in frozen_modules: | |
match = re.search(freeze_module_name, ft_module_name) | |
if match: | |
print(f"=> Freezing {ft_module_name}") | |
util.freeze(module) | |
self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values()) | |
self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values()) | |
def from_checkpoint(cls, model, checkpoint, frozen_modules=[]): | |
if isinstance(checkpoint, str): | |
checkpoint = torch.load(checkpoint) | |
modules = [f"{key}$" for key in list(checkpoint.keys())] | |
ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules) | |
ftm.load_state_dict(checkpoint) | |
return ftm | |
def __enter__(self): | |
for key, ft_module in self.ft_modules.items(): | |
util.set_module(self.model, key, ft_module) | |
def __exit__(self, exc_type, exc_value, tb): | |
for key, module in self.orig_modules.items(): | |
util.set_module(self.model, key, module) | |
def parameters(self): | |
parameters = [] | |
for ft_module in self.ft_modules.values(): | |
parameters.extend(list(ft_module.parameters())) | |
return parameters | |
def state_dict(self): | |
state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()} | |
return state_dict | |
def load_state_dict(self, state_dict): | |
for key, sd in state_dict.items(): | |
self.ft_modules[key].load_state_dict(sd) |