File size: 2,780 Bytes
a24b16a
 
 
 
d8ffb68
a24b16a
 
 
 
d8ffb68
a24b16a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5349660
 
a24b16a
 
5349660
a24b16a
5349660
a24b16a
 
 
 
 
5349660
 
 
 
 
 
 
 
 
 
 
 
 
 
a24b16a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import copy
import re
import torch
import util
from StableDiffuser import StableDiffuser

class FineTunedModel(torch.nn.Module):

    def __init__(self,
                 model: StableDiffuser,
                 modules,
                 frozen_modules=[]
                 ):

        super().__init__()

        if isinstance(modules, str):
            modules = [modules]

        self.model = model
        self.ft_modules = {}
        self.orig_modules = {}

        util.freeze(self.model)

        for module_name, module in model.named_modules():
            for ft_module_regex in modules:
                match = re.search(ft_module_regex, module_name)
                if match is not None:
                    ft_module = copy.deepcopy(module)
                    
                    self.orig_modules[module_name] = module
                    self.ft_modules[module_name] = ft_module

                    util.unfreeze(ft_module)

                    print(f"=> Finetuning {module_name}")
       
                    for ft_module_name, module in ft_module.named_modules():
                        ft_module_name = f"{module_name}.{ft_module_name}"
                        for freeze_module_name in frozen_modules:

                            match = re.search(freeze_module_name, ft_module_name)
                            if match:
                                print(f"=> Freezing {ft_module_name}")
                                util.freeze(module)

        self.ft_modules_list = torch.nn.ModuleList(self.ft_modules.values())
        self.orig_modules_list = torch.nn.ModuleList(self.orig_modules.values())


    @classmethod
    def from_checkpoint(cls, model, checkpoint, frozen_modules=[]):
        if isinstance(checkpoint, str):
            checkpoint = torch.load(checkpoint)

        modules = [f"{key}$" for key in list(checkpoint.keys())]

        ftm = FineTunedModel(model, modules, frozen_modules=frozen_modules)
        ftm.load_state_dict(checkpoint)

        return ftm

        
    def __enter__(self):
        for key, ft_module in self.ft_modules.items():
            util.set_module(self.model, key, ft_module)

    def __exit__(self, exc_type, exc_value, tb):
        for key, module in self.orig_modules.items():
            util.set_module(self.model, key, module)

    def parameters(self):
        parameters = []
        for ft_module in self.ft_modules.values():
            parameters.extend(list(ft_module.parameters()))
        return parameters

    def state_dict(self):
        state_dict = {key: module.state_dict() for key, module in self.ft_modules.items()}
        return state_dict

    def load_state_dict(self, state_dict):
        for key, sd in state_dict.items():
            self.ft_modules[key].load_state_dict(sd)