|
import random |
|
import torch |
|
from collections import OrderedDict |
|
|
|
import numpy as np |
|
from PIL import Image |
|
import torchvision.transforms as T |
|
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor |
|
from torchvision import transforms as tvtrans |
|
|
|
from decord import VideoReader, cpu, gpu |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_duplicate_word(tx): |
|
def combine_words(input, length): |
|
combined_inputs = [] |
|
if len(splitted_input) > 1: |
|
for i in range(len(input) - 1): |
|
combined_inputs.append(input[i] + " " + last_word_of(splitted_input[i + 1], |
|
length)) |
|
return combined_inputs, length + 1 |
|
|
|
def remove_duplicates(input, length): |
|
bool_broke = False |
|
for i in range(len(input) - length): |
|
if input[i] == input[i + length]: |
|
for j in range(0, length): |
|
del input[i + length - j] |
|
bool_broke = True |
|
break |
|
if bool_broke: |
|
return remove_duplicates(input, |
|
length) |
|
return input |
|
|
|
def last_word_of(input, length): |
|
splitted = input.split(" ") |
|
if len(splitted) == 0: |
|
return input |
|
else: |
|
return splitted[length - 1] |
|
|
|
def split_and_puncsplit(text): |
|
tx = text.split(" ") |
|
txnew = [] |
|
for txi in tx: |
|
txqueue = [] |
|
while True: |
|
if txi[0] in '([{': |
|
txqueue.extend([txi[:1], '<puncnext>']) |
|
txi = txi[1:] |
|
if len(txi) == 0: |
|
break |
|
else: |
|
break |
|
txnew += txqueue |
|
txstack = [] |
|
if len(txi) == 0: |
|
continue |
|
while True: |
|
if txi[-1] in '?!.,:;}])': |
|
txstack = ['<puncnext>', txi[-1:]] + txstack |
|
txi = txi[:-1] |
|
if len(txi) == 0: |
|
break |
|
else: |
|
break |
|
if len(txi) != 0: |
|
txnew += [txi] |
|
txnew += txstack |
|
return txnew |
|
|
|
if tx == '': |
|
return tx |
|
|
|
splitted_input = split_and_puncsplit(tx) |
|
word_length = 1 |
|
intermediate_output = False |
|
while len(splitted_input) > 1: |
|
splitted_input = remove_duplicates(splitted_input, word_length) |
|
if len(splitted_input) > 1: |
|
splitted_input, word_length = combine_words(splitted_input, word_length) |
|
if intermediate_output: |
|
print(splitted_input) |
|
print(word_length) |
|
output = splitted_input[0] |
|
output = output.replace(' <puncnext> ', '') |
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def regularize_image(x, image_size=512): |
|
if isinstance(x, str): |
|
x = Image.open(x) |
|
size = min(x.size) |
|
elif isinstance(x, Image.Image): |
|
x = x.convert('RGB') |
|
size = min(x.size) |
|
elif isinstance(x, np.ndarray): |
|
x = Image.fromarray(x).convert('RGB') |
|
size = min(x.size) |
|
elif isinstance(x, torch.Tensor): |
|
|
|
size = min(x.size()[1:]) |
|
else: |
|
assert False, 'Unknown image type' |
|
|
|
"""transforms = T.Compose([ |
|
T.RandomCrop(size), |
|
T.Resize( |
|
(image_size, image_size), |
|
interpolation=BICUBIC, |
|
), |
|
T.RandomHorizontalFlip(), |
|
T.ToTensor(), |
|
]) |
|
x = transforms(x) |
|
assert (x.shape[1] == image_size) & (x.shape[2] == image_size), \ |
|
'Wrong image size' |
|
""" |
|
x = x * 2 - 1 |
|
return x |
|
|
|
|
|
def center_crop(img, new_width=None, new_height=None): |
|
width = img.shape[2] |
|
height = img.shape[1] |
|
|
|
if new_width is None: |
|
new_width = min(width, height) |
|
|
|
if new_height is None: |
|
new_height = min(width, height) |
|
|
|
left = int(np.ceil((width - new_width) / 2)) |
|
right = width - int(np.floor((width - new_width) / 2)) |
|
|
|
top = int(np.ceil((height - new_height) / 2)) |
|
bottom = height - int(np.floor((height - new_height) / 2)) |
|
if len(img.shape) == 3: |
|
center_cropped_img = img[:, top:bottom, left:right] |
|
else: |
|
center_cropped_img = img[:, top:bottom, left:right, ...] |
|
|
|
return center_cropped_img |
|
|
|
|
|
def _transform(n_px): |
|
return Compose([ |
|
Resize([n_px, n_px], interpolation=T.InterpolationMode.BICUBIC), ]) |
|
|
|
|
|
def regularize_video(video, image_size=256): |
|
min_shape = min(video.shape[1:3]) |
|
video = center_crop(video, min_shape, min_shape) |
|
video = torch.from_numpy(video).permute(0, 3, 1, 2) |
|
video = _transform(image_size)(video) |
|
video = video / 255.0 * 2.0 - 1.0 |
|
return video.permute(1, 0, 2, 3) |
|
|
|
|
|
def time_to_indices(video_reader, time): |
|
times = video_reader.get_frame_timestamp(range(len(video_reader))).mean(-1) |
|
indices = np.searchsorted(times, time) |
|
|
|
return np.where(np.bitwise_or(indices == 0, times[indices] - time <= time - times[indices - 1]), indices, |
|
indices - 1) |
|
|
|
|
|
def load_video(video_path, sample_duration=8.0, num_frames=8): |
|
sample_duration = 4.0 |
|
num_frames = 4 |
|
|
|
vr = VideoReader(video_path, ctx=cpu(0)) |
|
framerate = vr.get_avg_fps() |
|
video_frame_len = len(vr) |
|
video_len = video_frame_len / framerate |
|
sample_duration = min(sample_duration, video_len) |
|
|
|
if video_len > sample_duration: |
|
s = random.random() * (video_len - sample_duration) |
|
t = s + sample_duration |
|
start, end = time_to_indices(vr, [s, t]) |
|
end = min(video_frame_len - 1, end) |
|
start = min(start, end - 1) |
|
downsamlp_indices = np.linspace(start, end, num_frames, endpoint=True).astype(int).tolist() |
|
else: |
|
downsamlp_indices = np.linspace(0, video_frame_len - 1, num_frames, endpoint=True).astype(int).tolist() |
|
|
|
video = vr.get_batch(downsamlp_indices).asnumpy() |
|
return video |
|
|
|
|
|
|
|
|
|
|
|
|
|
def atomic_save(cfg, net, opt, step, path): |
|
if isinstance(net, (torch.nn.DataParallel, |
|
torch.nn.parallel.DistributedDataParallel)): |
|
netm = net.module |
|
else: |
|
netm = net |
|
sd = netm.state_dict() |
|
slimmed_sd = [(ki, vi) for ki, vi in sd.items() |
|
if ki.find('first_stage_model') != 0 and ki.find('cond_stage_model') != 0] |
|
|
|
checkpoint = { |
|
"config": cfg, |
|
"state_dict": OrderedDict(slimmed_sd), |
|
"step": step} |
|
if opt is not None: |
|
checkpoint['optimizer_states'] = opt.state_dict() |
|
import io |
|
import fsspec |
|
bytesbuffer = io.BytesIO() |
|
torch.save(checkpoint, bytesbuffer) |
|
with fsspec.open(path, "wb") as f: |
|
f.write(bytesbuffer.getvalue()) |
|
|
|
|
|
def load_state_dict(net, cfg): |
|
pretrained_pth_full = cfg.get('pretrained_pth_full', None) |
|
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None) |
|
pretrained_pth = cfg.get('pretrained_pth', None) |
|
pretrained_ckpt = cfg.get('pretrained_ckpt', None) |
|
pretrained_pth_dm = cfg.get('pretrained_pth_dm', None) |
|
pretrained_pth_ema = cfg.get('pretrained_pth_ema', None) |
|
strict_sd = cfg.get('strict_sd', False) |
|
errmsg = "Overlapped model state_dict! This is undesired behavior!" |
|
|
|
if pretrained_pth_full is not None or pretrained_ckpt_full is not None: |
|
assert (pretrained_pth is None) and \ |
|
(pretrained_ckpt is None) and \ |
|
(pretrained_pth_dm is None) and \ |
|
(pretrained_pth_ema is None), errmsg |
|
if pretrained_pth_full is not None: |
|
target_file = pretrained_pth_full |
|
sd = torch.load(target_file, map_location='cpu') |
|
assert pretrained_ckpt is None, errmsg |
|
else: |
|
target_file = pretrained_ckpt_full |
|
sd = torch.load(target_file, map_location='cpu')['state_dict'] |
|
print('Load full model from [{}] strict [{}].'.format( |
|
target_file, strict_sd)) |
|
net.load_state_dict(sd, strict=strict_sd) |
|
|
|
if pretrained_pth is not None or pretrained_ckpt is not None: |
|
assert (pretrained_ckpt_full is None) and \ |
|
(pretrained_pth_full is None) and \ |
|
(pretrained_pth_dm is None) and \ |
|
(pretrained_pth_ema is None), errmsg |
|
if pretrained_pth is not None: |
|
target_file = pretrained_pth |
|
sd = torch.load(target_file, map_location='cpu') |
|
assert pretrained_ckpt is None, errmsg |
|
else: |
|
target_file = pretrained_ckpt |
|
sd = torch.load(target_file, map_location='cpu')['state_dict'] |
|
print('Load model from [{}] strict [{}].'.format( |
|
target_file, strict_sd)) |
|
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \ |
|
if ki.find('first_stage_model') == 0 or ki.find('cond_stage_model') == 0] |
|
sd.update(OrderedDict(sd_extra)) |
|
net.load_state_dict(sd, strict=strict_sd) |
|
|
|
if pretrained_pth_dm is not None: |
|
assert (pretrained_ckpt_full is None) and \ |
|
(pretrained_pth_full is None) and \ |
|
(pretrained_pth is None) and \ |
|
(pretrained_ckpt is None), errmsg |
|
print('Load diffusion model from [{}] strict [{}].'.format( |
|
pretrained_pth_dm, strict_sd)) |
|
sd = torch.load(pretrained_pth_dm, map_location='cpu') |
|
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd) |
|
|
|
if pretrained_pth_ema is not None: |
|
assert (pretrained_ckpt_full is None) and \ |
|
(pretrained_pth_full is None) and \ |
|
(pretrained_pth is None) and \ |
|
(pretrained_ckpt is None), errmsg |
|
print('Load unet ema model from [{}] strict [{}].'.format( |
|
pretrained_pth_ema, strict_sd)) |
|
sd = torch.load(pretrained_pth_ema, map_location='cpu') |
|
net.model_ema.load_state_dict(sd, strict=strict_sd) |
|
|
|
|
|
def auto_merge_imlist(imlist, max=64): |
|
imlist = imlist[0:max] |
|
h, w = imlist[0].shape[0:2] |
|
num_images = len(imlist) |
|
num_row = int(np.sqrt(num_images)) |
|
num_col = num_images // num_row + 1 if num_images % num_row != 0 else num_images // num_row |
|
canvas = np.zeros([num_row * h, num_col * w, 3], dtype=np.uint8) |
|
for idx, im in enumerate(imlist): |
|
hi = (idx // num_col) * h |
|
wi = (idx % num_col) * w |
|
canvas[hi:hi + h, wi:wi + w, :] = im |
|
return canvas |
|
|
|
|
|
def latent2im(net, latent): |
|
single_input = len(latent.shape) == 3 |
|
if single_input: |
|
latent = latent[None] |
|
im = net.decode_image(latent.to(net.device)) |
|
im = torch.clamp((im + 1.0) / 2.0, min=0.0, max=1.0) |
|
im = [tvtrans.ToPILImage()(i) for i in im] |
|
if single_input: |
|
im = im[0] |
|
return im |
|
|
|
|
|
def im2latent(net, im): |
|
single_input = not isinstance(im, list) |
|
if single_input: |
|
im = [im] |
|
im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0) |
|
im = (im * 2 - 1).to(net.device) |
|
z = net.encode_image(im) |
|
if single_input: |
|
z = z[0] |
|
return z |
|
|
|
|
|
class color_adjust(object): |
|
def __init__(self, ref_from, ref_to): |
|
x0, m0, std0 = self.get_data_and_stat(ref_from) |
|
x1, m1, std1 = self.get_data_and_stat(ref_to) |
|
self.ref_from_stat = (m0, std0) |
|
self.ref_to_stat = (m1, std1) |
|
self.ref_from = self.preprocess(x0).reshape(-1, 3) |
|
self.ref_to = x1.reshape(-1, 3) |
|
|
|
def get_data_and_stat(self, x): |
|
if isinstance(x, str): |
|
x = np.array(PIL.Image.open(x)) |
|
elif isinstance(x, PIL.Image.Image): |
|
x = np.array(x) |
|
elif isinstance(x, torch.Tensor): |
|
x = torch.clamp(x, min=0.0, max=1.0) |
|
x = np.array(tvtrans.ToPILImage()(x)) |
|
elif isinstance(x, np.ndarray): |
|
pass |
|
else: |
|
raise ValueError |
|
x = x.astype(float) |
|
m = np.reshape(x, (-1, 3)).mean(0) |
|
s = np.reshape(x, (-1, 3)).std(0) |
|
return x, m, s |
|
|
|
def preprocess(self, x): |
|
m0, s0 = self.ref_from_stat |
|
m1, s1 = self.ref_to_stat |
|
y = ((x - m0) / s0) * s1 + m1 |
|
return y |
|
|
|
def __call__(self, xin, keep=0, simple=False): |
|
xin, _, _ = self.get_data_and_stat(xin) |
|
x = self.preprocess(xin) |
|
if simple: |
|
y = (x * (1 - keep) + xin * keep) |
|
y = np.clip(y, 0, 255).astype(np.uint8) |
|
return y |
|
|
|
h, w = x.shape[:2] |
|
x = x.reshape(-1, 3) |
|
y = [] |
|
for chi in range(3): |
|
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi]) |
|
y.append(yi) |
|
|
|
y = np.stack(y, axis=1) |
|
y = y.reshape(h, w, 3) |
|
y = (y.astype(float) * (1 - keep) + xin.astype(float) * keep) |
|
y = np.clip(y, 0, 255).astype(np.uint8) |
|
return y |
|
|
|
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600): |
|
arr = np.concatenate((arr_fo, arr_to)) |
|
min_v = arr.min() - 1e-6 |
|
max_v = arr.max() + 1e-6 |
|
min_vto = arr_to.min() - 1e-6 |
|
max_vto = arr_to.max() + 1e-6 |
|
xs = np.array( |
|
[min_v + (max_v - min_v) * i / n for i in range(n + 1)]) |
|
hist_fo, _ = np.histogram(arr_fo, xs) |
|
hist_to, _ = np.histogram(arr_to, xs) |
|
xs = xs[:-1] |
|
|
|
cum_fo = np.cumsum(hist_fo) |
|
cum_to = np.cumsum(hist_to) |
|
d_fo = cum_fo / cum_fo[-1] |
|
d_to = cum_to / cum_to[-1] |
|
|
|
t_d = np.interp(d_fo, d_to, xs) |
|
t_d[d_fo <= d_to[0]] = min_vto |
|
t_d[d_fo >= d_to[-1]] = max_vto |
|
arr_out = np.interp(arr_in, xs, t_d) |
|
return arr_out |
|
|