Spaces:
Running
on
Zero
Running
on
Zero
from .train_util_diffusion import TrainLoop3DDiffusion | |
import dnnlib | |
import torch as th | |
class TrainLoop3DDiffusionDiT(TrainLoop3DDiffusion): | |
def __init__(self, | |
*, | |
rec_model, | |
denoise_model, | |
diffusion, | |
loss_class, | |
data, | |
eval_data, | |
batch_size, | |
microbatch, | |
lr, | |
ema_rate, | |
log_interval, | |
eval_interval, | |
save_interval, | |
resume_checkpoint, | |
use_fp16=False, | |
fp16_scale_growth=0.001, | |
schedule_sampler=None, | |
weight_decay=0, | |
lr_anneal_steps=0, | |
iterations=10001, | |
ignore_resume_opt=False, | |
freeze_ae=False, | |
denoised_ae=True, | |
triplane_scaling_divider=10, | |
use_amp=False, | |
**kwargs): | |
super().__init__(rec_model=rec_model, | |
denoise_model=denoise_model, | |
diffusion=diffusion, | |
loss_class=loss_class, | |
data=data, | |
eval_data=eval_data, | |
batch_size=batch_size, | |
microbatch=microbatch, | |
lr=lr, | |
ema_rate=ema_rate, | |
log_interval=log_interval, | |
eval_interval=eval_interval, | |
save_interval=save_interval, | |
resume_checkpoint=resume_checkpoint, | |
use_fp16=use_fp16, | |
fp16_scale_growth=fp16_scale_growth, | |
schedule_sampler=schedule_sampler, | |
weight_decay=weight_decay, | |
lr_anneal_steps=lr_anneal_steps, | |
iterations=iterations, | |
ignore_resume_opt=ignore_resume_opt, | |
freeze_ae=freeze_ae, | |
denoised_ae=denoised_ae, | |
triplane_scaling_divider=triplane_scaling_divider, | |
use_amp=use_amp, | |
**kwargs) | |
self.latent_name = 'latent_from_vit' | |
self.render_latent_behaviour = 'vit_postprocess_triplane_dec' # translate latent into 2D spatial tokens, then triplane render | |
def eval_ddpm_sample(self): | |
args = dnnlib.EasyDict( | |
dict(batch_size=1, | |
image_size=224, | |
denoise_in_channels=self.ddp_rec_model.module.decoder.triplane_decoder.out_chans, # type: ignore | |
clip_denoised=False, | |
class_cond=False, | |
use_ddim=False)) | |
model_kwargs = {} | |
if args.class_cond: | |
classes = th.randint(low=0, | |
high=NUM_CLASSES, | |
size=(args.batch_size, ), | |
device=dist_util.dev()) | |
model_kwargs["y"] = classes | |
diffusion = self.diffusion | |
sample_fn = (diffusion.p_sample_loop | |
if not args.use_ddim else diffusion.ddim_sample_loop) | |
vit_L = (224//14)**2 # vit sequence length | |
if self.ddp_rec_model.module.decoder.vit_decoder.cls_token: | |
vit_L += 1 | |
for i in range(1): | |
triplane_sample = sample_fn( | |
self.ddp_model, | |
(args.batch_size, vit_L, self.ddp_rec_model.module.decoder.vit_decoder.embed_dim), # vit token size, N L C | |
clip_denoised=args.clip_denoised, | |
model_kwargs=model_kwargs, | |
) | |
th.cuda.empty_cache() | |
self.render_video_given_triplane( | |
triplane_sample, | |
name_prefix=f'{self.step + self.resume_step}_{i}') |