MedCoDi-M / core /common /utils.py
dmolino's picture
Upload 276 files
168a510 verified
raw
history blame
14.3 kB
import random
import torch
from collections import OrderedDict
import numpy as np
from PIL import Image
import torchvision.transforms as T
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
from torchvision import transforms as tvtrans
from decord import VideoReader, cpu, gpu
###############
# text helper #
###############
def remove_duplicate_word(tx):
def combine_words(input, length):
combined_inputs = []
if len(splitted_input) > 1:
for i in range(len(input) - 1):
combined_inputs.append(input[i] + " " + last_word_of(splitted_input[i + 1],
length)) # add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence
return combined_inputs, length + 1
def remove_duplicates(input, length):
bool_broke = False #this means we didn't find any duplicates here
for i in range(len(input) - length):
if input[i] == input[i + length]: #found a duplicate piece of sentence!
for j in range(0, length): #remove the overlapping sequences in reverse order
del input[i + length - j]
bool_broke = True
break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements
if bool_broke:
return remove_duplicates(input,
length) #if we found a duplicate, look for another duplicate of the same length
return input
def last_word_of(input, length):
splitted = input.split(" ")
if len(splitted) == 0:
return input
else:
return splitted[length - 1]
def split_and_puncsplit(text):
tx = text.split(" ")
txnew = []
for txi in tx:
txqueue = []
while True:
if txi[0] in '([{':
txqueue.extend([txi[:1], '<puncnext>'])
txi = txi[1:]
if len(txi) == 0:
break
else:
break
txnew += txqueue
txstack = []
if len(txi) == 0:
continue
while True:
if txi[-1] in '?!.,:;}])':
txstack = ['<puncnext>', txi[-1:]] + txstack
txi = txi[:-1]
if len(txi) == 0:
break
else:
break
if len(txi) != 0:
txnew += [txi]
txnew += txstack
return txnew
if tx == '':
return tx
splitted_input = split_and_puncsplit(tx)
word_length = 1
intermediate_output = False
while len(splitted_input) > 1:
splitted_input = remove_duplicates(splitted_input, word_length)
if len(splitted_input) > 1:
splitted_input, word_length = combine_words(splitted_input, word_length)
if intermediate_output:
print(splitted_input)
print(word_length)
output = splitted_input[0]
output = output.replace(' <puncnext> ', '')
return output
#################
# vision helper #
#################
def regularize_image(x, image_size=512):
if isinstance(x, str):
x = Image.open(x)
size = min(x.size)
elif isinstance(x, Image.Image):
x = x.convert('RGB')
size = min(x.size)
elif isinstance(x, np.ndarray):
x = Image.fromarray(x).convert('RGB')
size = min(x.size)
elif isinstance(x, torch.Tensor):
# normalize to [0, 1]
size = min(x.size()[1:])
else:
assert False, 'Unknown image type'
"""transforms = T.Compose([
T.RandomCrop(size),
T.Resize(
(image_size, image_size),
interpolation=BICUBIC,
),
T.RandomHorizontalFlip(),
T.ToTensor(),
])
x = transforms(x)
assert (x.shape[1] == image_size) & (x.shape[2] == image_size), \
'Wrong image size'
"""
x = x * 2 - 1
return x
def center_crop(img, new_width=None, new_height=None):
width = img.shape[2]
height = img.shape[1]
if new_width is None:
new_width = min(width, height)
if new_height is None:
new_height = min(width, height)
left = int(np.ceil((width - new_width) / 2))
right = width - int(np.floor((width - new_width) / 2))
top = int(np.ceil((height - new_height) / 2))
bottom = height - int(np.floor((height - new_height) / 2))
if len(img.shape) == 3:
center_cropped_img = img[:, top:bottom, left:right]
else:
center_cropped_img = img[:, top:bottom, left:right, ...]
return center_cropped_img
def _transform(n_px):
return Compose([
Resize([n_px, n_px], interpolation=T.InterpolationMode.BICUBIC), ])
def regularize_video(video, image_size=256):
min_shape = min(video.shape[1:3])
video = center_crop(video, min_shape, min_shape)
video = torch.from_numpy(video).permute(0, 3, 1, 2)
video = _transform(image_size)(video)
video = video / 255.0 * 2.0 - 1.0
return video.permute(1, 0, 2, 3)
def time_to_indices(video_reader, time):
times = video_reader.get_frame_timestamp(range(len(video_reader))).mean(-1)
indices = np.searchsorted(times, time)
# Use `np.bitwise_or` so it works both with scalars and numpy arrays.
return np.where(np.bitwise_or(indices == 0, times[indices] - time <= time - times[indices - 1]), indices,
indices - 1)
def load_video(video_path, sample_duration=8.0, num_frames=8):
sample_duration = 4.0
num_frames = 4
vr = VideoReader(video_path, ctx=cpu(0))
framerate = vr.get_avg_fps()
video_frame_len = len(vr)
video_len = video_frame_len / framerate
sample_duration = min(sample_duration, video_len)
if video_len > sample_duration:
s = random.random() * (video_len - sample_duration)
t = s + sample_duration
start, end = time_to_indices(vr, [s, t])
end = min(video_frame_len - 1, end)
start = min(start, end - 1)
downsamlp_indices = np.linspace(start, end, num_frames, endpoint=True).astype(int).tolist()
else:
downsamlp_indices = np.linspace(0, video_frame_len - 1, num_frames, endpoint=True).astype(int).tolist()
video = vr.get_batch(downsamlp_indices).asnumpy()
return video
###############
# some helper #
###############
def atomic_save(cfg, net, opt, step, path):
if isinstance(net, (torch.nn.DataParallel,
torch.nn.parallel.DistributedDataParallel)):
netm = net.module
else:
netm = net
sd = netm.state_dict()
slimmed_sd = [(ki, vi) for ki, vi in sd.items()
if ki.find('first_stage_model') != 0 and ki.find('cond_stage_model') != 0]
checkpoint = {
"config": cfg,
"state_dict": OrderedDict(slimmed_sd),
"step": step}
if opt is not None:
checkpoint['optimizer_states'] = opt.state_dict()
import io
import fsspec
bytesbuffer = io.BytesIO()
torch.save(checkpoint, bytesbuffer)
with fsspec.open(path, "wb") as f:
f.write(bytesbuffer.getvalue())
def load_state_dict(net, cfg):
pretrained_pth_full = cfg.get('pretrained_pth_full', None)
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
pretrained_pth = cfg.get('pretrained_pth', None)
pretrained_ckpt = cfg.get('pretrained_ckpt', None)
pretrained_pth_dm = cfg.get('pretrained_pth_dm', None)
pretrained_pth_ema = cfg.get('pretrained_pth_ema', None)
strict_sd = cfg.get('strict_sd', False)
errmsg = "Overlapped model state_dict! This is undesired behavior!"
if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
assert (pretrained_pth is None) and \
(pretrained_ckpt is None) and \
(pretrained_pth_dm is None) and \
(pretrained_pth_ema is None), errmsg
if pretrained_pth_full is not None:
target_file = pretrained_pth_full
sd = torch.load(target_file, map_location='cpu')
assert pretrained_ckpt is None, errmsg
else:
target_file = pretrained_ckpt_full
sd = torch.load(target_file, map_location='cpu')['state_dict']
print('Load full model from [{}] strict [{}].'.format(
target_file, strict_sd))
net.load_state_dict(sd, strict=strict_sd)
if pretrained_pth is not None or pretrained_ckpt is not None:
assert (pretrained_ckpt_full is None) and \
(pretrained_pth_full is None) and \
(pretrained_pth_dm is None) and \
(pretrained_pth_ema is None), errmsg
if pretrained_pth is not None:
target_file = pretrained_pth
sd = torch.load(target_file, map_location='cpu')
assert pretrained_ckpt is None, errmsg
else:
target_file = pretrained_ckpt
sd = torch.load(target_file, map_location='cpu')['state_dict']
print('Load model from [{}] strict [{}].'.format(
target_file, strict_sd))
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
if ki.find('first_stage_model') == 0 or ki.find('cond_stage_model') == 0]
sd.update(OrderedDict(sd_extra))
net.load_state_dict(sd, strict=strict_sd)
if pretrained_pth_dm is not None:
assert (pretrained_ckpt_full is None) and \
(pretrained_pth_full is None) and \
(pretrained_pth is None) and \
(pretrained_ckpt is None), errmsg
print('Load diffusion model from [{}] strict [{}].'.format(
pretrained_pth_dm, strict_sd))
sd = torch.load(pretrained_pth_dm, map_location='cpu')
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
if pretrained_pth_ema is not None:
assert (pretrained_ckpt_full is None) and \
(pretrained_pth_full is None) and \
(pretrained_pth is None) and \
(pretrained_ckpt is None), errmsg
print('Load unet ema model from [{}] strict [{}].'.format(
pretrained_pth_ema, strict_sd))
sd = torch.load(pretrained_pth_ema, map_location='cpu')
net.model_ema.load_state_dict(sd, strict=strict_sd)
def auto_merge_imlist(imlist, max=64):
imlist = imlist[0:max]
h, w = imlist[0].shape[0:2]
num_images = len(imlist)
num_row = int(np.sqrt(num_images))
num_col = num_images // num_row + 1 if num_images % num_row != 0 else num_images // num_row
canvas = np.zeros([num_row * h, num_col * w, 3], dtype=np.uint8)
for idx, im in enumerate(imlist):
hi = (idx // num_col) * h
wi = (idx % num_col) * w
canvas[hi:hi + h, wi:wi + w, :] = im
return canvas
def latent2im(net, latent):
single_input = len(latent.shape) == 3
if single_input:
latent = latent[None]
im = net.decode_image(latent.to(net.device))
im = torch.clamp((im + 1.0) / 2.0, min=0.0, max=1.0)
im = [tvtrans.ToPILImage()(i) for i in im]
if single_input:
im = im[0]
return im
def im2latent(net, im):
single_input = not isinstance(im, list)
if single_input:
im = [im]
im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0)
im = (im * 2 - 1).to(net.device)
z = net.encode_image(im)
if single_input:
z = z[0]
return z
class color_adjust(object):
def __init__(self, ref_from, ref_to):
x0, m0, std0 = self.get_data_and_stat(ref_from)
x1, m1, std1 = self.get_data_and_stat(ref_to)
self.ref_from_stat = (m0, std0)
self.ref_to_stat = (m1, std1)
self.ref_from = self.preprocess(x0).reshape(-1, 3)
self.ref_to = x1.reshape(-1, 3)
def get_data_and_stat(self, x):
if isinstance(x, str):
x = np.array(PIL.Image.open(x))
elif isinstance(x, PIL.Image.Image):
x = np.array(x)
elif isinstance(x, torch.Tensor):
x = torch.clamp(x, min=0.0, max=1.0)
x = np.array(tvtrans.ToPILImage()(x))
elif isinstance(x, np.ndarray):
pass
else:
raise ValueError
x = x.astype(float)
m = np.reshape(x, (-1, 3)).mean(0)
s = np.reshape(x, (-1, 3)).std(0)
return x, m, s
def preprocess(self, x):
m0, s0 = self.ref_from_stat
m1, s1 = self.ref_to_stat
y = ((x - m0) / s0) * s1 + m1
return y
def __call__(self, xin, keep=0, simple=False):
xin, _, _ = self.get_data_and_stat(xin)
x = self.preprocess(xin)
if simple:
y = (x * (1 - keep) + xin * keep)
y = np.clip(y, 0, 255).astype(np.uint8)
return y
h, w = x.shape[:2]
x = x.reshape(-1, 3)
y = []
for chi in range(3):
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
y.append(yi)
y = np.stack(y, axis=1)
y = y.reshape(h, w, 3)
y = (y.astype(float) * (1 - keep) + xin.astype(float) * keep)
y = np.clip(y, 0, 255).astype(np.uint8)
return y
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
arr = np.concatenate((arr_fo, arr_to))
min_v = arr.min() - 1e-6
max_v = arr.max() + 1e-6
min_vto = arr_to.min() - 1e-6
max_vto = arr_to.max() + 1e-6
xs = np.array(
[min_v + (max_v - min_v) * i / n for i in range(n + 1)])
hist_fo, _ = np.histogram(arr_fo, xs)
hist_to, _ = np.histogram(arr_to, xs)
xs = xs[:-1]
# compute probability distribution
cum_fo = np.cumsum(hist_fo)
cum_to = np.cumsum(hist_to)
d_fo = cum_fo / cum_fo[-1]
d_to = cum_to / cum_to[-1]
# transfer
t_d = np.interp(d_fo, d_to, xs)
t_d[d_fo <= d_to[0]] = min_vto
t_d[d_fo >= d_to[-1]] = max_vto
arr_out = np.interp(arr_in, xs, t_d)
return arr_out