Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import math | |
import torch | |
import logging | |
import subprocess | |
import numpy as np | |
import torch.distributed as dist | |
# from torch._six import inf | |
from torch import inf | |
from PIL import Image | |
from typing import Union, Iterable | |
from collections import OrderedDict | |
from torch.utils.tensorboard import SummaryWriter | |
from typing import Dict | |
import torch_dct | |
from diffusers.utils import is_bs4_available, is_ftfy_available | |
import html | |
import re | |
import urllib.parse as ul | |
if is_bs4_available(): | |
from bs4 import BeautifulSoup | |
if is_ftfy_available(): | |
import ftfy | |
import torch.fft as fft | |
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] | |
################################################################################# | |
# Testing Utils # | |
################################################################################# | |
def find_model(model_name): | |
""" | |
Finds a pre-trained model | |
""" | |
assert os.path.isfile(model_name), f'Could not find DiT checkpoint at {model_name}' | |
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) | |
if "ema" in checkpoint: # supports checkpoints from train.py | |
print('Using ema ckpt!') | |
checkpoint = checkpoint["ema"] | |
else: | |
checkpoint = checkpoint["model"] | |
print("Using model ckpt!") | |
return checkpoint | |
def save_video_grid(video, nrow=None): | |
b, t, h, w, c = video.shape | |
if nrow is None: | |
nrow = math.ceil(math.sqrt(b)) | |
ncol = math.ceil(b / nrow) | |
padding = 1 | |
video_grid = torch.zeros((t, (padding + h) * nrow + padding, | |
(padding + w) * ncol + padding, c), dtype=torch.uint8) | |
# print(video_grid.shape) | |
for i in range(b): | |
r = i // ncol | |
c = i % ncol | |
start_r = (padding + h) * r | |
start_c = (padding + w) * c | |
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] | |
return video_grid | |
def save_videos_grid_tav(videos: torch.Tensor, path: str, rescale=False, nrow=None, fps=8): | |
from einops import rearrange | |
import imageio | |
import torchvision | |
b, _, _, _, _ = videos.shape | |
if nrow is None: | |
nrow = math.ceil(math.sqrt(b)) | |
videos = rearrange(videos, "b c t h w -> t b c h w") | |
outputs = [] | |
for x in videos: | |
x = torchvision.utils.make_grid(x, nrow=nrow) | |
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
if rescale: | |
x = (x + 1.0) / 2.0 # -1,1 -> 0,1 | |
x = (x * 255).numpy().astype(np.uint8) | |
outputs.append(x) | |
# os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.mimsave(path, outputs, fps=fps) | |
################################################################################# | |
# MMCV Utils # | |
################################################################################# | |
def collect_env(): | |
# Copyright (c) OpenMMLab. All rights reserved. | |
from mmcv.utils import collect_env as collect_base_env | |
from mmcv.utils import get_git_hash | |
"""Collect the information of the running environments.""" | |
env_info = collect_base_env() | |
env_info['MMClassification'] = get_git_hash()[:7] | |
for name, val in env_info.items(): | |
print(f'{name}: {val}') | |
print(torch.cuda.get_arch_list()) | |
print(torch.version.cuda) | |
################################################################################# | |
# DCT Functions # | |
################################################################################# | |
def dct_low_pass_filter(dct_coefficients, percentage=0.3): # 2d [b c f h w] | |
""" | |
Applies a low pass filter to the given DCT coefficients. | |
:param dct_coefficients: 2D tensor of DCT coefficients | |
:param percentage: percentage of coefficients to keep (between 0 and 1) | |
:return: 2D tensor of DCT coefficients after applying the low pass filter | |
""" | |
# Determine the cutoff indices for both dimensions | |
cutoff_x = int(dct_coefficients.shape[-2] * percentage) | |
cutoff_y = int(dct_coefficients.shape[-1] * percentage) | |
# Create a mask with the same shape as the DCT coefficients | |
mask = torch.zeros_like(dct_coefficients) | |
# Set the top-left corner of the mask to 1 (the low-frequency area) | |
mask[:, :, :, :cutoff_x, :cutoff_y] = 1 | |
return mask | |
def normalize(tensor): | |
"""将Tensor归一化到[0, 1]范围内。""" | |
min_val = tensor.min() | |
max_val = tensor.max() | |
normalized = (tensor - min_val) / (max_val - min_val) | |
return normalized | |
def denormalize(tensor, max_val_target, min_val_target): | |
"""将Tensor从[0, 1]范围反归一化到目标的[min_val_target, max_val_target]范围。""" | |
denormalized = tensor * (max_val_target - min_val_target) + min_val_target | |
return denormalized | |
def exchanged_mixed_dct_freq(noise, base_content, LPF_3d, normalized=False): | |
# noise dct | |
noise_freq = torch_dct.dct_3d(noise, 'ortho') | |
# frequency | |
HPF_3d = 1 - LPF_3d | |
noise_freq_high = noise_freq * HPF_3d | |
# base frame dct | |
base_content_freq = torch_dct.dct_3d(base_content, 'ortho') | |
# base content low frequency | |
base_content_freq_low = base_content_freq * LPF_3d | |
# mixed frequency | |
mixed_freq = base_content_freq_low + noise_freq_high | |
# idct | |
mixed_freq = torch_dct.idct_3d(mixed_freq, 'ortho') | |
return mixed_freq |