File size: 5,324 Bytes
2f4febc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import PIL
import torch
import requests
import torchvision
from math import ceil
from io import BytesIO
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import math
from tqdm import tqdm
def download_image(url):
    return PIL.Image.open(requests.get(url, stream=True).raw).convert("RGB")


def resize_image(image, size=768):
    tensor_image = F.to_tensor(image)
    resized_image = F.resize(tensor_image, size, antialias=True)
    return resized_image


def downscale_images(images, factor=3/4):
    scaled_height, scaled_width = int(((images.size(-2)*factor)//32)*32), int(((images.size(-1)*factor)//32)*32)
    scaled_image = torchvision.transforms.functional.resize(images, (scaled_height, scaled_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST)
    return scaled_image



def calculate_latent_sizes(height=1024, width=1024, batch_size=4, compression_factor_b=42.67, compression_factor_a=4.0):
    resolution_multiple = 42.67
    latent_height = ceil(height / compression_factor_b)
    latent_width = ceil(width / compression_factor_b)
    stage_c_latent_shape = (batch_size, 16, latent_height, latent_width)
    
    latent_height = ceil(height / compression_factor_a)
    latent_width = ceil(width / compression_factor_a)
    stage_b_latent_shape = (batch_size, 4, latent_height, latent_width)
    
    return stage_c_latent_shape, stage_b_latent_shape


def get_views(H, W, window_size=64, stride=16):
    '''
    - H, W: height and width of the latent
    '''
    num_blocks_height = (H - window_size) // stride + 1
    num_blocks_width = (W - window_size) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size
        views.append((h_start, h_end, w_start, w_end))
    return views

       

def show_images(images, rows=None, cols=None, **kwargs):
    if images.size(1) == 1:
        images = images.repeat(1, 3, 1, 1)
    elif images.size(1) > 3:
        images = images[:, :3]
    
    if rows is None:
        rows = 1
    if cols is None:
        cols = images.size(0) // rows

    _, _, h, w = images.shape

    imgs = []
    for i, img in enumerate(images):
        imgs.append( torchvision.transforms.functional.to_pil_image(img.clamp(0, 1)))
    
    return imgs
    


def decode_b(conditions_b, unconditions_b, models_b, bshape,  extras_b, device, \
    stage_a_tiled=False, num_instance=4, patch_size=256, stride=24):
   
    
    sampling_b = extras_b.gdf.sample(
        models_b.generator.half(), conditions_b,  bshape,
            unconditions_b, device=device,
            **extras_b.sampling_configs,
        )
    models_b.generator.cuda()
    for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
        sampled_b = sampled_b
    models_b.generator.cpu()
    torch.cuda.empty_cache()
    if stage_a_tiled:
        with torch.cuda.amp.autocast(dtype=torch.float16):
            padding = (stride*2, stride*2, stride*2, stride*2)
            sampled_b = torch.nn.functional.pad(sampled_b, padding, mode='reflect')
            count = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
            sampled = torch.zeros((sampled_b.shape[0], 3, sampled_b.shape[-2]*4, sampled_b.shape[-1]*4), requires_grad=False, device=sampled_b.device)
            views = get_views(sampled_b.shape[-2], sampled_b.shape[-1], window_size=patch_size, stride=stride)
           
            for view_idx, (h_start, h_end, w_start, w_end) in enumerate(tqdm(views, total=len(views))):
            
                sampled[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += models_b.stage_a.decode(sampled_b[:, :, h_start:h_end, w_start:w_end]).float()   
                count[:, :, h_start*4:h_end*4, w_start*4:w_end*4] += 1
            sampled /= count    
            sampled = sampled[:, :, stride*4*2:-stride*4*2, stride*4*2:-stride*4*2]
    else:
    
        sampled = models_b.stage_a.decode(sampled_b, tiled_decoding=stage_a_tiled)

    return sampled.float()


def generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device, conditions=None, unconditions=None):
    if conditions is None:
        conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
    if unconditions is None:
        unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)    
    sampling_c = extras.gdf.sample(
                models.generator,  conditions, stage_c_latent_shape, stage_c_latent_shape_lr, 
                unconditions, device=device, **extras.sampling_configs, 
            )
    for idx, (sampled_c, sampled_c_curr, _, _) in enumerate(tqdm(sampling_c, total=extras.sampling_configs['timesteps'])):
                sampled_c = sampled_c
    return sampled_c
    
def get_target_lr_size(ratio, std_size=24):
        w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) 
        return (h * 32 , w *32 )