import os import torch import os import math import torch import logging import random import subprocess import numpy as np import torch.distributed as dist # from torch._six import inf from torch import inf from PIL import Image from typing import Union, Iterable from collections import OrderedDict from torch.utils.tensorboard import SummaryWriter from diffusers.utils import is_bs4_available, is_ftfy_available import html import re import urllib.parse as ul if is_bs4_available(): from bs4 import BeautifulSoup if is_ftfy_available(): import ftfy _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] def find_model(model_name): """ Finds a pre-trained Latte model, downloading it if necessary. Alternatively, loads a model from a local path. """ assert os.path.isfile(model_name), f'Could not find Latte checkpoint at {model_name}' checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) # if "ema" in checkpoint: # supports checkpoints from train.py # print('Using Ema!') # checkpoint = checkpoint["ema"] # else: print('Using model!') checkpoint = checkpoint['model'] return checkpoint ################################################################################# # Training Clip Gradients # ################################################################################# def get_grad_norm( parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: r""" Copy from torch.nn.utils.clip_grad_norm_ Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) Returns: Total norm of the parameter gradients (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] grads = [p.grad for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(grads) == 0: return torch.tensor(0.) device = grads[0].device if norm_type == inf: norms = [g.detach().abs().max().to(device) for g in grads] total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) else: total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) return total_norm def clip_grad_norm_( parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, error_if_nonfinite: bool = False, clip_grad=True) -> torch.Tensor: r""" Copy from torch.nn.utils.clip_grad_norm_ Clips gradient norm of an iterable of parameters. The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) Returns: Total norm of the parameter gradients (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] grads = [p.grad for p in parameters if p.grad is not None] max_norm = float(max_norm) norm_type = float(norm_type) if len(grads) == 0: return torch.tensor(0.) device = grads[0].device if norm_type == inf: norms = [g.detach().abs().max().to(device) for g in grads] total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) else: total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) if clip_grad: if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): raise RuntimeError( f'The total norm of order {norm_type} for gradients from ' '`parameters` is non-finite, so it cannot be clipped. To disable ' 'this error and scale the gradients by the non-finite norm anyway, ' 'set `error_if_nonfinite=False`') clip_coef = max_norm / (total_norm + 1e-6) # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization # when the gradients do not reside in CPU memory. clip_coef_clamped = torch.clamp(clip_coef, max=1.0) for g in grads: g.detach().mul_(clip_coef_clamped.to(g.device)) # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) # print(gradient_cliped) return total_norm def get_experiment_dir(root_dir, args): # if args.pretrained is not None and 'Latte-XL-2-256x256.pt' not in args.pretrained: # root_dir += '-WOPRE' if args.use_compile: root_dir += '-Compile' # speedup by torch compile if args.attention_mode: root_dir += f'-{args.attention_mode.upper()}' # if args.enable_xformers_memory_efficient_attention: # root_dir += '-Xfor' if args.gradient_checkpointing: root_dir += '-Gc' if args.mixed_precision: root_dir += f'-{args.mixed_precision.upper()}' root_dir += f'-{args.max_image_size}' return root_dir def get_precision(args): if args.mixed_precision == "bf16": dtype = torch.bfloat16 elif args.mixed_precision == "fp16": dtype = torch.float16 else: dtype = torch.float32 return dtype ################################################################################# # Training Logger # ################################################################################# def create_logger(logging_dir): """ Create a logger that writes to a log file and stdout. """ if dist.get_rank() == 0: # real logger logging.basicConfig( level=logging.INFO, # format='[\033[34m%(asctime)s\033[0m] %(message)s', format='[%(asctime)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] ) logger = logging.getLogger(__name__) else: # dummy logger (does nothing) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) return logger def create_tensorboard(tensorboard_dir): """ Create a tensorboard that saves losses. """ if dist.get_rank() == 0: # real tensorboard # tensorboard writer = SummaryWriter(tensorboard_dir) return writer def write_tensorboard(writer, *args): ''' write the loss information to a tensorboard file. Only for pytorch DDP mode. ''' if dist.get_rank() == 0: # real tensorboard writer.add_scalar(args[0], args[1], args[2]) ################################################################################# # EMA Update/ DDP Training Utils # ################################################################################# @torch.no_grad() def update_ema(ema_model, model, decay=0.9999): """ Step the EMA model towards the current model. """ ema_params = OrderedDict(ema_model.named_parameters()) model_params = OrderedDict(model.named_parameters()) for name, param in model_params.items(): # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) def requires_grad(model, flag=True): """ Set requires_grad flag for all parameters in a model. """ for p in model.parameters(): p.requires_grad = flag def cleanup(): """ End DDP training. """ dist.destroy_process_group() def setup_distributed(backend="nccl", port=None): """Initialize distributed training environment. support both slurm and torch.distributed.launch see torch.distributed.init_process_group() for more details """ num_gpus = torch.cuda.device_count() if "SLURM_JOB_ID" in os.environ: rank = int(os.environ["SLURM_PROCID"]) world_size = int(os.environ["SLURM_NTASKS"]) node_list = os.environ["SLURM_NODELIST"] addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") # specify master port if port is not None: os.environ["MASTER_PORT"] = str(port) elif "MASTER_PORT" not in os.environ: # os.environ["MASTER_PORT"] = "29566" os.environ["MASTER_PORT"] = str(29567 + num_gpus) if "MASTER_ADDR" not in os.environ: os.environ["MASTER_ADDR"] = addr os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_RANK"] = str(rank % num_gpus) os.environ["RANK"] = str(rank) else: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) # torch.cuda.set_device(rank % num_gpus) dist.init_process_group( backend=backend, world_size=world_size, rank=rank, ) ################################################################################# # Testing Utils # ################################################################################# def save_video_grid(video, nrow=None): b, t, h, w, c = video.shape if nrow is None: nrow = math.ceil(math.sqrt(b)) ncol = math.ceil(b / nrow) padding = 1 video_grid = torch.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype=torch.uint8) print(video_grid.shape) for i in range(b): r = i // ncol c = i % ncol start_r = (padding + h) * r start_c = (padding + w) * c video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] return video_grid ################################################################################# # MMCV Utils # ################################################################################# def collect_env(): # Copyright (c) OpenMMLab. All rights reserved. from mmcv.utils import collect_env as collect_base_env from mmcv.utils import get_git_hash """Collect the information of the running environments.""" env_info = collect_base_env() env_info['MMClassification'] = get_git_hash()[:7] for name, val in env_info.items(): print(f'{name}: {val}') print(torch.cuda.get_arch_list()) print(torch.version.cuda) ################################################################################# # Pixart-alpha Utils # ################################################################################# bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa def text_preprocessing(text): # The exact text cleaning as was in the training stage: text = clean_caption(text) text = clean_caption(text) return text def basic_clean(text): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def clean_caption(caption): caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() caption = re.sub('', 'person', caption) # urls: caption = re.sub( r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa '', caption) # regex for urls caption = re.sub( r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa '', caption) # regex for urls # html: caption = BeautifulSoup(caption, features='html.parser').text # @ caption = re.sub(r'@[\w\d]+\b', '', caption) # 31C0—31EF CJK Strokes # 31F0—31FF Katakana Phonetic Extensions # 3200—32FF Enclosed CJK Letters and Months # 3300—33FF CJK Compatibility # 3400—4DBF CJK Unified Ideographs Extension A # 4DC0—4DFF Yijing Hexagram Symbols # 4E00—9FFF CJK Unified Ideographs caption = re.sub(r'[\u31c0-\u31ef]+', '', caption) caption = re.sub(r'[\u31f0-\u31ff]+', '', caption) caption = re.sub(r'[\u3200-\u32ff]+', '', caption) caption = re.sub(r'[\u3300-\u33ff]+', '', caption) caption = re.sub(r'[\u3400-\u4dbf]+', '', caption) caption = re.sub(r'[\u4dc0-\u4dff]+', '', caption) caption = re.sub(r'[\u4e00-\u9fff]+', '', caption) ####################################################### # все виды тире / all types of dash --> "-" caption = re.sub( r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa '-', caption) # кавычки к одному стандарту caption = re.sub(r'[`´«»“”¨]', '"', caption) caption = re.sub(r'[‘’]', "'", caption) # " caption = re.sub(r'"?', '', caption) # & caption = re.sub(r'&', '', caption) # ip adresses: caption = re.sub(r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}', ' ', caption) # article ids: caption = re.sub(r'\d:\d\d\s+$', '', caption) # \n caption = re.sub(r'\\n', ' ', caption) # "#123" caption = re.sub(r'#\d{1,3}\b', '', caption) # "#12345.." caption = re.sub(r'#\d{5,}\b', '', caption) # "123456.." caption = re.sub(r'\b\d{6,}\b', '', caption) # filenames: caption = re.sub(r'[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)', '', caption) # caption = re.sub(r'[\"\']{2,}', r'"', caption) # """AUSVERKAUFT""" caption = re.sub(r'[\.]{2,}', r' ', caption) # """AUSVERKAUFT""" caption = re.sub(bad_punct_regex, r' ', caption) # ***AUSVERKAUFT***, #AUSVERKAUFT caption = re.sub(r'\s+\.\s+', r' ', caption) # " . " # this-is-my-cute-cat / this_is_my_cute_cat regex2 = re.compile(r'(?:\-|\_)') if len(re.findall(regex2, caption)) > 3: caption = re.sub(regex2, ' ', caption) caption = basic_clean(caption) caption = re.sub(r'\b[a-zA-Z]{1,3}\d{3,15}\b', '', caption) # jc6640 caption = re.sub(r'\b[a-zA-Z]+\d+[a-zA-Z]+\b', '', caption) # jc6640vc caption = re.sub(r'\b\d+[a-zA-Z]+\d+\b', '', caption) # 6640vc231 caption = re.sub(r'(worldwide\s+)?(free\s+)?shipping', '', caption) caption = re.sub(r'(free\s)?download(\sfree)?', '', caption) caption = re.sub(r'\bclick\b\s(?:for|on)\s\w+', '', caption) caption = re.sub(r'\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?', '', caption) caption = re.sub(r'\bpage\s+\d+\b', '', caption) caption = re.sub(r'\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b', r' ', caption) # j2d1a2a... caption = re.sub(r'\b\d+\.?\d*[xх×]\d+\.?\d*\b', '', caption) caption = re.sub(r'\b\s+\:\s+', r': ', caption) caption = re.sub(r'(\D[,\./])\b', r'\1 ', caption) caption = re.sub(r'\s+', ' ', caption) caption.strip() caption = re.sub(r'^[\"\']([\w\W]+)[\"\']$', r'\1', caption) caption = re.sub(r'^[\'\_,\-\:;]', r'', caption) caption = re.sub(r'[\'\_,\-\:\-\+]$', r'', caption) caption = re.sub(r'^\.\S+$', '', caption) return caption.strip()