GaussianAnything-AIGC3D / scripts /vit_triplane_sit_train.py
yslan's picture
init
7f51798
"""
Train a diffusion model on images.
"""
import json
import sys
import os
sys.path.append('.')
# from dnnlib import EasyDict
import traceback
import torch as th
# from xformers.triton import FusedLayerNorm as LayerNorm
import torch.multiprocessing as mp
import torch.distributed as dist
import numpy as np
import argparse
import dnnlib
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util import (
args_to_dict,
add_dict_to_argparser,
continuous_diffusion_defaults,
control_net_defaults,
model_and_diffusion_defaults,
create_model_and_diffusion,
)
from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
import nsr
import nsr.lsgm
# from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
from datasets.eg3d_dataset import LMDBDataset_MV_Compressed_eg3d
from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, rendering_options_defaults, eg3d_options_default, dataset_defaults
from datasets.shapenet import load_data, load_eval_data, load_memory_data
from nsr.losses.builder import E3DGELossClass
from torch_utils import legacy, misc
from torch.utils.data import Subset
from pdb import set_trace as st
from dnnlib.util import EasyDict, InfiniteSampler
# from .vit_triplane_train_FFHQ import init_dataset_kwargs
from datasets.eg3d_dataset import init_dataset_kwargs
th.backends.cudnn.enabled = True # https://zhuanlan.zhihu.com/p/635824460
th.backends.cudnn.benchmark = True
from transport import create_transport, Sampler
from transport.train_utils import parse_transport_args
# from torch.utils.tensorboard import SummaryWriter
SEED = 0
def training_loop(args):
# def training_loop(args):
logger.log("dist setup...")
# th.multiprocessing.set_start_method('spawn')
th.autograd.set_detect_anomaly(False) # type: ignore
# th.autograd.set_detect_anomaly(True) # type: ignore
# st()
th.cuda.set_device(
args.local_rank) # set this line to avoid extra memory on rank 0
th.cuda.empty_cache()
th.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
dist_util.setup_dist(args)
# st() # mark
th.backends.cuda.matmul.allow_tf32 = args.allow_tf32
th.backends.cudnn.allow_tf32 = args.allow_tf32
# st()
# logger.configure(dir=args.logdir, format_strs=["tensorboard", "csv"])
logger.configure(dir=args.logdir)
logger.log("creating ViT encoder and NSR decoder...")
# st() # mark
device = dist_util.dev()
args.img_size = [args.image_size_encoder]
logger.log("creating model and diffusion...")
# * set denoise model args
if args.denoise_in_channels == -1:
args.diffusion_input_size = args.image_size_encoder
args.denoise_in_channels = args.out_chans
args.denoise_out_channels = args.out_chans
else:
assert args.denoise_out_channels != -1
# args.image_size = args.image_size_encoder # 224, follow the triplane size
# if args.diffusion_input_size == -1:
# else:
# args.image_size = args.diffusion_input_size
if args.pred_type == 'v': # for lsgm training
assert args.predict_v == True # for DDIM sampling
# if not args.create_dit:
denoise_model, diffusion = create_model_and_diffusion(
**args_to_dict(args,
model_and_diffusion_defaults().keys()))
opts = eg3d_options_default()
if args.sr_training:
args.sr_kwargs = dnnlib.EasyDict(
channel_base=opts.cbase,
channel_max=opts.cmax,
fused_modconv_default='inference_only',
use_noise=True
) # ! close noise injection? since noise_mode='none' in eg3d
logger.log("creating encoder and NSR decoder...")
auto_encoder = create_3DAE_model(
**args_to_dict(args,
encoder_and_nsr_defaults().keys()))
auto_encoder.to(device)
auto_encoder.eval()
# * load G_ema modules into autoencoder
# * clone G_ema.decoder to auto_encoder triplane
# logger.log("AE triplane decoder reuses G_ema decoder...")
# auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)
# auto_encoder.decoder.triplane_decoder.decoder.load_state_dict( # type: ignore
# G_ema.decoder.state_dict()) # type: ignore
# set grad=False in this manner suppresses the DDP forward no grad error.
# if args.sr_training:
# logger.log("AE triplane decoder reuses G_ema SR module...")
# # auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
# # G_ema.superresolution.state_dict()) # type: ignore
# # set grad=False in this manner suppresses the DDP forward no grad error.
# logger.log("freeze SR module...")
# for param in auto_encoder.decoder.superresolution.parameters(): # type: ignore
# param.requires_grad_(False)
# # del G_ema
# th.cuda.empty_cache()
if args.freeze_triplane_decoder:
logger.log("freeze triplane decoder...")
for param in auto_encoder.decoder.triplane_decoder.parameters(
): # type: ignore
# for param in auto_encoder.decoder.triplane_decoder.decoder.parameters(): # type: ignore
param.requires_grad_(False)
if args.cfg in ('afhq', 'ffhq'):
if args.sr_training:
logger.log("AE triplane decoder reuses G_ema SR module...")
auto_encoder.decoder.triplane_decoder.superresolution.load_state_dict( # type: ignore
G_ema.superresolution.state_dict()) # type: ignore
# set grad=False in this manner suppresses the DDP forward no grad error.
for param in auto_encoder.decoder.triplane_decoder.superresolution.parameters(
): # type: ignore
param.requires_grad_(False)
# ! load data
if args.use_lmdb:
logger.log("creating LMDB eg3d data loader...")
training_set = LMDBDataset_MV_Compressed_eg3d(
args.data_dir,
args.image_size,
args.image_size_encoder,
)
else:
logger.log("creating eg3d data loader...")
training_set_kwargs, dataset_name = init_dataset_kwargs(
data=args.data_dir,
class_name='datasets.eg3d_dataset.ImageFolderDataset',
reso_gt=args.image_size) # only load pose here
# if args.cond and not training_set_kwargs.use_labels:
# raise Exception('check here')
# training_set_kwargs.use_labels = args.cond
training_set_kwargs.use_labels = True
training_set_kwargs.xflip = False
training_set_kwargs.random_seed = SEED
training_set_kwargs.max_size = args.dataset_size
# desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'
# * construct ffhq/afhq dataset
training_set = dnnlib.util.construct_class_by_name(
**training_set_kwargs) # subclass of training.dataset.Dataset
training_set_sampler = InfiniteSampler(
dataset=training_set,
rank=dist_util.get_rank(),
num_replicas=dist_util.get_world_size(),
seed=SEED)
data = iter(
th.utils.data.DataLoader(
dataset=training_set,
sampler=training_set_sampler,
batch_size=args.batch_size,
pin_memory=True,
num_workers=args.num_workers,
persistent_workers=args.num_workers > 0,
prefetch_factor=max(8 // args.batch_size, 2),
))
# prefetch_factor=2))
eval_data = th.utils.data.DataLoader(dataset=Subset(
training_set, np.arange(8)),
batch_size=args.eval_batch_size,
num_workers=1)
else:
logger.log("creating data loader...")
if args.objv_dataset:
from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_wds_data, load_data_cls
else: # shapenet
from datasets.shapenet import load_data, load_eval_data, load_memory_data
# TODO, load shapenet data
# data = load_data(
# st() mark
# if args.overfitting:
# logger.log("create overfitting memory dataset")
# data = load_memory_data(
# file_path=args.eval_data_dir,
# batch_size=args.batch_size,
# reso=args.image_size,
# reso_encoder=args.image_size_encoder, # 224 -> 128
# num_workers=args.num_workers,
# load_depth=True # for evaluation
# )
# else:
if args.use_wds:
if args.data_dir == 'NONE':
with open(args.shards_lst) as f:
shards_lst = [url.strip() for url in f.readlines()]
data = load_wds_data(
shards_lst, args.image_size, args.image_size_encoder,
args.batch_size, args.num_workers,
**args_to_dict(args,
dataset_defaults().keys()))
else:
data = load_wds_data(
args.data_dir, args.image_size, args.image_size_encoder,
args.batch_size, args.num_workers,
**args_to_dict(args,
dataset_defaults().keys()))
# eval_data = load_wds_data(
# args.data_dir,
# args.image_size,
# args.image_size_encoder,
# args.eval_batch_size,
# args.num_workers,
# decode_encode_img_only=args.decode_encode_img_only,
# load_wds_diff=args.load_wds_diff)
if args.eval_data_dir == 'NONE':
with open(args.eval_shards_lst) as f:
eval_shards_lst = [url.strip() for url in f.readlines()]
else:
eval_shards_lst = args.eval_data_dir # auto expanded
eval_data = load_wds_data(
eval_shards_lst,
args.image_size,
args.image_size_encoder,
args.eval_batch_size,
args.num_workers,
plucker_embedding=args.plucker_embedding,
decode_encode_img_only=args.decode_encode_img_only,
mv_input=args.mv_input,
load_wds_diff=False,
load_instance=True)
else:
logger.log("create all instances dataset")
# st() mark
# data = load_data(
# file_path=args.data_dir,
# batch_size=args.batch_size,
# reso=args.image_size,
# reso_encoder=args.image_size_encoder, # 224 -> 128
# num_workers=args.num_workers,
# load_depth=args.load_depth,
# preprocess=auto_encoder.preprocess, # clip
# dataset_size=args.dataset_size,
# use_lmdb=args.use_lmdb,
# trainer_name=args.trainer_name,
# # load_depth=True # for evaluation
# )
# eval_data = load_eval_data(
# file_path=args.eval_data_dir,
# batch_size=args.eval_batch_size,
# reso=args.image_size,
# reso_encoder=args.image_size_encoder, # 224 -> 128
# num_workers=args.num_workers,
# load_depth=True, # for evaluation
# interval=args.interval,
# use_lmdb=args.use_lmdb,
# )
data = load_data(
file_path=args.data_dir,
batch_size=args.batch_size,
reso=args.image_size,
reso_encoder=args.image_size_encoder, # 224 -> 128
num_workers=args.num_workers,
load_latent=True,
**args_to_dict(args,
dataset_defaults().keys())
# load_depth=args.load_depth,
# preprocess=auto_encoder.preprocess, # clip
# dataset_size=args.dataset_size,
# use_lmdb=args.use_lmdb,
# trainer_name=args.trainer_name,
# load_depth=True # for evaluation
)
# eval_dataset = load_data_cls(
# file_path=args.data_dir,
# batch_size=args.batch_size,
# reso=args.image_size,
# reso_encoder=args.image_size_encoder, # 224 -> 128
# num_workers=args.num_workers,
# load_latent=True,
# return_dataset=True,
# **args_to_dict(args,
# dataset_defaults().keys())
# # load_depth=args.load_depth,
# # preprocess=auto_encoder.preprocess, # clip
# # dataset_size=args.dataset_size,
# # use_lmdb=args.use_lmdb,
# # trainer_name=args.trainer_name,
# # load_depth=True # for evaluation
# )
eval_dataset = None
# st()
# eval_data = data
# eval_data = load_eval_data(
# file_path=args.eval_data_dir,
# batch_size=args.eval_batch_size,
# reso=args.image_size,
# reso_encoder=args.image_size_encoder, # 224 -> 128
# num_workers=args.num_workers,
# load_depth=True, # for evaluation
# interval=args.interval,
# use_lmdb=args.use_lmdb,
# )
# let all processes sync up before starting with a new epoch of training
if dist_util.get_rank() == 0:
with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
json.dump(vars(args), f, indent=2)
args.schedule_sampler = create_named_schedule_sampler(
args.schedule_sampler, diffusion)
opt = dnnlib.EasyDict(args_to_dict(args, loss_defaults().keys()))
loss_class = E3DGELossClass(device, opt).to(device)
logger.log("training...")
TrainLoop = {
'flow_matching':
nsr.lsgm.flow_matching_trainer.FlowMatchingEngine,
'flow_matching_gs':
nsr.lsgm.flow_matching_trainer.FlowMatchingEngine_gs, # slightly modified sampling and rendering for gs
'flow_matching_gs_clay':
nsr.lsgm.flow_matching_trainer.FlowMatchingEngine_gs_clay, # slightly modified sampling and rendering for gs
}[args.trainer_name]
if 'vpsde' in args.trainer_name:
sde_diffusion = make_sde_diffusion(
dnnlib.EasyDict(
args_to_dict(args,
continuous_diffusion_defaults().keys())))
# assert args.mixed_prediction, 'enable mixed_prediction by default'
logger.log('create VPSDE diffusion.')
else:
sde_diffusion = None
if 'cldm' in args.trainer_name:
assert isinstance(denoise_model, tuple)
denoise_model, controlNet = denoise_model
controlNet.to(dist_util.dev())
controlNet.train()
else:
controlNet = None
# st()
denoise_model.to(dist_util.dev())
denoise_model.train()
auto_encoder.decoder.rendering_kwargs = args.rendering_kwargs
TrainLoop(rec_model=auto_encoder,
denoise_model=denoise_model,
control_model=controlNet,
diffusion=diffusion,
sde_diffusion=sde_diffusion,
loss_class=loss_class,
data=data,
eval_data=None, # return dataset
# eval_data=eval_dataset, # return dataset
# eval_data=eval_data, # return dataset
**vars(args)).run_loop()
dist_util.synchronize()
def create_argparser(**kwargs):
# defaults.update(model_and_diffusion_defaults())
defaults = dict(
dataset_size=-1,
diffusion_input_size=-1,
trainer_name='adm',
use_amp=False,
train_vae=True, # jldm?
triplane_scaling_divider=1.0, # divide by this value
overfitting=False,
num_workers=4,
image_size=128,
image_size_encoder=224,
iterations=150000,
schedule_sampler="uniform",
anneal_lr=False,
lr=5e-5,
weight_decay=0.0,
lr_anneal_steps=0,
batch_size=1,
eval_batch_size=12,
microbatch=-1, # -1 disables microbatches
ema_rate="0.9999", # comma-separated list of EMA values
log_interval=50,
eval_interval=2500,
save_interval=10000,
resume_checkpoint="",
resume_checkpoint_EG3D="",
use_fp16=False,
fp16_scale_growth=1e-3,
data_dir="",
eval_data_dir="",
load_depth=True, # TODO
logdir="/mnt/lustre/yslan/logs/nips23/",
load_submodule_name='', # for loading pretrained auto_encoder model
ignore_resume_opt=False,
# freeze_ae=False,
denoised_ae=True,
diffusion_ce_anneal=False,
use_lmdb=False,
interval=1,
freeze_triplane_decoder=False,
objv_dataset=False,
use_eos_feature=False,
clip_grad_throld=1.0,
allow_tf32=True,
)
defaults.update(model_and_diffusion_defaults())
defaults.update(continuous_diffusion_defaults())
defaults.update(encoder_and_nsr_defaults()) # type: ignore
defaults.update(dataset_defaults()) # type: ignore
defaults.update(loss_defaults())
defaults.update(control_net_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
# ! add transport args
parse_transport_args(parser)
return parser
if __name__ == "__main__":
# os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
# os.environ["NCCL_DEBUG"] = "INFO"
th.multiprocessing.set_start_method('spawn')
th.multiprocessing.set_sharing_strategy('file_system') # avoid "too many files opened"
# https://stackoverflow.com/questions/76300957/multiprocessing-pool-too-many-file-descriptors-using-pytorch
# https://stackoverflow.com/questions/48250053/pytorchs-dataloader-too-many-open-files-error-when-no-files-should-be-open
os.environ[
"TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" # set to DETAIL for runtime logging.
args = create_argparser().parse_args()
args.local_rank = int(os.environ["LOCAL_RANK"])
args.gpus = th.cuda.device_count()
# opts = dnnlib.EasyDict(vars(args)) # compatiable with triplane original settings
# opts = args
args.rendering_kwargs = rendering_options_defaults(args)
# Launch processes.
logger.log('Launching processes...')
logger.log('Available devices ', th.cuda.device_count())
logger.log('Current cuda device ', th.cuda.current_device())
# logger.log('GPU Device name:', th.cuda.get_device_name(th.cuda.current_device()))
try:
training_loop(args)
# except KeyboardInterrupt as e:
except Exception as e:
# print(e)
traceback.print_exc()
dist_util.cleanup() # clean port and socket when ctrl+c