GaussianAnything-AIGC3D / nsr /train_util_diffusion_dit.py
yslan's picture
init
7f51798
from .train_util_diffusion import TrainLoop3DDiffusion
import dnnlib
import torch as th
class TrainLoop3DDiffusionDiT(TrainLoop3DDiffusion):
def __init__(self,
*,
rec_model,
denoise_model,
diffusion,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
schedule_sampler=None,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
ignore_resume_opt=False,
freeze_ae=False,
denoised_ae=True,
triplane_scaling_divider=10,
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
denoise_model=denoise_model,
diffusion=diffusion,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
schedule_sampler=schedule_sampler,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
ignore_resume_opt=ignore_resume_opt,
freeze_ae=freeze_ae,
denoised_ae=denoised_ae,
triplane_scaling_divider=triplane_scaling_divider,
use_amp=use_amp,
**kwargs)
self.latent_name = 'latent_from_vit'
self.render_latent_behaviour = 'vit_postprocess_triplane_dec' # translate latent into 2D spatial tokens, then triplane render
def eval_ddpm_sample(self):
args = dnnlib.EasyDict(
dict(batch_size=1,
image_size=224,
denoise_in_channels=self.ddp_rec_model.module.decoder.triplane_decoder.out_chans, # type: ignore
clip_denoised=False,
class_cond=False,
use_ddim=False))
model_kwargs = {}
if args.class_cond:
classes = th.randint(low=0,
high=NUM_CLASSES,
size=(args.batch_size, ),
device=dist_util.dev())
model_kwargs["y"] = classes
diffusion = self.diffusion
sample_fn = (diffusion.p_sample_loop
if not args.use_ddim else diffusion.ddim_sample_loop)
vit_L = (224//14)**2 # vit sequence length
if self.ddp_rec_model.module.decoder.vit_decoder.cls_token:
vit_L += 1
for i in range(1):
triplane_sample = sample_fn(
self.ddp_model,
(args.batch_size, vit_L, self.ddp_rec_model.module.decoder.vit_decoder.embed_dim), # vit token size, N L C
clip_denoised=args.clip_denoised,
model_kwargs=model_kwargs,
)
th.cuda.empty_cache()
self.render_video_given_triplane(
triplane_sample,
name_prefix=f'{self.step + self.resume_step}_{i}')