Spaces:
Runtime error
Runtime error
File size: 2,780 Bytes
a24b16a d8ffb68 a24b16a d8ffb68 a24b16a 5349660 a24b16a 5349660 a24b16a 5349660 a24b16a 5349660 a24b16a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import copy
import re
import torch
import util
from StableDiffuser import StableDiffuser
class FineTunedModel(torch.nn.Module):
def __init__(self,
model: StableDiffuser,
modules,
frozen_modules=[]
):
super().__init__()
if isinstance(modules, str):
modules = [modules]
self.model = model
self.ft_modules = {}
self.orig_modules = {}
util.freeze(self.model)
for module_name, module in model.named_modules():
for ft_module_regex in modules:
match = re.search(ft_module_regex, module_name)
if match is not None:
ft_module = copy.deepcopy(module)
self.orig_modules[module_name] = module
self.ft_modules[module_name] = ft_module
util.unfreeze(ft_module)
print(f"=> Finetuning {module_name}")
for ft_module_name, module in ft_module.named_modules():
ft_module_name = f"{module_name}.{ft_module_name}"
for freeze_module_name in frozen_modules:
match = re.search(freeze_module_name, ft_module_name)
if match:
print(f"=> Freezing {ft_module_name}")
util.freeze(module)
self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())
@classmethod
def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
if isinstance(checkpoint, str):
checkpoint = torch.load(checkpoint)
modules = [f"{key}$" for key in list(checkpoint.keys())]
ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
ftm.load_state_dict(checkpoint)
return ftm
def __enter__(self):
for key, ft_module in self.ft_modules.items():
util.set_module(self.model, key, ft_module)
def __exit__(self, exc_type, exc_value, tb):
for key, module in self.orig_modules.items():
util.set_module(self.model, key, module)
def parameters(self):
parameters = []
for ft_module in self.ft_modules.values():
parameters.extend(list(ft_module.parameters()))
return parameters
def state_dict(self):
state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
return state_dict
def load_state_dict(self, state_dict):
for key, sd in state_dict.items():
self.ft_modules[key].load_state_dict(sd) |