|
from diffusers.schedulers import UniPCMultistepScheduler |
|
from diffusers import AutoencoderKL |
|
from diffusion_module.unet import UNetModel |
|
import torch |
|
from diffusion_module.utils.LSDMPipeline_expandDataset import SDMLDMPipeline |
|
from accelerate import Accelerator |
|
from evolution import random_walk |
|
import cv2 |
|
import numpy as np |
|
|
|
def mask2onehot(data, num_classes): |
|
|
|
data = data.to(dtype=torch.int64) |
|
|
|
|
|
label_map = data |
|
bs, _, h, w = label_map.size() |
|
input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device) |
|
input_semantics = input_label.scatter_(1, label_map, 1.0) |
|
|
|
return input_semantics |
|
|
|
def generate(img, pretrain_weight,seed=None): |
|
|
|
noise_scheduler = UniPCMultistepScheduler() |
|
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") |
|
latent_size = (64, 64) |
|
unet = UNetModel( |
|
image_size = latent_size, |
|
in_channels=vae.config.latent_channels, |
|
model_channels=256, |
|
out_channels=vae.config.latent_channels, |
|
num_res_blocks=2, |
|
attention_resolutions=(2, 4, 8), |
|
dropout=0, |
|
channel_mult=(1, 2, 3, 4), |
|
num_heads=8, |
|
num_head_channels=-1, |
|
num_heads_upsample=-1, |
|
use_scale_shift_norm=True, |
|
resblock_updown=True, |
|
use_new_attention_order=False, |
|
num_classes=151, |
|
mask_emb="resize", |
|
use_checkpoint=True, |
|
SPADE_type="spade", |
|
) |
|
|
|
|
|
unet = unet.from_pretrained(pretrain_weight) |
|
device = 'cpu' |
|
if device != 'cpu': |
|
mixed_precision = "fp16" |
|
else: |
|
mixed_precision = "no" |
|
|
|
|
|
accelerator = Accelerator( |
|
mixed_precision=mixed_precision, |
|
cpu= True if device is 'cpu' else False |
|
) |
|
|
|
weight_dtype = torch.float32 |
|
if accelerator.mixed_precision == "fp16": |
|
weight_dtype = torch.float16 |
|
|
|
unet,vae = accelerator.prepare(unet, vae) |
|
vae.to(device=accelerator.device, dtype=weight_dtype) |
|
pipeline = SDMLDMPipeline( |
|
vae=accelerator.unwrap_model(vae), |
|
unet=accelerator.unwrap_model(unet), |
|
scheduler=noise_scheduler, |
|
torch_dtype=weight_dtype, |
|
resolution_type="crack" |
|
) |
|
""" |
|
if accelerator.device != 'cpu': |
|
pipeline.enable_xformers_memory_efficient_attention() |
|
""" |
|
pipeline = pipeline.to(accelerator.device) |
|
pipeline.set_progress_bar_config(disable=False) |
|
|
|
if seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
|
|
|
resized_s = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA) |
|
|
|
_, binary_s = cv2.threshold(resized_s, 1, 255, cv2.THRESH_BINARY) |
|
|
|
tensor_s = torch.from_numpy(binary_s / 255) |
|
|
|
tensor_s = tensor_s.unsqueeze(0).unsqueeze(0) |
|
onehot_skeletons=[] |
|
onehot_s = mask2onehot(tensor_s, 151) |
|
onehot_skeletons.append(onehot_s) |
|
|
|
onehot_skeletons = torch.stack(onehot_skeletons, dim=1).squeeze(0) |
|
onehot_skeletons = onehot_skeletons.to(dtype=weight_dtype,device=accelerator.device) |
|
|
|
images = pipeline(onehot_skeletons, generator=generator,batch_size = 1, |
|
num_inference_steps=20, s=1.5, |
|
num_evolution_per_mask=1).images |
|
|
|
return images |