import torch import random import numbers from torchvision.transforms import RandomCrop, RandomResizedCrop from PIL import Image from torchvision.utils import _log_api_usage_once def _is_tensor_video_clip(clip): if not torch.is_tensor(clip): raise TypeError("clip should be Tensor. Got %s" % type(clip)) if not clip.ndimension() == 4: raise ValueError("clip should be 4D. Got %dD" % clip.dim()) return True def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) def crop(clip, i, j, h, w): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) """ if len(clip.size()) != 4: raise ValueError("clip should be a 4D tensor") return clip[..., i : i + h, j : j + w] def resize(clip, target_size, interpolation_mode): if len(target_size) != 2: raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) def resize_scale(clip, target_size, interpolation_mode): if len(target_size) != 2: raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") H, W = clip.size(-2), clip.size(-1) scale_ = target_size[0] / min(H, W) return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) def resize_with_scale_factor(clip, scale_factor, interpolation_mode): return torch.nn.functional.interpolate(clip, scale_factor=scale_factor, mode=interpolation_mode, align_corners=False) def resize_scale_with_height(clip, target_size, interpolation_mode): H, W = clip.size(-2), clip.size(-1) scale_ = target_size / H return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) def resize_scale_with_weight(clip, target_size, interpolation_mode): H, W = clip.size(-2), clip.size(-1) scale_ = target_size / W return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): """ Do spatial cropping and resizing to the video clip Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) i (int): i in (i,j) i.e coordinates of the upper left corner. j (int): j in (i,j) i.e coordinates of the upper left corner. h (int): Height of the cropped region. w (int): Width of the cropped region. size (tuple(int, int)): height and width of resized clip Returns: clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") clip = crop(clip, i, j, h, w) clip = resize(clip, size, interpolation_mode) return clip def center_crop(clip, crop_size): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) # print(clip.shape) th, tw = crop_size if h < th or w < tw: # print(h, w) raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) i = int(round((h - th) / 2.0)) j = int(round((w - tw) / 2.0)) return crop(clip, i, j, th, tw), i, j def center_crop_using_short_edge(clip): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) if h < w: th, tw = h, h i = 0 j = int(round((w - tw) / 2.0)) else: th, tw = w, w i = int(round((h - th) / 2.0)) j = 0 return crop(clip, i, j, th, tw) def random_shift_crop(clip): ''' Slide along the long edge, with the short edge as crop size ''' if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) if h <= w: long_edge = w short_edge = h else: long_edge = h short_edge =w th, tw = short_edge, short_edge i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return crop(clip, i, j, th, tw), i, j def random_crop(clip, crop_size): if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) th, tw = crop_size[-2], crop_size[-1] if h < th or w < tw: raise ValueError("height {} and width {} must be no smaller than crop_size".format(h, w)) i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() clip_crop = crop(clip, i, j, th, tw) return clip_crop, i, j def to_tensor(clip): """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) """ _is_tensor_video_clip(clip) if not clip.dtype == torch.uint8: raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) # return clip.float().permute(3, 0, 1, 2) / 255.0 return clip.float() / 255.0 def normalize(clip, mean, std, inplace=False): """ Args: clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) mean (tuple): pixel RGB mean. Size is (3) std (tuple): pixel standard deviation. Size is (3) Returns: normalized clip (torch.tensor): Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") if not inplace: clip = clip.clone() mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) # print(mean) std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) return clip def hflip(clip): """ Args: clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) Returns: flipped clip (torch.tensor): Size is (T, C, H, W) """ if not _is_tensor_video_clip(clip): raise ValueError("clip should be a 4D torch.tensor") return clip.flip(-1) class RandomCropVideo: def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: randomly cropped video clip. size is (T, C, OH, OW) """ i, j, h, w = self.get_params(clip) return crop(clip, i, j, h, w) def get_params(self, clip): h, w = clip.shape[-2:] th, tw = self.size if h < th or w < tw: raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") if w == tw and h == th: return 0, 0, h, w i = torch.randint(0, h - th + 1, size=(1,)).item() j = torch.randint(0, w - tw + 1, size=(1,)).item() return i, j, th, tw def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size})" class CenterCropResizeVideo: ''' First use the short side for cropping length, center crop video, then resize to the specified size ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ # print(clip.shape) clip_center_crop = center_crop_using_short_edge(clip) # print(clip_center_crop.shape) 320 512 clip_center_crop_resize = resize(clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode) return clip_center_crop_resize def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class SDXL: def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ # add aditional one pixel for avoiding error in center crop ori_h, ori_w = clip.size(-2), clip.size(-1) tar_h, tar_w = self.size[0] + 1, self.size[1] + 1 # if ori_h >= tar_h and ori_w >= tar_w: # clip_tar_crop, i, j = random_crop(clip=clip, crop_size=self.size) # else: # tar_h_div_ori_h = tar_h / ori_h # tar_w_div_ori_w = tar_w / ori_w # if tar_h_div_ori_h > tar_w_div_ori_w: # clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) # else: # clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) # clip_tar_crop, i, j = random_crop(clip, self.size) if ori_h >= tar_h and ori_w >= tar_w: tar_h_div_ori_h = tar_h / ori_h tar_w_div_ori_w = tar_w / ori_w if tar_h_div_ori_h > tar_w_div_ori_w: clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) else: clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) ori_h, ori_w = clip.size(-2), clip.size(-1) clip_tar_crop, i, j = random_crop(clip, self.size) else: tar_h_div_ori_h = tar_h / ori_h tar_w_div_ori_w = tar_w / ori_w if tar_h_div_ori_h > tar_w_div_ori_w: clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) else: clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) clip_tar_crop, i, j = random_crop(clip, self.size) return clip_tar_crop, ori_h, ori_w, i, j def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class SDXLCenterCrop: def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ # add aditional one pixel for avoiding error in center crop ori_h, ori_w = clip.size(-2), clip.size(-1) tar_h, tar_w = self.size[0] + 1, self.size[1] + 1 tar_h_div_ori_h = tar_h / ori_h tar_w_div_ori_w = tar_w / ori_w # print('before resize', clip.shape) if tar_h_div_ori_h > tar_w_div_ori_w: clip = resize_with_scale_factor(clip=clip, scale_factor=tar_h_div_ori_h, interpolation_mode=self.interpolation_mode) # print('after h resize', clip.shape) else: clip = resize_with_scale_factor(clip=clip, scale_factor=tar_w_div_ori_w, interpolation_mode=self.interpolation_mode) # print('after resize', clip.shape) # print(clip.shape) # clip_tar_crop, i, j = random_crop(clip, self.size) clip_tar_crop, i, j = center_crop(clip, self.size) # print('after crop', clip_tar_crop.shape) return clip_tar_crop, ori_h, ori_w, i, j def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class InternVideo320512: def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ # add aditional one pixel for avoiding error in center crop h, w = clip.size(-2), clip.size(-1) # print('before resize', clip.shape) if h < 320: clip = resize_scale_with_height(clip=clip, target_size=321, interpolation_mode=self.interpolation_mode) # print('after h resize', clip.shape) if w < 512: clip = resize_scale_with_weight(clip=clip, target_size=513, interpolation_mode=self.interpolation_mode) # print('after w resize', clip.shape) # print(clip.shape) clip_center_crop = center_crop(clip, self.size) clip_center_crop_no_subtitles = center_crop(clip, (220, 352)) clip_center_resize = resize(clip_center_crop_no_subtitles, target_size=self.size, interpolation_mode=self.interpolation_mode) # print(clip_center_crop.shape) return clip_center_resize def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class CenterCropVideo: ''' First scale to the specified size in equal proportion to the short edge, then center cropping ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) clip_center_crop = center_crop(clip_resize, self.size) return clip_center_crop def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class KineticsRandomCropResizeVideo: ''' Slide along the long edge, with the short edge as crop size. And resie to the desired size. ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): clip_random_crop = random_shift_crop(clip) clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) return clip_resize class ResizeVideo(): ''' First use the short side for cropping length, center crop video, then resize to the specified size ''' def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: scale resized / center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) return clip_resize def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class CenterCropVideo: def __init__( self, size, interpolation_mode="bilinear", ): if isinstance(size, tuple): if len(size) != 2: raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) self.interpolation_mode = interpolation_mode def __call__(self, clip): """ Args: clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) Returns: torch.tensor: center cropped video clip. size is (T, C, crop_size, crop_size) """ clip_center_crop = center_crop(clip, self.size) return clip_center_crop def __repr__(self) -> str: return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" class NormalizeVideo: """ Normalize the video clip by mean subtraction and division by standard deviation Args: mean (3-tuple): pixel RGB mean std (3-tuple): pixel RGB standard deviation inplace (boolean): whether do in-place normalization """ def __init__(self, mean, std, inplace=False): self.mean = mean self.std = std self.inplace = inplace def __call__(self, clip): """ Args: clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) """ return normalize(clip, self.mean, self.std, self.inplace) def __repr__(self) -> str: return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" class ToTensorVideo: """ Convert tensor data type from uint8 to float, divide value by 255.0 and permute the dimensions of clip tensor """ def __init__(self): pass def __call__(self, clip): """ Args: clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) Return: clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) """ return to_tensor(clip) def __repr__(self) -> str: return self.__class__.__name__ class RandomHorizontalFlipVideo: """ Flip the video clip along the horizontal direction with a given probability Args: p (float): probability of the clip being flipped. Default value is 0.5 """ def __init__(self, p=0.5): self.p = p def __call__(self, clip): """ Args: clip (torch.tensor): Size is (T, C, H, W) Return: clip (torch.tensor): Size is (T, C, H, W) """ if random.random() < self.p: clip = hflip(clip) return clip def __repr__(self) -> str: return f"{self.__class__.__name__}(p={self.p})" class Compose: """Composes several transforms together. This transform does not support torchscript. Please, see the note below. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. """ def __init__(self, transforms): if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(self) self.transforms = transforms def __call__(self, img): for t in self.transforms: if isinstance(t, SDXLCenterCrop) or isinstance(t, SDXL): img, ori_h, ori_w, crops_coords_top, crops_coords_left = t(img) else: img = t(img) return img, ori_h, ori_w, crops_coords_top, crops_coords_left def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" format_string += f" {t}" format_string += "\n)" return format_string # ------------------------------------------------------------ # --------------------- Sampling --------------------------- # ------------------------------------------------------------ class TemporalRandomCrop(object): """Temporally crop the given frame indices at a random location. Args: size (int): Desired length of frames will be seen in the model. """ def __init__(self, size): self.size = size def __call__(self, total_frames): rand_end = max(0, total_frames - self.size - 1) begin_index = random.randint(0, rand_end) end_index = min(begin_index + self.size, total_frames) return begin_index, end_index if __name__ == '__main__': from torchvision import transforms import torchvision.io as io import numpy as np from torchvision.utils import save_image import os vframes, aframes, info = io.read_video( filename='./v_Archery_g01_c03.avi', pts_unit='sec', output_format='TCHW' ) trans = transforms.Compose([ ToTensorVideo(), RandomHorizontalFlipVideo(), UCFCenterCropVideo(512), # NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) ]) target_video_len = 32 frame_interval = 1 total_frames = len(vframes) print(total_frames) temporal_sample = TemporalRandomCrop(target_video_len * frame_interval) # Sampling video frames start_frame_ind, end_frame_ind = temporal_sample(total_frames) # print(start_frame_ind) # print(end_frame_ind) assert end_frame_ind - start_frame_ind >= target_video_len frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int) print(frame_indice) select_vframes = vframes[frame_indice] print(select_vframes.shape) print(select_vframes.dtype) select_vframes_trans = trans(select_vframes) print(select_vframes_trans.shape) print(select_vframes_trans.dtype) select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8) print(select_vframes_trans_int.dtype) print(select_vframes_trans_int.permute(0, 2, 3, 1).shape) io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8) for i in range(target_video_len): save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True, value_range=(-1, 1))