from __future__ import annotations import os import pathlib import random import sys from typing import Any import cv2 import numpy as np import PIL.Image import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T import tqdm.auto from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel from huggingface_hub import hf_hub_download, snapshot_download from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel HF_TOKEN = os.getenv('HF_TOKEN') repo_dir = pathlib.Path(__file__).parent submodule_dir = repo_dir / 'ELITE' snapshot_download('ELITE-library/ELITE', repo_type='model', local_dir=submodule_dir.as_posix(), token=HF_TOKEN) sys.path.insert(0, submodule_dir.as_posix()) from train_local import (Mapper, MapperLocal, inj_forward_crossattention, inj_forward_text, th2image, value_local_list) def get_tensor_clip(normalize=True, toTensor=True): transform_list = [] if toTensor: transform_list += [T.ToTensor()] if normalize: transform_list += [ T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ] return T.Compose(transform_list) def process(image: np.ndarray, size: int = 512) -> torch.Tensor: image = cv2.resize(image, (size, size), interpolation=cv2.INTER_CUBIC) image = np.array(image).astype(np.float32) image = image / 127.5 - 1.0 return torch.from_numpy(image).permute(2, 0, 1) class Model: def __init__(self): self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') (self.vae, self.unet, self.text_encoder, self.tokenizer, self.image_encoder, self.mapper, self.mapper_local, self.scheduler) = self.load_model() def download_mappers(self) -> tuple[str, str]: global_mapper_path = hf_hub_download('ELITE-library/ELITE', 'global_mapper.pt', subfolder='checkpoints', repo_type='model', token=HF_TOKEN) local_mapper_path = hf_hub_download('ELITE-library/ELITE', 'local_mapper.pt', subfolder='checkpoints', repo_type='model', token=HF_TOKEN) return global_mapper_path, local_mapper_path def load_model( self, scheduler_type=LMSDiscreteScheduler ) -> tuple[UNet2DConditionModel, CLIPTextModel, CLIPTokenizer, AutoencoderKL, CLIPVisionModel, Mapper, MapperLocal, LMSDiscreteScheduler, ]: diffusion_model_id = 'CompVis/stable-diffusion-v1-4' vae = AutoencoderKL.from_pretrained( diffusion_model_id, subfolder='vae', torch_dtype=torch.float16, ) tokenizer = CLIPTokenizer.from_pretrained( 'openai/clip-vit-large-patch14', torch_dtype=torch.float16, ) text_encoder = CLIPTextModel.from_pretrained( 'openai/clip-vit-large-patch14', torch_dtype=torch.float16, ) image_encoder = CLIPVisionModel.from_pretrained( 'openai/clip-vit-large-patch14', torch_dtype=torch.float16, ) # Load models and create wrapper for stable diffusion for _module in text_encoder.modules(): if _module.__class__.__name__ == 'CLIPTextTransformer': _module.__class__.__call__ = inj_forward_text unet = UNet2DConditionModel.from_pretrained( diffusion_model_id, subfolder='unet', torch_dtype=torch.float16, ) inj_forward_crossattention mapper = Mapper(input_dim=1024, output_dim=768) mapper_local = MapperLocal(input_dim=1024, output_dim=768) for _name, _module in unet.named_modules(): if _module.__class__.__name__ == 'CrossAttention': if 'attn1' in _name: continue _module.__class__.__call__ = inj_forward_crossattention shape = _module.to_k.weight.shape to_k_global = nn.Linear(shape[1], shape[0], bias=False) mapper.add_module(f'{_name.replace(".", "_")}_to_k', to_k_global) shape = _module.to_v.weight.shape to_v_global = nn.Linear(shape[1], shape[0], bias=False) mapper.add_module(f'{_name.replace(".", "_")}_to_v', to_v_global) to_v_local = nn.Linear(shape[1], shape[0], bias=False) mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', to_v_local) to_k_local = nn.Linear(shape[1], shape[0], bias=False) mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', to_k_local) #global_mapper_path, local_mapper_path = self.download_mappers() global_mapper_path = submodule_dir / 'checkpoints/global_mapper.pt' local_mapper_path = submodule_dir / 'checkpoints/local_mapper.pt' mapper.load_state_dict( torch.load(global_mapper_path, map_location='cpu')) mapper.half() mapper_local.load_state_dict( torch.load(local_mapper_path, map_location='cpu')) mapper_local.half() for _name, _module in unet.named_modules(): if 'attn1' in _name: continue if _module.__class__.__name__ == 'CrossAttention': _module.add_module( 'to_k_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_k')) _module.add_module( 'to_v_global', mapper.__getattr__(f'{_name.replace(".", "_")}_to_v')) _module.add_module( 'to_v_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_v')) _module.add_module( 'to_k_local', getattr(mapper_local, f'{_name.replace(".", "_")}_to_k')) vae.eval().to(self.device) unet.eval().to(self.device) text_encoder.eval().to(self.device) image_encoder.eval().to(self.device) mapper.eval().to(self.device) mapper_local.eval().to(self.device) scheduler = scheduler_type( beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear', num_train_timesteps=1000, ) return (vae, unet, text_encoder, tokenizer, image_encoder, mapper, mapper_local, scheduler) def prepare_data(self, image: PIL.Image.Image, mask: PIL.Image.Image, text: str, placeholder_string: str = 'S') -> dict[str, Any]: data: dict[str, Any] = {} data['text'] = text placeholder_index = 0 words = text.strip().split(' ') for idx, word in enumerate(words): if word == placeholder_string: placeholder_index = idx + 1 data['index'] = torch.tensor(placeholder_index) data['input_ids'] = self.tokenizer( text, padding='max_length', truncation=True, max_length=self.tokenizer.model_max_length, return_tensors='pt', ).input_ids[0] image = image.convert('RGB') mask = mask.convert('RGB') mask = np.array(mask) / 255.0 image_np = np.array(image) object_tensor = image_np * mask data['pixel_values'] = process(image_np) ref_object_tensor = PIL.Image.fromarray( object_tensor.astype('uint8')).resize( (224, 224), resample=PIL.Image.Resampling.BICUBIC) ref_image_tenser = PIL.Image.fromarray( image_np.astype('uint8')).resize( (224, 224), resample=PIL.Image.Resampling.BICUBIC) data['pixel_values_obj'] = get_tensor_clip()(ref_object_tensor) data['pixel_values_clip'] = get_tensor_clip()(ref_image_tenser) ref_seg_tensor = PIL.Image.fromarray(mask.astype('uint8') * 255) ref_seg_tensor = get_tensor_clip(normalize=False)(ref_seg_tensor) data['pixel_values_seg'] = F.interpolate(ref_seg_tensor.unsqueeze(0), size=(128, 128), mode='nearest').squeeze(0) device = torch.device('cuda:0') data['pixel_values'] = data['pixel_values'].to(device) data['pixel_values_clip'] = data['pixel_values_clip'].to(device).half() data['pixel_values_obj'] = data['pixel_values_obj'].to(device).half() data['pixel_values_seg'] = data['pixel_values_seg'].to(device).half() data['input_ids'] = data['input_ids'].to(device) data['index'] = data['index'].to(device).long() for key, value in list(data.items()): if isinstance(value, torch.Tensor): data[key] = value.unsqueeze(0) return data @torch.inference_mode() def run( self, image: dict[str, PIL.Image.Image], text: str, seed: int, guidance_scale: float, lambda_: float, num_steps: int, ) -> PIL.Image.Image: data = self.prepare_data(image['image'], image['mask'], text) uncond_input = self.tokenizer( [''] * data['pixel_values'].shape[0], padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt', ) uncond_embeddings = self.text_encoder( {'input_ids': uncond_input.input_ids.to(self.device)})[0] if seed == -1: seed = random.randint(0, 1000000) generator = torch.Generator().manual_seed(seed) latents = torch.randn( (data['pixel_values'].shape[0], self.unet.in_channels, 64, 64), generator=generator, ) latents = latents.to(data['pixel_values_clip']) self.scheduler.set_timesteps(num_steps) latents = latents * self.scheduler.init_noise_sigma placeholder_idx = data['index'] image = F.interpolate(data['pixel_values_clip'], (224, 224), mode='bilinear') image_features = self.image_encoder(image, output_hidden_states=True) image_embeddings = [ image_features[0], image_features[2][4], image_features[2][8], image_features[2][12], image_features[2][16], ] image_embeddings = [emb.detach() for emb in image_embeddings] inj_embedding = self.mapper(image_embeddings) inj_embedding = inj_embedding[:, 0:1, :] encoder_hidden_states = self.text_encoder({ 'input_ids': data['input_ids'], 'inj_embedding': inj_embedding, 'inj_index': placeholder_idx, })[0] image_obj = F.interpolate(data['pixel_values_obj'], (224, 224), mode='bilinear') image_features_obj = self.image_encoder(image_obj, output_hidden_states=True) image_embeddings_obj = [ image_features_obj[0], image_features_obj[2][4], image_features_obj[2][8], image_features_obj[2][12], image_features_obj[2][16], ] image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj] inj_embedding_local = self.mapper_local(image_embeddings_obj) mask = F.interpolate(data['pixel_values_seg'], (16, 16), mode='nearest') mask = mask[:, 0].reshape(mask.shape[0], -1, 1) inj_embedding_local = inj_embedding_local * mask for t in tqdm.auto.tqdm(self.scheduler.timesteps): latent_model_input = self.scheduler.scale_model_input(latents, t) noise_pred_text = self.unet(latent_model_input, t, encoder_hidden_states={ 'CONTEXT_TENSOR': encoder_hidden_states, 'LOCAL': inj_embedding_local, 'LOCAL_INDEX': placeholder_idx.detach(), 'LAMBDA': lambda_ }).sample value_local_list.clear() latent_model_input = self.scheduler.scale_model_input(latents, t) noise_pred_uncond = self.unet(latent_model_input, t, encoder_hidden_states={ 'CONTEXT_TENSOR': uncond_embeddings, }).sample value_local_list.clear() noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents).prev_sample _latents = 1 / 0.18215 * latents.clone() images = self.vae.decode(_latents).sample return th2image(images[0])