File size: 3,417 Bytes
15acbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from diffusers.schedulers import UniPCMultistepScheduler    
from diffusers import AutoencoderKL
from diffusion_module.unet import UNetModel
import torch
from diffusion_module.utils.LSDMPipeline_expandDataset import SDMLDMPipeline
from accelerate import Accelerator
from evolution import random_walk
import cv2
import numpy as np 

def mask2onehot(data, num_classes):
    # move to GPU and change data types
    data = data.to(dtype=torch.int64)

    # create one-hot label map
    label_map = data
    bs, _, h, w = label_map.size()
    input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device)
    input_semantics = input_label.scatter_(1, label_map, 1.0)

    return input_semantics

def generate(img, pretrain_weight,seed=None):

    noise_scheduler = UniPCMultistepScheduler()
    vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
    latent_size  = (64, 64)
    unet = UNetModel(
        image_size = latent_size,
        in_channels=vae.config.latent_channels,
        model_channels=256,
        out_channels=vae.config.latent_channels,
        num_res_blocks=2,
        attention_resolutions=(2, 4, 8),
        dropout=0,
        channel_mult=(1, 2, 3, 4),
        num_heads=8,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=True,
        resblock_updown=True,
        use_new_attention_order=False,
        num_classes=151,
        mask_emb="resize",
        use_checkpoint=True,
        SPADE_type="spade",
    )
    

    unet = unet.from_pretrained(pretrain_weight)
    device = 'cpu'
    if device != 'cpu':
        mixed_precision = "fp16"
    else:
        mixed_precision = "no"
    
    
    accelerator = Accelerator(
        mixed_precision=mixed_precision,
        cpu= True if device is 'cpu' else False
    )

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    
    unet,vae = accelerator.prepare(unet, vae)
    vae.to(device=accelerator.device, dtype=weight_dtype)
    pipeline = SDMLDMPipeline(
        vae=accelerator.unwrap_model(vae),
        unet=accelerator.unwrap_model(unet),
        scheduler=noise_scheduler,
        torch_dtype=weight_dtype,
        resolution_type="crack"
    )
    """
    if accelerator.device != 'cpu':
        pipeline.enable_xformers_memory_efficient_attention()
    """
    pipeline = pipeline.to(accelerator.device)
    pipeline.set_progress_bar_config(disable=False)

    if seed is None:
        generator = None
    else:
        generator = torch.Generator(device=accelerator.device).manual_seed(seed)
    
    resized_s = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA)
    # 灰度图放大到255
    _, binary_s = cv2.threshold(resized_s, 1, 255, cv2.THRESH_BINARY)
    # 转换为0,1
    tensor_s = torch.from_numpy(binary_s / 255)
    # h,w -> 1,1,h,w
    tensor_s = tensor_s.unsqueeze(0).unsqueeze(0)
    onehot_skeletons=[]
    onehot_s = mask2onehot(tensor_s, 151)
    onehot_skeletons.append(onehot_s)

    onehot_skeletons = torch.stack(onehot_skeletons, dim=1).squeeze(0)
    onehot_skeletons = onehot_skeletons.to(dtype=weight_dtype,device=accelerator.device)

    images = pipeline(onehot_skeletons, generator=generator,batch_size = 1,
                    num_inference_steps=20, s=1.5,
                    num_evolution_per_mask=1).images
    
    return images