|
import torch |
|
import copy |
|
import os |
|
from diffusers import DDPMScheduler |
|
from transformers import AutoTokenizer, CLIPTextModel |
|
from diffusers import AutoencoderKL, UNet2DConditionModel |
|
from peft import LoraConfig |
|
from S2I.modules.utils import sc_vae_encoder_fwd, sc_vae_decoder_fwd, download_models, get_model_path, get_s2i_home |
|
|
|
|
|
class RelationShipConvolution(torch.nn.Module): |
|
def __init__(self, conv_in_pretrained, conv_in_curr, r): |
|
super(RelationShipConvolution, self).__init__() |
|
self.conv_in_pretrained = copy.deepcopy(conv_in_pretrained) |
|
self.conv_in_curr = copy.deepcopy(conv_in_curr) |
|
self.r = r |
|
|
|
def forward(self, x): |
|
x1 = self.conv_in_pretrained(x).detach() |
|
x2 = self.conv_in_curr(x) |
|
return x1 * (1 - self.r) + x2 * self.r |
|
|
|
|
|
class PrimaryModel: |
|
def __init__(self, backbone_diffusion_path='stabilityai/sd-turbo'): |
|
self.backbone_diffusion_path = backbone_diffusion_path |
|
self.global_unet = None |
|
self.global_vae = None |
|
self.global_tokenizer = None |
|
self.global_text_encoder = None |
|
self.global_scheduler = None |
|
|
|
@staticmethod |
|
def _load_model(path, model_class, unet_mode=False): |
|
model = model_class.from_pretrained(path, subfolder='unet' if unet_mode else 'vae').to('cuda') |
|
return model |
|
|
|
|
|
def one_step_scheduler(self): |
|
noise_scheduler_1step = DDPMScheduler.from_pretrained(self.backbone_diffusion_path, subfolder="scheduler") |
|
noise_scheduler_1step.set_timesteps(1, device="cuda") |
|
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() |
|
return noise_scheduler_1step |
|
|
|
def skip_connections(self, vae): |
|
vae.encoder.forward = sc_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__) |
|
vae.decoder.forward = sc_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__) |
|
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() |
|
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() |
|
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() |
|
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda() |
|
vae.decoder.ignore_skip = False |
|
return vae |
|
def weights_adapter(self, p_ckpt, model_name): |
|
if model_name == '350k-adapter': |
|
home = get_s2i_home() |
|
sd_sketch = torch.load(os.path.join(home, f"sketch2image_lora_350k.pkl"), map_location="cpu") |
|
sd = torch.load(p_ckpt, map_location="cpu") |
|
sd.update(sd_sketch) |
|
return sd |
|
else: |
|
sd = torch.load(p_ckpt, map_location="cpu") |
|
return sd |
|
def from_pretrained(self, model_name, r): |
|
if self.global_tokenizer is None: |
|
|
|
|
|
self.global_tokenizer = AutoTokenizer.from_pretrained("myn0908/stable-diffusion-3", subfolder="tokenizer_2") |
|
|
|
if self.global_text_encoder is None: |
|
self.global_text_encoder = CLIPTextModel.from_pretrained(self.backbone_diffusion_path, |
|
subfolder="text_encoder").to(device='cuda') |
|
|
|
if self.global_scheduler is None: |
|
self.global_scheduler = self.one_step_scheduler() |
|
|
|
if self.global_vae is None: |
|
self.global_vae = self._load_model(self.backbone_diffusion_path, AutoencoderKL) |
|
self.global_vae = self.skip_connections(self.global_vae) |
|
|
|
if self.global_unet is None: |
|
self.global_unet = self._load_model(self.backbone_diffusion_path, UNet2DConditionModel, unet_mode=True) |
|
p_ckpt_path = download_models() |
|
p_ckpt = get_model_path(model_name=model_name, model_paths=p_ckpt_path) |
|
sd = self.weights_adapter(p_ckpt, model_name) |
|
conv_in_pretrained = copy.deepcopy(self.global_unet.conv_in) |
|
self.global_unet.conv_in = RelationShipConvolution(conv_in_pretrained, self.global_unet.conv_in, r) |
|
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", |
|
target_modules=sd["unet_lora_target_modules"]) |
|
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", |
|
target_modules=sd["vae_lora_target_modules"]) |
|
self.global_vae.add_adapter(vae_lora_config, adapter_name="vae_skip") |
|
_sd_vae = self.global_vae.state_dict() |
|
for k in sd["state_dict_vae"]: |
|
_sd_vae[k] = sd["state_dict_vae"][k] |
|
self.global_vae.load_state_dict(_sd_vae) |
|
self.global_unet.add_adapter(unet_lora_config) |
|
_sd_unet = self.global_unet.state_dict() |
|
for k in sd["state_dict_unet"]: |
|
_sd_unet[k] = sd["state_dict_unet"][k] |
|
self.global_unet.load_state_dict(_sd_unet, strict=False) |
|
|