File size: 4,785 Bytes
55a3c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import copy
from diffusers import DDPMScheduler
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import AutoencoderKL, UNet2DConditionModel
from peft import LoraConfig
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path


class RelationShipConvolution(torch.nn.Module):
    def __init__(self, conv_in_pretrained, conv_in_curr, r):
        super(RelationShipConvolution, self).__init__()
        self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained)
        self.conv_in_curr = copy.deepcopy(conv_in_curr)
        self.r = r

    def forward(self, x):
        x1 = self.conv_in_pretrained(x).detach()
        x2 = self.conv_in_curr(x)
        return x1 * (1 - self.r) + x2 * self.r


class PrimaryModel:
    def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'):
        self.backbone_diffusion_path = backbone_diffusion_path
        self.global_unet = None
        self.global_vae = None
        self.global_tokenizer = None
        self.global_text_encoder = None
        self.global_scheduler = None

    @staticmethod
    def _load_model(path, model_class, unet_mode=False):
        model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda')
        return model


    def one_step_scheduler(self):
        noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler")
        noise_scheduler_1step.set_timesteps(1, device="cuda")
        noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
        return noise_scheduler_1step

    def skip_connections(self, vae):
        vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
        vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
        vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
        vae.decoder.ignore_skip = False
        return vae

    def from_pretrained(self, model_name, r):
        if self.global_tokenizer is None:
            # self.global_tokenizer = AutoTokenizer.from_pretrained(self.backbone_diffusion_path,
            #                                                       subfolder="tokenizer")
            self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2")

        if self.global_text_encoder is None:
            self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path,
                                                                     subfolder="text_encoder").to(device='cuda')

        if self.global_scheduler is None:
            self.global_scheduler = self.one_step_scheduler()

        if self.global_vae is None:
            self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL)
            self.global_vae = self.skip_connections(self.global_vae)

        if self.global_unet is None:
            self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True)
            p_ckpt_path = download_models()
            p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path)
            sd = torch.load(p_ckpt, map_location="cpu")
            conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in)
            self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r)
            unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian",
                                          target_modules=sd["unet_lora_target_modules"])
            vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian",
                                         target_modules=sd["vae_lora_target_modules"])
            self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
            _sd_vae = self.global_vae.state_dict()
            for k in sd["state_dict_vae"]:
                _sd_vae[k] = sd["state_dict_vae"][k]
            self.global_vae.load_state_dict(_sd_vae)
            self.global_unet.add_adapter(unet_lora_config)
            _sd_unet = self.global_unet.state_dict()
            for k in sd["state_dict_unet"]:
                _sd_unet[k] = sd["state_dict_unet"][k]
            self.global_unet.load_state_dict(_sd_unet, strict=False)