Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
import pathlib | |
import random | |
import sys | |
from typing import Any | |
import cv2 | |
import numpy as np | |
import PIL.Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
import tqdm.auto | |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel | |
from huggingface_hub import hf_hub_download, snapshot_download | |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
repo_dir = pathlib.Path(__file__).parent | |
submodule_dir = repo_dir / 'ELITE' | |
snapshot_download('ELITE-library/ELITE', | |
repo_type='model', | |
local_dir=submodule_dir.as_posix(), | |
token=HF_TOKEN) | |
sys.path.insert(0, submodule_dir.as_posix()) | |
from train_local import (Mapper, MapperLocal, inj_forward_crossattention, | |
inj_forward_text, th2image, value_local_list) | |
def get_tensor_clip(normalize=True, toTensor=True): | |
transform_list = [] | |
if toTensor: | |
transform_list += [T.ToTensor()] | |
if normalize: | |
transform_list += [ | |
T.Normalize((0.48145466, 0.4578275, 0.40821073), | |
(0.26862954, 0.26130258, 0.27577711)) | |
] | |
return T.Compose(transform_list) | |
def process(image: np.ndarray, size: int = 512) -> torch.Tensor: | |
image = cv2.resize(image, (size, size), interpolation=cv2.INTER_CUBIC) | |
image = np.array(image).astype(np.float32) | |
image = image / 127.5 - 1.0 | |
return torch.from_numpy(image).permute(2, 0, 1) | |
class Model: | |
def __init__(self): | |
self.device = torch.device( | |
'cuda:0' if torch.cuda.is_available() else 'cpu') | |
(self.vae, self.unet, self.text_encoder, self.tokenizer, | |
self.image_encoder, self.mapper, self.mapper_local, | |
self.scheduler) = self.load_model() | |
def download_mappers(self) -> tuple[str, str]: | |
global_mapper_path = hf_hub_download('ELITE-library/ELITE', | |
'global_mapper.pt', | |
subfolder='checkpoints', | |
repo_type='model', | |
token=HF_TOKEN) | |
local_mapper_path = hf_hub_download('ELITE-library/ELITE', | |
'local_mapper.pt', | |
subfolder='checkpoints', | |
repo_type='model', | |
token=HF_TOKEN) | |
return global_mapper_path, local_mapper_path | |
def load_model( | |
self, | |
scheduler_type=LMSDiscreteScheduler | |
) -> tuple[UNet2DConditionModel, CLIPTextModel, CLIPTokenizer, | |
AutoencoderKL, CLIPVisionModel, Mapper, MapperLocal, | |
LMSDiscreteScheduler, ]: | |
diffusion_model_id = 'CompVis/stable-diffusion-v1-4' | |
vae = AutoencoderKL.from_pretrained( | |
diffusion_model_id, | |
subfolder='vae', | |
torch_dtype=torch.float16, | |
) | |
tokenizer = CLIPTokenizer.from_pretrained( | |
'openai/clip-vit-large-patch14', | |
torch_dtype=torch.float16, | |
) | |
text_encoder = CLIPTextModel.from_pretrained( | |
'openai/clip-vit-large-patch14', | |
torch_dtype=torch.float16, | |
) | |
image_encoder = CLIPVisionModel.from_pretrained( | |
'openai/clip-vit-large-patch14', | |
torch_dtype=torch.float16, | |
) | |
# Load models and create wrapper for stable diffusion | |
for _module in text_encoder.modules(): | |
if _module.__class__.__name__ == 'CLIPTextTransformer': | |
_module.__class__.__call__ = inj_forward_text | |
unet = UNet2DConditionModel.from_pretrained( | |
diffusion_model_id, | |
subfolder='unet', | |
torch_dtype=torch.float16, | |
) | |
inj_forward_crossattention | |
mapper = Mapper(input_dim=1024, output_dim=768) | |
mapper_local = MapperLocal(input_dim=1024, output_dim=768) | |
for _name, _module in unet.named_modules(): | |
if _module.__class__.__name__ == 'CrossAttention': | |
if 'attn1' in _name: | |
continue | |
_module.__class__.__call__ = inj_forward_crossattention | |
shape = _module.to_k.weight.shape | |
to_k_global = nn.Linear(shape[1], shape[0], bias=False) | |
mapper.add_module(f'{_name.replace(".", "_")}_to_k', | |
to_k_global) | |
shape = _module.to_v.weight.shape | |
to_v_global = nn.Linear(shape[1], shape[0], bias=False) | |
mapper.add_module(f'{_name.replace(".", "_")}_to_v', | |
to_v_global) | |
to_v_local = nn.Linear(shape[1], shape[0], bias=False) | |
mapper_local.add_module(f'{_name.replace(".", "_")}_to_v', | |
to_v_local) | |
to_k_local = nn.Linear(shape[1], shape[0], bias=False) | |
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k', | |
to_k_local) | |
#global_mapper_path, local_mapper_path = self.download_mappers() | |
global_mapper_path = submodule_dir / 'checkpoints/global_mapper.pt' | |
local_mapper_path = submodule_dir / 'checkpoints/local_mapper.pt' | |
mapper.load_state_dict( | |
torch.load(global_mapper_path, map_location='cpu')) | |
mapper.half() | |
mapper_local.load_state_dict( | |
torch.load(local_mapper_path, map_location='cpu')) | |
mapper_local.half() | |
for _name, _module in unet.named_modules(): | |
if 'attn1' in _name: | |
continue | |
if _module.__class__.__name__ == 'CrossAttention': | |
_module.add_module( | |
'to_k_global', | |
mapper.__getattr__(f'{_name.replace(".", "_")}_to_k')) | |
_module.add_module( | |
'to_v_global', | |
mapper.__getattr__(f'{_name.replace(".", "_")}_to_v')) | |
_module.add_module( | |
'to_v_local', | |
getattr(mapper_local, f'{_name.replace(".", "_")}_to_v')) | |
_module.add_module( | |
'to_k_local', | |
getattr(mapper_local, f'{_name.replace(".", "_")}_to_k')) | |
vae.eval().to(self.device) | |
unet.eval().to(self.device) | |
text_encoder.eval().to(self.device) | |
image_encoder.eval().to(self.device) | |
mapper.eval().to(self.device) | |
mapper_local.eval().to(self.device) | |
scheduler = scheduler_type( | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule='scaled_linear', | |
num_train_timesteps=1000, | |
) | |
return (vae, unet, text_encoder, tokenizer, image_encoder, mapper, | |
mapper_local, scheduler) | |
def prepare_data(self, | |
image: PIL.Image.Image, | |
mask: PIL.Image.Image, | |
text: str, | |
placeholder_string: str = 'S') -> dict[str, Any]: | |
data: dict[str, Any] = {} | |
data['text'] = text | |
placeholder_index = 0 | |
words = text.strip().split(' ') | |
for idx, word in enumerate(words): | |
if word == placeholder_string: | |
placeholder_index = idx + 1 | |
data['index'] = torch.tensor(placeholder_index) | |
data['input_ids'] = self.tokenizer( | |
text, | |
padding='max_length', | |
truncation=True, | |
max_length=self.tokenizer.model_max_length, | |
return_tensors='pt', | |
).input_ids[0] | |
image = image.convert('RGB') | |
mask = mask.convert('RGB') | |
mask = np.array(mask) / 255.0 | |
image_np = np.array(image) | |
object_tensor = image_np * mask | |
data['pixel_values'] = process(image_np) | |
ref_object_tensor = PIL.Image.fromarray( | |
object_tensor.astype('uint8')).resize( | |
(224, 224), resample=PIL.Image.Resampling.BICUBIC) | |
ref_image_tenser = PIL.Image.fromarray( | |
image_np.astype('uint8')).resize( | |
(224, 224), resample=PIL.Image.Resampling.BICUBIC) | |
data['pixel_values_obj'] = get_tensor_clip()(ref_object_tensor) | |
data['pixel_values_clip'] = get_tensor_clip()(ref_image_tenser) | |
ref_seg_tensor = PIL.Image.fromarray(mask.astype('uint8') * 255) | |
ref_seg_tensor = get_tensor_clip(normalize=False)(ref_seg_tensor) | |
data['pixel_values_seg'] = F.interpolate(ref_seg_tensor.unsqueeze(0), | |
size=(128, 128), | |
mode='nearest').squeeze(0) | |
device = torch.device('cuda:0') | |
data['pixel_values'] = data['pixel_values'].to(device) | |
data['pixel_values_clip'] = data['pixel_values_clip'].to(device).half() | |
data['pixel_values_obj'] = data['pixel_values_obj'].to(device).half() | |
data['pixel_values_seg'] = data['pixel_values_seg'].to(device).half() | |
data['input_ids'] = data['input_ids'].to(device) | |
data['index'] = data['index'].to(device).long() | |
for key, value in list(data.items()): | |
if isinstance(value, torch.Tensor): | |
data[key] = value.unsqueeze(0) | |
return data | |
def run( | |
self, | |
image: dict[str, PIL.Image.Image], | |
text: str, | |
seed: int, | |
guidance_scale: float, | |
lambda_: float, | |
num_steps: int, | |
) -> PIL.Image.Image: | |
data = self.prepare_data(image['image'], image['mask'], text) | |
uncond_input = self.tokenizer( | |
[''] * data['pixel_values'].shape[0], | |
padding='max_length', | |
max_length=self.tokenizer.model_max_length, | |
return_tensors='pt', | |
) | |
uncond_embeddings = self.text_encoder( | |
{'input_ids': uncond_input.input_ids.to(self.device)})[0] | |
if seed == -1: | |
seed = random.randint(0, 1000000) | |
generator = torch.Generator().manual_seed(seed) | |
latents = torch.randn( | |
(data['pixel_values'].shape[0], self.unet.in_channels, 64, 64), | |
generator=generator, | |
) | |
latents = latents.to(data['pixel_values_clip']) | |
self.scheduler.set_timesteps(num_steps) | |
latents = latents * self.scheduler.init_noise_sigma | |
placeholder_idx = data['index'] | |
image = F.interpolate(data['pixel_values_clip'], (224, 224), | |
mode='bilinear') | |
image_features = self.image_encoder(image, output_hidden_states=True) | |
image_embeddings = [ | |
image_features[0], | |
image_features[2][4], | |
image_features[2][8], | |
image_features[2][12], | |
image_features[2][16], | |
] | |
image_embeddings = [emb.detach() for emb in image_embeddings] | |
inj_embedding = self.mapper(image_embeddings) | |
inj_embedding = inj_embedding[:, 0:1, :] | |
encoder_hidden_states = self.text_encoder({ | |
'input_ids': | |
data['input_ids'], | |
'inj_embedding': | |
inj_embedding, | |
'inj_index': | |
placeholder_idx, | |
})[0] | |
image_obj = F.interpolate(data['pixel_values_obj'], (224, 224), | |
mode='bilinear') | |
image_features_obj = self.image_encoder(image_obj, | |
output_hidden_states=True) | |
image_embeddings_obj = [ | |
image_features_obj[0], | |
image_features_obj[2][4], | |
image_features_obj[2][8], | |
image_features_obj[2][12], | |
image_features_obj[2][16], | |
] | |
image_embeddings_obj = [emb.detach() for emb in image_embeddings_obj] | |
inj_embedding_local = self.mapper_local(image_embeddings_obj) | |
mask = F.interpolate(data['pixel_values_seg'], (16, 16), | |
mode='nearest') | |
mask = mask[:, 0].reshape(mask.shape[0], -1, 1) | |
inj_embedding_local = inj_embedding_local * mask | |
for t in tqdm.auto.tqdm(self.scheduler.timesteps): | |
latent_model_input = self.scheduler.scale_model_input(latents, t) | |
noise_pred_text = self.unet(latent_model_input, | |
t, | |
encoder_hidden_states={ | |
'CONTEXT_TENSOR': | |
encoder_hidden_states, | |
'LOCAL': inj_embedding_local, | |
'LOCAL_INDEX': | |
placeholder_idx.detach(), | |
'LAMBDA': lambda_ | |
}).sample | |
value_local_list.clear() | |
latent_model_input = self.scheduler.scale_model_input(latents, t) | |
noise_pred_uncond = self.unet(latent_model_input, | |
t, | |
encoder_hidden_states={ | |
'CONTEXT_TENSOR': | |
uncond_embeddings, | |
}).sample | |
value_local_list.clear() | |
noise_pred = noise_pred_uncond + guidance_scale * ( | |
noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
_latents = 1 / 0.18215 * latents.clone() | |
images = self.vae.decode(_latents).sample | |
return th2image(images[0]) | |