diff --git a/README.md b/README.md index c9701c9d5c9fa874a4fd75170ffaa706937072c9..7f7028112b6c15e6ba6730ff03f619b3539c685d 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ --- -title: ShoeGen + V3D +title: InstantMesh emoji: 🏆 colorFrom: red colorTo: gray sdk: gradio -sdk_version: 4.21.0 +sdk_version: 4.26.0 app_file: app.py pinned: false --- diff --git a/app.py b/app.py index f5530cebbc617111eaabb0d22edfe2d0942701e1..a7d5f010ebb060dd4d5864d9e57b082982caf0cb 100644 --- a/app.py +++ b/app.py @@ -1,271 +1,367 @@ -# TODO +import spaces + +import os +import imageio import numpy as np -import argparse import torch -from torchvision.utils import make_grid -import tempfile -import gradio as gr -from omegaconf import OmegaConf -from einops import rearrange -from scripts.pub.V3D_512 import ( - sample_one, - get_batch, - get_unique_embedder_keys_from_conditioner, - load_model, -) -from sgm.util import default, instantiate_from_config -from safetensors.torch import load_file as load_safetensors +import rembg from PIL import Image -from kiui.op import recenter -from torchvision.transforms import ToTensor +from torchvision.transforms import v2 +from pytorch_lightning import seed_everything +from omegaconf import OmegaConf from einops import rearrange, repeat -import rembg -import os -from glob import glob -from mediapy import write_video -from pathlib import Path -import spaces +from tqdm import tqdm +from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler + +from src.utils.train_util import instantiate_from_config +from src.utils.camera_util import ( + FOV_to_intrinsics, + get_zero123plus_input_cameras, + get_circular_camera_poses, +) +from src.utils.mesh_util import save_obj +from src.utils.infer_util import remove_background, resize_foreground, images_to_video + +import tempfile +from functools import partial + from huggingface_hub import hf_hub_download -import imageio -import cv2 +import gradio as gr -@spaces.GPU -def do_sample( - image, - num_frames, - num_steps, - decoding_t, - border_ratio, - ignore_alpha, - output_folder, - seed, -): - # if image.mode == "RGBA": - # image = image.convert("RGB") - torch.manual_seed(seed) - image = Image.fromarray(image) - w, h = image.size - - if border_ratio > 0: - if image.mode != "RGBA" or ignore_alpha: - image = image.convert("RGB") - image = np.asarray(image) - carved_image = rembg.remove(image, session=rembg_session) # [H, W, 4] - else: - image = np.asarray(image) - carved_image = image - mask = carved_image[..., -1] > 0 - image = recenter(carved_image, mask, border_ratio=border_ratio) - image = image.astype(np.float32) / 255.0 - if image.shape[-1] == 4: - image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) - image = Image.fromarray((image * 255).astype(np.uint8)) +def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): + """ + Get the rendering camera parameters. + """ + c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) + if is_flexicubes: + cameras = torch.linalg.inv(c2ws) + cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) else: - print("Ignore border ratio") - image = image.resize((512, 512)) - - image = ToTensor()(image) - image = image * 2.0 - 1.0 - - image = image.unsqueeze(0).to(device) - H, W = image.shape[2:] - assert image.shape[1] == 3 - F = 8 - C = 4 - shape = (num_frames, C, H // F, W // F) - - value_dict = {} - value_dict["motion_bucket_id"] = 0 - value_dict["fps_id"] = 0 - value_dict["cond_aug"] = 0.05 - value_dict["cond_frames_without_noise"] = clip_model(image) - value_dict["cond_frames"] = ae_model.encode(image) - value_dict["cond_frames"] += 0.05 * torch.randn_like(value_dict["cond_frames"]) - value_dict["cond_aug"] = 0.05 - - print(device) - with torch.no_grad(): - with torch.autocast(device_type="cuda"): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [1, num_frames], - T=num_frames, - device=device, - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=[ - "cond_frames", - "cond_frames_without_noise", - ], - ) - - for k in ["crossattn", "concat"]: - uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) - uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) - c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) - c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) - - randn = torch.randn(shape, device=device) - randn = randn.to(device) - - additional_model_inputs = {} - additional_model_inputs["image_only_indicator"] = torch.zeros( - 2, num_frames - ).to(device) - additional_model_inputs["num_video_frames"] = batch["num_video_frames"] - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) + extrinsics = c2ws.flatten(-2) + intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) + return cameras - samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) - model.en_and_decode_n_samples_a_time = decoding_t - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - os.makedirs(output_folder, exist_ok=True) - base_count = len(glob(os.path.join(output_folder, "*.mp4"))) - video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") - - frames = ( - (rearrange(samples, "t c h w -> t h w c") * 255) - .cpu() - .numpy() - .astype(np.uint8) - ) - # write_video(video_path, frames, fps=6) - # writer = cv2.VideoWriter( - # video_path, - # cv2.VideoWriter_fourcc("m", "p", "4", "v"), - # 6, - # (frames.shape[-1], frames.shape[-2]), - # ) - # for fr in frames: - # writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR)) - # writer.release() - imageio.mimwrite(video_path, frames, fps=6) - - return video_path - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -# download -V3D_ckpt_path = hf_hub_download(repo_id="heheyas/V3D", filename="V3D.ckpt") -svd_xt_ckpt_path = hf_hub_download( - repo_id="stabilityai/stable-video-diffusion-img2vid-xt", - filename="svd_xt.safetensors", -) -model_config = "./scripts/pub/configs/V3D_512.yaml" -num_frames = OmegaConf.load( - model_config -).model.params.sampler_config.params.guider_config.params.num_frames -print("Detected num_frames:", num_frames) -# num_steps = default(num_steps, 25) -num_steps = 25 -output_folder = "outputs/V3D_512" - -sd = load_safetensors(svd_xt_ckpt_path) -clip_model_config = OmegaConf.load("./configs/embedder/clip_image.yaml") -clip_model = instantiate_from_config(clip_model_config).eval() -clip_sd = dict() -for k, v in sd.items(): - if "conditioner.embedders.0" in k: - clip_sd[k.replace("conditioner.embedders.0.", "")] = v -clip_model.load_state_dict(clip_sd) -clip_model = clip_model.to(device) - -ae_model_config = OmegaConf.load("./configs/ae/video.yaml") -ae_model = instantiate_from_config(ae_model_config).eval() -encoder_sd = dict() -for k, v in sd.items(): - if "first_stage_model" in k: - encoder_sd[k.replace("first_stage_model.", "")] = v -ae_model.load_state_dict(encoder_sd) -ae_model = ae_model.to(device) -rembg_session = rembg.new_session() - -model, _ = load_model( - model_config, - device, - num_frames, - num_steps, - min_cfg=3.5, - max_cfg=3.5, - ckpt_path=V3D_ckpt_path, +def images_to_video(images, output_path, fps=30): + # images: (N, C, H, W) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + frames = [] + for i in range(images.shape[0]): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264') + + +############################################################################### +# Configuration. +############################################################################### + +import shutil + +def find_cuda(): + # Check if CUDA_HOME or CUDA_PATH environment variables are set + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + + if cuda_home and os.path.exists(cuda_home): + return cuda_home + + # Search for the nvcc executable in the system's PATH + nvcc_path = shutil.which('nvcc') + + if nvcc_path: + # Remove the 'bin/nvcc' part to get the CUDA installation path + cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) + return cuda_path + + return None + +cuda_path = find_cuda() + +if cuda_path: + print(f"CUDA installation found at: {cuda_path}") +else: + print("CUDA installation not found") + +config_path = 'configs/instant-mesh-large.yaml' +config = OmegaConf.load(config_path) +config_name = os.path.basename(config_path).replace('.yaml', '') +model_config = config.model_config +infer_config = config.infer_config + +IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False + +device = torch.device('cuda') + +# load diffusion model +print('Loading diffusion model ...') +pipeline = DiffusionPipeline.from_pretrained( + "sudo-ai/zero123plus-v1.2", + custom_pipeline="zero123plus", + torch_dtype=torch.float16, +) +pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( + pipeline.scheduler.config, timestep_spacing='trailing' ) + +# load custom white-background UNet +unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model") +state_dict = torch.load(unet_ckpt_path, map_location='cpu') +pipeline.unet.load_state_dict(state_dict, strict=True) + +pipeline = pipeline.to(device) + +# load reconstruction model +print('Loading reconstruction model ...') +model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model") +model = instantiate_from_config(model_config) +state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict'] +state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k} +model.load_state_dict(state_dict, strict=True) + model = model.to(device) -with gr.Blocks(title="V3D", theme=gr.themes.Monochrome()) as demo: - with gr.Row(equal_height=True): +print('Loading Finished!') + + +def check_input_image(input_image): + if input_image is None: + raise gr.Error("No image uploaded!") + + +def preprocess(input_image, do_remove_background): + + rembg_session = rembg.new_session() if do_remove_background else None + + if do_remove_background: + input_image = remove_background(input_image, rembg_session) + input_image = resize_foreground(input_image, 0.85) + + return input_image + + +@spaces.GPU +def generate_mvs(input_image, sample_steps, sample_seed): + + seed_everything(sample_seed) + + # sampling + z123_image = pipeline( + input_image, + num_inference_steps=sample_steps + ).images[0] + + show_image = np.asarray(z123_image, dtype=np.uint8) + show_image = torch.from_numpy(show_image) # (960, 640, 3) + show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2) + show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3) + show_image = Image.fromarray(show_image.numpy()) + + return z123_image, show_image + + +@spaces.GPU +def make3d(images): + + global model + if IS_FLEXICUBES: + model.init_flexicubes_geometry(device, use_renderer=False) + model = model.eval() + + images = np.asarray(images, dtype=np.float32) / 255.0 + images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) + images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) + + input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) + render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device) + + images = images.unsqueeze(0).to(device) + images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) + + mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name + print(mesh_fpath) + mesh_basename = os.path.basename(mesh_fpath).split('.')[0] + mesh_dirname = os.path.dirname(mesh_fpath) + video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4") + + with torch.no_grad(): + # get triplane + planes = model.forward_planes(images, input_cameras) + + # # get video + # chunk_size = 20 if IS_FLEXICUBES else 1 + # render_size = 384 + + # frames = [] + # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): + # if IS_FLEXICUBES: + # frame = model.forward_geometry( + # planes, + # render_cameras[:, i:i+chunk_size], + # render_size=render_size, + # )['img'] + # else: + # frame = model.synthesizer( + # planes, + # cameras=render_cameras[:, i:i+chunk_size], + # render_size=render_size, + # )['images_rgb'] + # frames.append(frame) + # frames = torch.cat(frames, dim=1) + + # images_to_video( + # frames[0], + # video_fpath, + # fps=30, + # ) + + # print(f"Video saved to {video_fpath}") + + # get mesh + mesh_out = model.extract_mesh( + planes, + use_texture_map=False, + **infer_config, + ) + + vertices, faces, vertex_colors = mesh_out + vertices = vertices[:, [1, 2, 0]] + vertices[:, -1] *= -1 + faces = faces[:, [2, 1, 0]] + + save_obj(vertices, faces, vertex_colors, mesh_fpath) + + print(f"Mesh saved to {mesh_fpath}") + + return mesh_fpath + + +_HEADER_ = ''' +

Official 🤗 Gradio Demo

InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models

+''' + +_LINKS_ = ''' +

Code is available at GitHub

+

Report is available at ArXiv

+''' + +_CITE_ = r""" +```bibtex +@article{xu2024instantmesh, + title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models}, + author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying}, + journal={arXiv preprint arXiv:2404.07191}, + year={2024} +} +``` +""" + + +with gr.Blocks() as demo: + gr.Markdown(_HEADER_) + with gr.Row(variant="panel"): with gr.Column(): - input_image = gr.Image(value=None, label="Input Image") - - border_ratio_slider = gr.Slider( - value=0.3, - label="Border Ratio", - minimum=0.05, - maximum=0.5, - step=0.05, - ) - seed_input = gr.Number(value=42) - decoding_t_slider = gr.Slider( - value=1, - label="Number of Decoding frames", - minimum=1, - maximum=num_frames, - step=1, - ) - min_guidance_slider = gr.Slider( - value=3.5, - label="Min CFG Value", - minimum=0.05, - maximum=5, - step=0.05, - ) - max_guidance_slider = gr.Slider( - value=3.5, - label="Max CFG Value", - minimum=0.05, - maximum=5, - step=0.05, - ) - run_button = gr.Button(value="Run V3D") + with gr.Row(): + input_image = gr.Image( + label="Input Image", + image_mode="RGBA", + sources="upload", + #width=256, + #height=256, + type="pil", + elem_id="content_image", + ) + processed_image = gr.Image( + label="Processed Image", + image_mode="RGBA", + #width=256, + #height=256, + type="pil", + interactive=False + ) + with gr.Row(): + with gr.Group(): + do_remove_background = gr.Checkbox( + label="Remove Background", value=True + ) + sample_seed = gr.Number(value=42, label="Seed Value", precision=0) + + sample_steps = gr.Slider( + label="Sample Steps", + minimum=30, + maximum=75, + value=75, + step=5 + ) + + with gr.Row(): + submit = gr.Button("Generate", elem_id="generate", variant="primary") + + with gr.Row(variant="panel"): + gr.Examples( + examples=[ + os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples")) + ], + inputs=[input_image], + label="Examples", + cache_examples=False, + examples_per_page=12 + ) with gr.Column(): - output_video = gr.Video(value=None, label="Output Orbit Video") - - @run_button.click( - inputs=[ - input_image, - border_ratio_slider, - min_guidance_slider, - max_guidance_slider, - decoding_t_slider, - seed_input, - ], - outputs=[output_video], - ) - def _(image, border_ratio, min_guidance, max_guidance, decoding_t, seed): - model.sampler.guider.max_scale = max_guidance - model.sampler.guider.min_scale = min_guidance - return do_sample( - image, - num_frames, - num_steps, - int(decoding_t), - border_ratio, - False, - output_folder, - seed, - ) + with gr.Row(): + + with gr.Column(): + mv_show_images = gr.Image( + label="Generated Multi-views", + type="pil", + width=379, + interactive=False + ) + + # with gr.Column(): + # output_video = gr.Video( + # label="video", format="mp4", + # width=379, + # autoplay=True, + # interactive=False + # ) + + with gr.Row(): + output_model_obj = gr.Model3D( + label="Output Model (OBJ Format)", + interactive=False, + ) + + with gr.Row(): + gr.Markdown('''Try a different seed value if the result is unsatisfying (Default: 42).''') + + gr.Markdown(_LINKS_) + gr.Markdown(_CITE_) + + mv_images = gr.State() + + submit.click(fn=check_input_image, inputs=[input_image]).success( + fn=preprocess, + inputs=[input_image, do_remove_background], + outputs=[processed_image], + ).success( + fn=generate_mvs, + inputs=[processed_image, sample_steps, sample_seed], + outputs=[mv_images, mv_show_images] + + ).success( + fn=make3d, + inputs=[mv_images], + outputs=[output_model_obj] + ) -demo.launch() +demo.launch() \ No newline at end of file diff --git a/ckpts/shoes.safetensors b/ckpts/shoes.safetensors deleted file mode 100644 index 7576f35d2861ef021cc9eccb611983f818942e34..0000000000000000000000000000000000000000 --- a/ckpts/shoes.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e66a57b2174aff462c3bc0c9f9e3b1142617d856a1f5ddbada3b696dcc057b73 -size 170543188 diff --git a/ckpts/snckrsgen.safetensors b/ckpts/snckrsgen.safetensors deleted file mode 100644 index 855ee1c67696d0785be1619e264c9361f1143b82..0000000000000000000000000000000000000000 --- a/ckpts/snckrsgen.safetensors +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e80bf5f4ded84793d74c9939b0fc1a09b76af31bafe2ac3190c21c9be5eb6965 -size 151112168 diff --git a/configs/ae/video.yaml b/configs/ae/video.yaml deleted file mode 100644 index ecc65942a2203ddb468763cff2fc894616fc47a3..0000000000000000000000000000000000000000 --- a/configs/ae/video.yaml +++ /dev/null @@ -1,35 +0,0 @@ -target: sgm.models.autoencoder.AutoencodingEngine -params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: sgm.modules.diffusionmodules.model.Encoder - params: - attn_type: vanilla - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder - params: - attn_type: vanilla - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - video_kernel_size: [3, 1, 1] \ No newline at end of file diff --git a/configs/embedder/clip_image.yaml b/configs/embedder/clip_image.yaml deleted file mode 100644 index 54a2a92c162d9c950c16b0f12170d1d73d999212..0000000000000000000000000000000000000000 --- a/configs/embedder/clip_image.yaml +++ /dev/null @@ -1,8 +0,0 @@ -target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder -params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True \ No newline at end of file diff --git a/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml b/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml deleted file mode 100644 index 731f55930fba00cb9de758c90eefbcd1afd59d47..0000000000000000000000000000000000000000 --- a/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +++ /dev/null @@ -1,104 +0,0 @@ -model: - base_learning_rate: 4.5e-6 - target: sgm.models.autoencoder.AutoencodingEngine - params: - input_key: jpg - monitor: val/rec_loss - - loss_config: - target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator - params: - perceptual_weight: 0.25 - disc_start: 20001 - disc_weight: 0.5 - learn_logvar: True - - regularization_weights: - kl_loss: 1.0 - - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - - encoder_config: - target: sgm.modules.diffusionmodules.model.Encoder - params: - attn_type: none - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4] - num_res_blocks: 4 - attn_resolutions: [] - dropout: 0.0 - - decoder_config: - target: sgm.modules.diffusionmodules.model.Decoder - params: ${model.params.encoder_config.params} - -data: - target: sgm.data.dataset.StableDataModuleFromConfig - params: - train: - datapipeline: - urls: - - DATA-PATH - pipeline_config: - shardshuffle: 10000 - sample_shuffle: 10000 - - decoders: - - pil - - postprocessors: - - target: sdata.mappers.TorchVisionImageTransforms - params: - key: jpg - transforms: - - target: torchvision.transforms.Resize - params: - size: 256 - interpolation: 3 - - target: torchvision.transforms.ToTensor - - target: sdata.mappers.Rescaler - - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare - params: - h_key: height - w_key: width - - loader: - batch_size: 8 - num_workers: 4 - - -lightning: - strategy: - target: pytorch_lightning.strategies.DDPStrategy - params: - find_unused_parameters: True - - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 50000 - - image_logger: - target: main.ImageLogger - params: - enable_autocast: False - batch_frequency: 1000 - max_images: 8 - increase_log_steps: True - - trainer: - devices: 0, - limit_val_batches: 50 - benchmark: True - accumulate_grad_batches: 1 - val_check_interval: 10000 \ No newline at end of file diff --git a/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml b/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml deleted file mode 100644 index 39c7c9df5da1c657d2ce72ac8b6269ae86185e91..0000000000000000000000000000000000000000 --- a/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +++ /dev/null @@ -1,105 +0,0 @@ -model: - base_learning_rate: 4.5e-6 - target: sgm.models.autoencoder.AutoencodingEngine - params: - input_key: jpg - monitor: val/loss/rec - disc_start_iter: 0 - - encoder_config: - target: sgm.modules.diffusionmodules.model.Encoder - params: - attn_type: vanilla-xformers - double_z: true - z_channels: 8 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - - decoder_config: - target: sgm.modules.diffusionmodules.model.Decoder - params: ${model.params.encoder_config.params} - - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - - loss_config: - target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator - params: - perceptual_weight: 0.25 - disc_start: 20001 - disc_weight: 0.5 - learn_logvar: True - - regularization_weights: - kl_loss: 1.0 - -data: - target: sgm.data.dataset.StableDataModuleFromConfig - params: - train: - datapipeline: - urls: - - DATA-PATH - pipeline_config: - shardshuffle: 10000 - sample_shuffle: 10000 - - decoders: - - pil - - postprocessors: - - target: sdata.mappers.TorchVisionImageTransforms - params: - key: jpg - transforms: - - target: torchvision.transforms.Resize - params: - size: 256 - interpolation: 3 - - target: torchvision.transforms.ToTensor - - target: sdata.mappers.Rescaler - - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare - params: - h_key: height - w_key: width - - loader: - batch_size: 8 - num_workers: 4 - - -lightning: - strategy: - target: pytorch_lightning.strategies.DDPStrategy - params: - find_unused_parameters: True - - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 50000 - - image_logger: - target: main.ImageLogger - params: - enable_autocast: False - batch_frequency: 1000 - max_images: 8 - increase_log_steps: True - - trainer: - devices: 0, - limit_val_batches: 50 - benchmark: True - accumulate_grad_batches: 1 - val_check_interval: 10000 diff --git a/configs/example_training/imagenet-f8_cond.yaml b/configs/example_training/imagenet-f8_cond.yaml deleted file mode 100644 index 23cded00a72e2883df1a4bf2b639a49cda763a8e..0000000000000000000000000000000000000000 --- a/configs/example_training/imagenet-f8_cond.yaml +++ /dev/null @@ -1,185 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - log_keys: - - cls - - scheduler_config: - target: sgm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [10000] - cycle_lengths: [10000000000000] - f_start: [1.e-6] - f_max: [1.] - f_min: [1.] - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 256 - attention_resolutions: [1, 2, 4] - num_res_blocks: 2 - channel_mult: [1, 2, 4] - num_head_channels: 64 - num_classes: sequential - adm_in_channels: 1024 - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: cls - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ClassEmbedder - params: - add_sequence_dim: True - embed_dim: 1024 - n_classes: 1000 - - - is_trainable: False - ucg_rate: 0.2 - input_key: original_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: crop_coords_top_left - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - ckpt_path: CKPT_PATH - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling - params: - num_idx: 1000 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 5.0 - -data: - target: sgm.data.dataset.StableDataModuleFromConfig - params: - train: - datapipeline: - urls: - # USER: adapt this path the root of your custom dataset - - DATA_PATH - pipeline_config: - shardshuffle: 10000 - sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM - - decoders: - - pil - - postprocessors: - - target: sdata.mappers.TorchVisionImageTransforms - params: - key: jpg # USER: you might wanna adapt this for your custom dataset - transforms: - - target: torchvision.transforms.Resize - params: - size: 256 - interpolation: 3 - - target: torchvision.transforms.ToTensor - - target: sdata.mappers.Rescaler - - - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare - params: - h_key: height # USER: you might wanna adapt this for your custom dataset - w_key: width # USER: you might wanna adapt this for your custom dataset - - loader: - batch_size: 64 - num_workers: 6 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - enable_autocast: False - batch_frequency: 1000 - max_images: 8 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 8 - n_rows: 2 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 1000 \ No newline at end of file diff --git a/configs/example_training/toy/cifar10_cond.yaml b/configs/example_training/toy/cifar10_cond.yaml deleted file mode 100644 index fca9958464488a66ed2a54d57c59228215690606..0000000000000000000000000000000000000000 --- a/configs/example_training/toy/cifar10_cond.yaml +++ /dev/null @@ -1,98 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling - params: - sigma_data: 1.0 - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - in_channels: 3 - out_channels: 3 - model_channels: 32 - attention_resolutions: [] - num_res_blocks: 4 - channel_mult: [1, 2, 2] - num_head_channels: 32 - num_classes: sequential - adm_in_channels: 128 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: cls - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ClassEmbedder - params: - embed_dim: 128 - n_classes: 10 - - first_stage_config: - target: sgm.models.autoencoder.IdentityFirstStage - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 3.0 - -data: - target: sgm.data.cifar10.CIFAR10Loader - params: - batch_size: 512 - num_workers: 1 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - batch_frequency: 1000 - max_images: 64 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 64 - n_rows: 8 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist.yaml b/configs/example_training/toy/mnist.yaml deleted file mode 100644 index a86d05ca1efa537b57646c3923c1f54ac0d6ccf4..0000000000000000000000000000000000000000 --- a/configs/example_training/toy/mnist.yaml +++ /dev/null @@ -1,79 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling - params: - sigma_data: 1.0 - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - in_channels: 1 - out_channels: 1 - model_channels: 32 - attention_resolutions: [] - num_res_blocks: 4 - channel_mult: [1, 2, 2] - num_head_channels: 32 - - first_stage_config: - target: sgm.models.autoencoder.IdentityFirstStage - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - -data: - target: sgm.data.mnist.MNISTLoader - params: - batch_size: 512 - num_workers: 1 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - batch_frequency: 1000 - max_images: 64 - increase_log_steps: False - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 64 - n_rows: 8 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 10 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond.yaml b/configs/example_training/toy/mnist_cond.yaml deleted file mode 100644 index 8378acd7acd4c23039a659789b6e6ff5de1a1058..0000000000000000000000000000000000000000 --- a/configs/example_training/toy/mnist_cond.yaml +++ /dev/null @@ -1,98 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling - params: - sigma_data: 1.0 - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - in_channels: 1 - out_channels: 1 - model_channels: 32 - attention_resolutions: [] - num_res_blocks: 4 - channel_mult: [1, 2, 2] - num_head_channels: 32 - num_classes: sequential - adm_in_channels: 128 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: cls - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ClassEmbedder - params: - embed_dim: 128 - n_classes: 10 - - first_stage_config: - target: sgm.models.autoencoder.IdentityFirstStage - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 3.0 - -data: - target: sgm.data.mnist.MNISTLoader - params: - batch_size: 512 - num_workers: 1 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - batch_frequency: 1000 - max_images: 16 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 16 - n_rows: 4 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond_discrete_eps.yaml b/configs/example_training/toy/mnist_cond_discrete_eps.yaml deleted file mode 100644 index e58aae58dd108887d8d2ac06933a31f84ea61509..0000000000000000000000000000000000000000 --- a/configs/example_training/toy/mnist_cond_discrete_eps.yaml +++ /dev/null @@ -1,103 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - in_channels: 1 - out_channels: 1 - model_channels: 32 - attention_resolutions: [] - num_res_blocks: 4 - channel_mult: [1, 2, 2] - num_head_channels: 32 - num_classes: sequential - adm_in_channels: 128 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: cls - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ClassEmbedder - params: - embed_dim: 128 - n_classes: 10 - - first_stage_config: - target: sgm.models.autoencoder.IdentityFirstStage - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling - params: - num_idx: 1000 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 5.0 - -data: - target: sgm.data.mnist.MNISTLoader - params: - batch_size: 512 - num_workers: 1 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - batch_frequency: 1000 - max_images: 16 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 16 - n_rows: 4 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond_l1_loss.yaml b/configs/example_training/toy/mnist_cond_l1_loss.yaml deleted file mode 100644 index ee2f780358b7fe100efa226ae20f6ac58b441632..0000000000000000000000000000000000000000 --- a/configs/example_training/toy/mnist_cond_l1_loss.yaml +++ /dev/null @@ -1,99 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling - params: - sigma_data: 1.0 - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - in_channels: 1 - out_channels: 1 - model_channels: 32 - attention_resolutions: [] - num_res_blocks: 4 - channel_mult: [1, 2, 2] - num_head_channels: 32 - num_classes: sequential - adm_in_channels: 128 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: cls - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ClassEmbedder - params: - embed_dim: 128 - n_classes: 10 - - first_stage_config: - target: sgm.models.autoencoder.IdentityFirstStage - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_type: l1 - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 3.0 - -data: - target: sgm.data.mnist.MNISTLoader - params: - batch_size: 512 - num_workers: 1 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - batch_frequency: 1000 - max_images: 64 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 64 - n_rows: 8 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond_with_ema.yaml b/configs/example_training/toy/mnist_cond_with_ema.yaml deleted file mode 100644 index c666e7143b7cb0a920d384f3f6294231b8bb1726..0000000000000000000000000000000000000000 --- a/configs/example_training/toy/mnist_cond_with_ema.yaml +++ /dev/null @@ -1,100 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - use_ema: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling - params: - sigma_data: 1.0 - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - in_channels: 1 - out_channels: 1 - model_channels: 32 - attention_resolutions: [] - num_res_blocks: 4 - channel_mult: [1, 2, 2] - num_head_channels: 32 - num_classes: sequential - adm_in_channels: 128 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: cls - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.ClassEmbedder - params: - embed_dim: 128 - n_classes: 10 - - first_stage_config: - target: sgm.models.autoencoder.IdentityFirstStage - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 3.0 - -data: - target: sgm.data.mnist.MNISTLoader - params: - batch_size: 512 - num_workers: 1 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - batch_frequency: 1000 - max_images: 64 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 64 - n_rows: 8 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml b/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml deleted file mode 100644 index 0f268c3295bd57888de3efc736d307903ee80a8f..0000000000000000000000000000000000000000 --- a/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +++ /dev/null @@ -1,182 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - log_keys: - - txt - - scheduler_config: - target: sgm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [10000] - cycle_lengths: [10000000000000] - f_start: [1.e-6] - f_max: [1.] - f_min: [1.] - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [1, 2, 4] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - num_classes: sequential - adm_in_channels: 1792 - num_heads: 1 - transformer_depth: 1 - context_dim: 768 - spatial_transformer_attn_type: softmax-xformers - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: txt - ucg_rate: 0.1 - legacy_ucg_value: "" - target: sgm.modules.encoders.modules.FrozenCLIPEmbedder - params: - always_return_pooled: True - - - is_trainable: False - ucg_rate: 0.1 - input_key: original_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: crop_coords_top_left - ucg_rate: 0.1 - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - ckpt_path: CKPT_PATH - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [ 1, 2, 4, 4 ] - num_res_blocks: 2 - attn_resolutions: [ ] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling - params: - num_idx: 1000 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 7.5 - -data: - target: sgm.data.dataset.StableDataModuleFromConfig - params: - train: - datapipeline: - urls: - # USER: adapt this path the root of your custom dataset - - DATA_PATH - pipeline_config: - shardshuffle: 10000 - sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM - - decoders: - - pil - - postprocessors: - - target: sdata.mappers.TorchVisionImageTransforms - params: - key: jpg # USER: you might wanna adapt this for your custom dataset - transforms: - - target: torchvision.transforms.Resize - params: - size: 256 - interpolation: 3 - - target: torchvision.transforms.ToTensor - - target: sdata.mappers.Rescaler - - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare - # USER: you might wanna use non-default parameters due to your custom dataset - - loader: - batch_size: 64 - num_workers: 6 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - enable_autocast: False - batch_frequency: 1000 - max_images: 8 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 8 - n_rows: 2 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 1000 \ No newline at end of file diff --git a/configs/example_training/txt2img-clipl.yaml b/configs/example_training/txt2img-clipl.yaml deleted file mode 100644 index cb66ede901b1aa1acb18d162b88912a2e6eab0ce..0000000000000000000000000000000000000000 --- a/configs/example_training/txt2img-clipl.yaml +++ /dev/null @@ -1,184 +0,0 @@ -model: - base_learning_rate: 1.0e-4 - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - log_keys: - - txt - - scheduler_config: - target: sgm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [10000] - cycle_lengths: [10000000000000] - f_start: [1.e-6] - f_max: [1.] - f_min: [1.] - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [1, 2, 4] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - num_classes: sequential - adm_in_channels: 1792 - num_heads: 1 - transformer_depth: 1 - context_dim: 768 - spatial_transformer_attn_type: softmax-xformers - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: True - input_key: txt - ucg_rate: 0.1 - legacy_ucg_value: "" - target: sgm.modules.encoders.modules.FrozenCLIPEmbedder - params: - always_return_pooled: True - - - is_trainable: False - ucg_rate: 0.1 - input_key: original_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: crop_coords_top_left - ucg_rate: 0.1 - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - ckpt_path: CKPT_PATH - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling - params: - num_idx: 1000 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 50 - - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - guider_config: - target: sgm.modules.diffusionmodules.guiders.VanillaCFG - params: - scale: 7.5 - -data: - target: sgm.data.dataset.StableDataModuleFromConfig - params: - train: - datapipeline: - urls: - # USER: adapt this path the root of your custom dataset - - DATA_PATH - pipeline_config: - shardshuffle: 10000 - sample_shuffle: 10000 - - - decoders: - - pil - - postprocessors: - - target: sdata.mappers.TorchVisionImageTransforms - params: - key: jpg # USER: you might wanna adapt this for your custom dataset - transforms: - - target: torchvision.transforms.Resize - params: - size: 256 - interpolation: 3 - - target: torchvision.transforms.ToTensor - - target: sdata.mappers.Rescaler - # USER: you might wanna use non-default parameters due to your custom dataset - - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare - # USER: you might wanna use non-default parameters due to your custom dataset - - loader: - batch_size: 64 - num_workers: 6 - -lightning: - modelcheckpoint: - params: - every_n_train_steps: 5000 - - callbacks: - metrics_over_trainsteps_checkpoint: - params: - every_n_train_steps: 25000 - - image_logger: - target: main.ImageLogger - params: - disabled: False - enable_autocast: False - batch_frequency: 1000 - max_images: 8 - increase_log_steps: True - log_first_step: False - log_images_kwargs: - use_ema_scope: False - N: 8 - n_rows: 2 - - trainer: - devices: 0, - benchmark: True - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 - max_epochs: 1000 \ No newline at end of file diff --git a/configs/inference/sd_2_1.yaml b/configs/inference/sd_2_1.yaml deleted file mode 100644 index 6531c6c49fab2d5d9f21c75e53b0370cb8dad8dc..0000000000000000000000000000000000000000 --- a/configs/inference/sd_2_1.yaml +++ /dev/null @@ -1,60 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: true - layer: penultimate - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity \ No newline at end of file diff --git a/configs/inference/sd_2_1_768.yaml b/configs/inference/sd_2_1_768.yaml deleted file mode 100644 index e2f9910a192745781edb2a8505fae1d3e1916f87..0000000000000000000000000000000000000000 --- a/configs/inference/sd_2_1_768.yaml +++ /dev/null @@ -1,60 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder - params: - freeze: true - layer: penultimate - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity \ No newline at end of file diff --git a/configs/inference/sd_xl_base.yaml b/configs/inference/sd_xl_base.yaml deleted file mode 100644 index 6047379753a05224bb5b3f6746130fb7fb9f40aa..0000000000000000000000000000000000000000 --- a/configs/inference/sd_xl_base.yaml +++ /dev/null @@ -1,93 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - adm_in_channels: 2816 - num_classes: sequential - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2] - num_res_blocks: 2 - channel_mult: [1, 2, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: [1, 2, 10] - context_dim: 2048 - spatial_transformer_attn_type: softmax-xformers - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenCLIPEmbedder - params: - layer: hidden - layer_idx: 11 - - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 - params: - arch: ViT-bigG-14 - version: laion2b_s39b_b160k - freeze: True - layer: penultimate - always_return_pooled: True - legacy: False - - - is_trainable: False - input_key: original_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: crop_coords_top_left - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: target_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity diff --git a/configs/inference/sd_xl_refiner.yaml b/configs/inference/sd_xl_refiner.yaml deleted file mode 100644 index 2d5ab44e748c55f5f2e34ae5aefdb78a921a8d3f..0000000000000000000000000000000000000000 --- a/configs/inference/sd_xl_refiner.yaml +++ /dev/null @@ -1,86 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.13025 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser - params: - num_idx: 1000 - - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization - - network_config: - target: sgm.modules.diffusionmodules.openaimodel.UNetModel - params: - adm_in_channels: 2560 - num_classes: sequential - use_checkpoint: True - in_channels: 4 - out_channels: 4 - model_channels: 384 - attention_resolutions: [4, 2] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 4 - context_dim: [1280, 1280, 1280, 1280] - spatial_transformer_attn_type: softmax-xformers - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: txt - target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 - params: - arch: ViT-bigG-14 - version: laion2b_s39b_b160k - legacy: False - freeze: True - layer: penultimate - always_return_pooled: True - - - is_trainable: False - input_key: original_size_as_tuple - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: crop_coords_top_left - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - is_trainable: False - input_key: aesthetic_score - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity diff --git a/configs/inference/svd.yaml b/configs/inference/svd.yaml deleted file mode 100644 index 2a0819ea77f1ed95dfedb2ab6ccdded4e6414e43..0000000000000000000000000000000000000000 --- a/configs/inference/svd.yaml +++ /dev/null @@ -1,131 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 768 - num_classes: sequential - use_checkpoint: True - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: True - use_spatial_context: True - merge_strategy: learned_with_images - video_kernel_size: [3, 1, 1] - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: cond_frames_without_noise - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True - - - input_key: fps_id - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - input_key: motion_bucket_id - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - input_key: cond_frames - is_trainable: False - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: True - n_cond_frames: 1 - n_copies: 1 - is_ae: True - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - - input_key: cond_aug - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: sgm.modules.diffusionmodules.model.Encoder - params: - attn_type: vanilla - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder - params: - attn_type: vanilla - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - video_kernel_size: [3, 1, 1] \ No newline at end of file diff --git a/configs/inference/svd_image_decoder.yaml b/configs/inference/svd_image_decoder.yaml deleted file mode 100644 index bb09177ad77a8154c20fcbb2e2fdfc0ac9b6c491..0000000000000000000000000000000000000000 --- a/configs/inference/svd_image_decoder.yaml +++ /dev/null @@ -1,114 +0,0 @@ -model: - target: sgm.models.diffusion.DiffusionEngine - params: - scale_factor: 0.18215 - disable_first_stage_autocast: True - - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 768 - num_classes: sequential - use_checkpoint: True - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: [4, 2, 1] - num_res_blocks: 2 - channel_mult: [1, 2, 4, 4] - num_head_channels: 64 - use_linear_in_transformer: True - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: True - use_spatial_context: True - merge_strategy: learned_with_images - video_kernel_size: [3, 1, 1] - - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: False - input_key: cond_frames_without_noise - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: True - - - input_key: fps_id - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - input_key: motion_bucket_id - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - - input_key: cond_frames - is_trainable: False - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: True - n_cond_frames: 1 - n_copies: 1 - is_ae: True - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - - input_key: cond_aug - is_trainable: False - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - first_stage_config: - target: sgm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: True - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: [1, 2, 4, 4] - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity \ No newline at end of file diff --git a/configs/inference/svd_mv.yaml b/configs/inference/svd_mv.yaml deleted file mode 100644 index a343094433e2a486d730f8c3a4561a371f3ad777..0000000000000000000000000000000000000000 --- a/configs/inference/svd_mv.yaml +++ /dev/null @@ -1,202 +0,0 @@ -model: - base_learning_rate: 1.0e-05 - target: sgm.models.video_diffusion.DiffusionEngine - params: - ckpt_path: ckpts/svd_xt.safetensors - scale_factor: 0.18215 - disable_first_stage_autocast: true - scheduler_config: - target: sgm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: - - 1 - cycle_lengths: - - 10000000000000 - f_start: - - 1.0e-06 - f_max: - - 1.0 - f_min: - - 1.0 - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 768 - num_classes: sequential - use_checkpoint: true - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: - - 4 - - 2 - - 1 - num_res_blocks: 2 - channel_mult: - - 1 - - 2 - - 4 - - 4 - num_head_channels: 64 - use_linear_in_transformer: true - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: true - use_spatial_context: true - merge_strategy: learned_with_images - video_kernel_size: - - 3 - - 1 - - 1 - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: false - ucg_rate: 0.2 - input_key: cond_frames_without_noise - target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder - params: - n_cond_frames: 1 - n_copies: 1 - open_clip_embedding_config: - target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder - params: - freeze: true - - input_key: fps_id - is_trainable: true - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - input_key: motion_bucket_id - is_trainable: true - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - input_key: cond_frames - is_trainable: false - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder - params: - disable_encoder_autocast: true - n_cond_frames: 1 - n_copies: 1 - is_ae: true - encoder_config: - target: sgm.models.autoencoder.AutoencoderKLModeOnly - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - attn_type: vanilla-xformers - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - input_key: cond_aug - is_trainable: true - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: sgm.modules.diffusionmodules.model.Encoder - params: - attn_type: vanilla - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder - params: - attn_type: vanilla - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - video_kernel_size: - - 3 - - 1 - - 1 - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 30 - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - params: - sigma_max: 700.0 - guider_config: - target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider - params: - max_scale: 2.5 - min_scale: 1.0 - num_frames: 24 - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - batch2model_keys: - - num_video_frames - - image_only_indicator - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - params: - p_mean: 0.3 - p_std: 1.2 -data: - target: sgm.data.objaverse.ObjaverseSpiralDataset - params: - root_dir: /mnt/mfs/zilong.chen/Downloads/objaverse-ndd-samples - random_front: true - batch_size: 2 - num_workers: 16 - cond_aug_mean: -0.0 diff --git a/configs/instant-mesh-base.yaml b/configs/instant-mesh-base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ad4f4c0cd0d3c6f4d3038b657a41dab82c048dd1 --- /dev/null +++ b/configs/instant-mesh-base.yaml @@ -0,0 +1,22 @@ +model_config: + target: src.models.lrm_mesh.InstantMesh + params: + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + transformer_dim: 1024 + transformer_layers: 12 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 40 + rendering_samples_per_ray: 96 + grid_res: 128 + grid_scale: 2.1 + + +infer_config: + unet_path: ckpts/diffusion_pytorch_model.bin + model_path: ckpts/instant_mesh_base.ckpt + texture_resolution: 1024 + render_resolution: 512 \ No newline at end of file diff --git a/configs/instant-mesh-large.yaml b/configs/instant-mesh-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e296bc89f6d0d0649136ba2ce0e34490f76a5e41 --- /dev/null +++ b/configs/instant-mesh-large.yaml @@ -0,0 +1,22 @@ +model_config: + target: src.models.lrm_mesh.InstantMesh + params: + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + rendering_samples_per_ray: 128 + grid_res: 128 + grid_scale: 2.1 + + +infer_config: + unet_path: ckpts/diffusion_pytorch_model.bin + model_path: ckpts/instant_mesh_large.ckpt + texture_resolution: 1024 + render_resolution: 512 \ No newline at end of file diff --git a/configs/instant-nerf-base.yaml b/configs/instant-nerf-base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ded3d484751127d430891fc28eb2de664aecd5e1 --- /dev/null +++ b/configs/instant-nerf-base.yaml @@ -0,0 +1,21 @@ +model_config: + target: src.models.lrm.InstantNeRF + params: + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + transformer_dim: 1024 + transformer_layers: 12 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 40 + rendering_samples_per_ray: 96 + + +infer_config: + unet_path: ckpts/diffusion_pytorch_model.bin + model_path: ckpts/instant_nerf_base.ckpt + mesh_threshold: 10.0 + mesh_resolution: 256 + render_resolution: 384 \ No newline at end of file diff --git a/configs/instant-nerf-large.yaml b/configs/instant-nerf-large.yaml new file mode 100644 index 0000000000000000000000000000000000000000..57494b69d74ee78dca2e2cead2ef68ddfd0fd531 --- /dev/null +++ b/configs/instant-nerf-large.yaml @@ -0,0 +1,21 @@ +model_config: + target: src.models.lrm.InstantNeRF + params: + encoder_feat_dim: 768 + encoder_freeze: false + encoder_model_name: facebook/dino-vitb16 + transformer_dim: 1024 + transformer_layers: 16 + transformer_heads: 16 + triplane_low_res: 32 + triplane_high_res: 64 + triplane_dim: 80 + rendering_samples_per_ray: 128 + + +infer_config: + unet_path: ckpts/diffusion_pytorch_model.bin + model_path: ckpts/instant_nerf_large.ckpt + mesh_threshold: 10.0 + mesh_resolution: 256 + render_resolution: 384 \ No newline at end of file diff --git a/examples/bird.jpg b/examples/bird.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ac70a36ebefb87fb283f3bb95d07fe71700702a3 Binary files /dev/null and b/examples/bird.jpg differ diff --git a/examples/bubble_mart_blue.png b/examples/bubble_mart_blue.png new file mode 100644 index 0000000000000000000000000000000000000000..af870322d4a8a2f237546fbea9560bb8e5f50364 Binary files /dev/null and b/examples/bubble_mart_blue.png differ diff --git a/examples/cartoon_dinosaur.png b/examples/cartoon_dinosaur.png new file mode 100644 index 0000000000000000000000000000000000000000..598964626b767eb6470a28a68537c091fc5de2f8 Binary files /dev/null and b/examples/cartoon_dinosaur.png differ diff --git a/examples/cartoon_girl.jpg b/examples/cartoon_girl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0664c7bb99a569faee449482ece59e3c32f06642 Binary files /dev/null and b/examples/cartoon_girl.jpg differ diff --git a/examples/chair_armed.png b/examples/chair_armed.png new file mode 100644 index 0000000000000000000000000000000000000000..2ab67e95ed57fbc5ebcd7d934827fd7fb03ab3ff Binary files /dev/null and b/examples/chair_armed.png differ diff --git a/examples/chair_comfort.jpg b/examples/chair_comfort.jpg new file mode 100644 index 0000000000000000000000000000000000000000..918347fe51773d7ecaa7fb929274db8d7d5d3e19 Binary files /dev/null and b/examples/chair_comfort.jpg differ diff --git a/examples/chair_wood.jpg b/examples/chair_wood.jpg new file mode 100644 index 0000000000000000000000000000000000000000..bc60569896fb02a46185aabb85086890f0f400d7 Binary files /dev/null and b/examples/chair_wood.jpg differ diff --git a/examples/chest.jpg b/examples/chest.jpg new file mode 100644 index 0000000000000000000000000000000000000000..26ae0b145887e43b850d298b94fe54828e909492 Binary files /dev/null and b/examples/chest.jpg differ diff --git a/examples/fruit_bycycle.jpg b/examples/fruit_bycycle.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4695b19dc1078776f5e10b055c7fffcedbb984d3 Binary files /dev/null and b/examples/fruit_bycycle.jpg differ diff --git a/examples/fruit_elephant.jpg b/examples/fruit_elephant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef8eaf3b88ae0a38272b34802fe40032055afa58 Binary files /dev/null and b/examples/fruit_elephant.jpg differ diff --git a/examples/genshin_building.png b/examples/genshin_building.png new file mode 100644 index 0000000000000000000000000000000000000000..00b6a949d01283e1ae30fac4bd6040e13f18a055 Binary files /dev/null and b/examples/genshin_building.png differ diff --git a/examples/genshin_teapot.png b/examples/genshin_teapot.png new file mode 100644 index 0000000000000000000000000000000000000000..1f13a6edfe67ced810b4513117279067f0360fae Binary files /dev/null and b/examples/genshin_teapot.png differ diff --git a/examples/hatsune_miku.png b/examples/hatsune_miku.png new file mode 100644 index 0000000000000000000000000000000000000000..2fecf005fdd56a396c4894256fbb98fcc1c4dd8f Binary files /dev/null and b/examples/hatsune_miku.png differ diff --git a/examples/house2.jpg b/examples/house2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2eb8d63a6b91d5b16e729710c8b703aa5c11f9e5 Binary files /dev/null and b/examples/house2.jpg differ diff --git a/examples/mushroom_teapot.jpg b/examples/mushroom_teapot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6c767354305f5467a4c0d5f199eee2a120f4501 Binary files /dev/null and b/examples/mushroom_teapot.jpg differ diff --git a/examples/pikachu.png b/examples/pikachu.png new file mode 100644 index 0000000000000000000000000000000000000000..e7579c16957a3e13b80d53cf0a41ddfdfd47b92d Binary files /dev/null and b/examples/pikachu.png differ diff --git a/examples/plant.jpg b/examples/plant.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3519c1639c3f837d9f1147cba1172e6aaab25a23 Binary files /dev/null and b/examples/plant.jpg differ diff --git a/examples/robot.jpg b/examples/robot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..929450fba69a20389f39d46cb51d27facc1bba6d Binary files /dev/null and b/examples/robot.jpg differ diff --git a/examples/sea_turtle.png b/examples/sea_turtle.png new file mode 100644 index 0000000000000000000000000000000000000000..27c3e2a9c7d44cb33914422b410ef41cf6591433 Binary files /dev/null and b/examples/sea_turtle.png differ diff --git a/examples/skating_shoe.jpg b/examples/skating_shoe.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5f21cb1d43e9d42d2836118963fc1d2874523748 Binary files /dev/null and b/examples/skating_shoe.jpg differ diff --git a/examples/sorting_board.png b/examples/sorting_board.png new file mode 100644 index 0000000000000000000000000000000000000000..a40fb8362afce0e323dd4517bba784cc652f5f6c Binary files /dev/null and b/examples/sorting_board.png differ diff --git a/examples/sword.png b/examples/sword.png new file mode 100644 index 0000000000000000000000000000000000000000..3068cb9bdbbd9ed3c0a143fd5c741abbc58508e3 Binary files /dev/null and b/examples/sword.png differ diff --git a/examples/toy_car.jpg b/examples/toy_car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ffa72aa6c1510e200e5d640461b779d2e7bf4997 Binary files /dev/null and b/examples/toy_car.jpg differ diff --git a/examples/watermelon.png b/examples/watermelon.png new file mode 100644 index 0000000000000000000000000000000000000000..52b39917abcbd2f1eef9b7c8cf9aa602bddde1bf Binary files /dev/null and b/examples/watermelon.png differ diff --git a/examples/whitedog.png b/examples/whitedog.png new file mode 100644 index 0000000000000000000000000000000000000000..16c598a8133643898408ea806b69d5b18c53be7d Binary files /dev/null and b/examples/whitedog.png differ diff --git a/examples/x_teapot.jpg b/examples/x_teapot.jpg new file mode 100644 index 0000000000000000000000000000000000000000..4e1cb46c5541dcc4ea544864e2eeebd42dfcb18a Binary files /dev/null and b/examples/x_teapot.jpg differ diff --git a/examples/x_toyduck.jpg b/examples/x_toyduck.jpg new file mode 100644 index 0000000000000000000000000000000000000000..5e60d43bd76d7511e44568c4f9bba2a11a1a4f04 Binary files /dev/null and b/examples/x_toyduck.jpg differ diff --git a/requirements.txt b/requirements.txt index 1596346cdf34d1787bd4b8032002f1044a9f507f..cb552261e6dd882cdf9410654cb143d83f3b7fca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,48 +1,23 @@ -black==23.7.0 -chardet==5.1.0 -clip @ git+https://github.com/openai/CLIP.git -einops>=0.6.1 -fairscale>=0.4.13 -fire>=0.5.0 -fsspec>=2023.6.0 -invisible-watermark>=0.2.0 -kornia==0.6.9 -matplotlib>=3.7.2 -natsort>=8.4.0 -ninja>=1.11.1 -numpy>=1.24.4 -omegaconf>=2.3.0 -open-clip-torch>=2.20.0 -opencv-python==4.6.0.66 -pandas>=2.0.3 -pillow>=9.5.0 -pudb>=2022.1.3 -pytorch-lightning==2.0.1 -pyyaml>=6.0.1 -scipy>=1.10.1 -streamlit>=0.73.1 -tensorboardx==2.6 -timm>=0.9.2 -tokenizers==0.12.1 -torch==2.0.1 -torchaudio>=2.0.2 -torchdata==0.6.1 -torchmetrics>=1.0.1 -torchvision>=0.15.2 -tqdm>=4.65.0 -transformers==4.19.1 -# triton==2.0.0 -urllib3<1.27,>=1.25.4 -wandb>=0.15.6 -webdataset>=0.2.33 -wheel>=0.41.0 -xformers>=0.0.20 -streamlit-keyup==0.2.0 -mediapy -tyro -wget +torch==2.1.0 +torchvision==0.16.0 +torchaudio==2.1.0 +pytorch-lightning==2.1.2 +einops +omegaconf +deepspeed +torchmetrics +webdataset +accelerate +tensorboard +PyMCubes +trimesh rembg -kiui -spaces -imageio==2.19.3 -imageio-ffmpeg==0.4.7 \ No newline at end of file +transformers==4.34.1 +diffusers==0.19.3 +bitsandbytes +imageio[ffmpeg] +xatlas +plyfile +xformers==0.0.22.post7 +git+https://github.com/NVlabs/nvdiffrast/ +huggingface-hub \ No newline at end of file diff --git a/scripts/pub/V3D_512.py b/scripts/pub/V3D_512.py deleted file mode 100644 index ae1fd579348d477d395958e13f7e7002bb9be1f2..0000000000000000000000000000000000000000 --- a/scripts/pub/V3D_512.py +++ /dev/null @@ -1,317 +0,0 @@ -import math -import os -from glob import glob -from pathlib import Path -from typing import Optional - -import cv2 -import numpy as np -import torch -from einops import rearrange, repeat -from fire import Fire -import tyro -from omegaconf import OmegaConf -from PIL import Image -from torchvision.transforms import ToTensor -from mediapy import write_video -import rembg -from kiui.op import recenter -from safetensors.torch import load_file as load_safetensors -from typing import Any - -from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering -from sgm.inference.helpers import embed_watermark -from sgm.util import default, instantiate_from_config - - -def get_unique_embedder_keys_from_conditioner(conditioner): - return list(set([x.input_key for x in conditioner.embedders])) - - -def get_batch(keys, value_dict, N, T, device): - batch = {} - batch_uc = {} - - for key in keys: - if key == "fps_id": - batch[key] = ( - torch.tensor([value_dict["fps_id"]]) - .to(device) - .repeat(int(math.prod(N))) - ) - elif key == "motion_bucket_id": - batch[key] = ( - torch.tensor([value_dict["motion_bucket_id"]]) - .to(device) - .repeat(int(math.prod(N))) - ) - elif key == "cond_aug": - batch[key] = repeat( - torch.tensor([value_dict["cond_aug"]]).to(device), - "1 -> b", - b=math.prod(N), - ) - elif key == "cond_frames": - batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) - elif key == "cond_frames_without_noise": - batch[key] = repeat( - value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] - ) - else: - batch[key] = value_dict[key] - - if T is not None: - batch["num_video_frames"] = T - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - -def load_model( - config: str, - device: str, - num_frames: int, - num_steps: int, - ckpt_path: Optional[str] = None, - min_cfg: Optional[float] = None, - max_cfg: Optional[float] = None, - sigma_max: Optional[float] = None, -): - config = OmegaConf.load(config) - - config.model.params.sampler_config.params.num_steps = num_steps - config.model.params.sampler_config.params.guider_config.params.num_frames = ( - num_frames - ) - if max_cfg is not None: - config.model.params.sampler_config.params.guider_config.params.max_scale = ( - max_cfg - ) - if min_cfg is not None: - config.model.params.sampler_config.params.guider_config.params.min_scale = ( - min_cfg - ) - if sigma_max is not None: - print("Overriding sigma_max to ", sigma_max) - config.model.params.sampler_config.params.discretization_config.params.sigma_max = ( - sigma_max - ) - - config.model.params.from_scratch = False - - if ckpt_path is not None: - config.model.params.ckpt_path = str(ckpt_path) - if device == "cuda": - with torch.device(device): - model = instantiate_from_config(config.model).to(device).eval() - else: - model = instantiate_from_config(config.model).to(device).eval() - - return model, None - - -def sample_one( - input_path: str = "assets/test_image.png", # Can either be image file or folder with image files - checkpoint_path: Optional[str] = None, - num_frames: Optional[int] = None, - num_steps: Optional[int] = None, - fps_id: int = 1, - motion_bucket_id: int = 300, - cond_aug: float = 0.02, - seed: int = 23, - decoding_t: int = 24, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. - device: str = "cuda", - output_folder: Optional[str] = None, - noise: torch.Tensor = None, - save: bool = False, - cached_model: Any = None, - border_ratio: float = 0.3, - min_guidance_scale: float = 3.5, - max_guidance_scale: float = 3.5, - sigma_max: float = None, - ignore_alpha: bool = False, -): - model_config = "scripts/pub/configs/V3D_512.yaml" - num_frames = OmegaConf.load( - model_config - ).model.params.sampler_config.params.guider_config.params.num_frames - print("Detected num_frames:", num_frames) - num_steps = default(num_steps, 25) - output_folder = default(output_folder, f"outputs/V3D_512") - decoding_t = min(decoding_t, num_frames) - - sd = load_safetensors("./ckpts/svd_xt.safetensors") - clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml") - clip_model = instantiate_from_config(clip_model_config).eval() - clip_sd = dict() - for k, v in sd.items(): - if "conditioner.embedders.0" in k: - clip_sd[k.replace("conditioner.embedders.0.", "")] = v - clip_model.load_state_dict(clip_sd) - clip_model = clip_model.to(device) - - ae_model_config = OmegaConf.load("configs/ae/video.yaml") - ae_model = instantiate_from_config(ae_model_config).eval() - encoder_sd = dict() - for k, v in sd.items(): - if "first_stage_model" in k: - encoder_sd[k.replace("first_stage_model.", "")] = v - ae_model.load_state_dict(encoder_sd) - ae_model = ae_model.to(device) - - if cached_model is None: - model, filter = load_model( - model_config, - device, - num_frames, - num_steps, - ckpt_path=checkpoint_path, - min_cfg=min_guidance_scale, - max_cfg=max_guidance_scale, - sigma_max=sigma_max, - ) - else: - model = cached_model - torch.manual_seed(seed) - - need_return = True - path = Path(input_path) - if path.is_file(): - if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): - all_img_paths = [input_path] - else: - raise ValueError("Path is not valid image file.") - elif path.is_dir(): - all_img_paths = sorted( - [ - f - for f in path.iterdir() - if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] - ] - ) - need_return = False - if len(all_img_paths) == 0: - raise ValueError("Folder does not contain any images.") - else: - raise ValueError - - for input_path in all_img_paths: - with Image.open(input_path) as image: - # if image.mode == "RGBA": - # image = image.convert("RGB") - w, h = image.size - - if border_ratio > 0: - if image.mode != "RGBA" or ignore_alpha: - image = image.convert("RGB") - image = np.asarray(image) - carved_image = rembg.remove(image) # [H, W, 4] - else: - image = np.asarray(image) - carved_image = image - mask = carved_image[..., -1] > 0 - image = recenter(carved_image, mask, border_ratio=border_ratio) - image = image.astype(np.float32) / 255.0 - if image.shape[-1] == 4: - image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) - image = Image.fromarray((image * 255).astype(np.uint8)) - else: - print("Ignore border ratio") - image = image.resize((512, 512)) - - image = ToTensor()(image) - image = image * 2.0 - 1.0 - - image = image.unsqueeze(0).to(device) - H, W = image.shape[2:] - assert image.shape[1] == 3 - F = 8 - C = 4 - shape = (num_frames, C, H // F, W // F) - - value_dict = {} - value_dict["motion_bucket_id"] = motion_bucket_id - value_dict["fps_id"] = fps_id - value_dict["cond_aug"] = cond_aug - value_dict["cond_frames_without_noise"] = clip_model(image) - value_dict["cond_frames"] = ae_model.encode(image) - value_dict["cond_frames"] += cond_aug * torch.randn_like( - value_dict["cond_frames"] - ) - value_dict["cond_aug"] = cond_aug - - with torch.no_grad(): - with torch.autocast(device): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [1, num_frames], - T=num_frames, - device=device, - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=[ - "cond_frames", - "cond_frames_without_noise", - ], - ) - - for k in ["crossattn", "concat"]: - uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) - uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) - c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) - c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) - - randn = torch.randn(shape, device=device) if noise is None else noise - randn = randn.to(device) - - additional_model_inputs = {} - additional_model_inputs["image_only_indicator"] = torch.zeros( - 2, num_frames - ).to(device) - additional_model_inputs["num_video_frames"] = batch["num_video_frames"] - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) - model.en_and_decode_n_samples_a_time = decoding_t - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - os.makedirs(output_folder, exist_ok=True) - base_count = len(glob(os.path.join(output_folder, "*.mp4"))) - video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") - # writer = cv2.VideoWriter( - # video_path, - # cv2.VideoWriter_fourcc(*"MP4V"), - # fps_id + 1, - # (samples.shape[-1], samples.shape[-2]), - # ) - - frames = ( - (rearrange(samples, "t c h w -> t h w c") * 255) - .cpu() - .numpy() - .astype(np.uint8) - ) - - if save: - write_video(video_path, frames, fps=3) - - images = [] - for frame in frames: - images.append(Image.fromarray(frame)) - - if need_return: - return images, model - - -if __name__ == "__main__": - tyro.cli(sample_one) diff --git a/scripts/pub/configs/V3D_512.yaml b/scripts/pub/configs/V3D_512.yaml deleted file mode 100644 index aee4108e741a50a75336d277e72a72d9b1df8ade..0000000000000000000000000000000000000000 --- a/scripts/pub/configs/V3D_512.yaml +++ /dev/null @@ -1,161 +0,0 @@ -model: - base_learning_rate: 1.0e-04 - target: sgm.models.video_diffusion.DiffusionEngine - params: - ckpt_path: ckpts/V3D_512.ckpt - scale_factor: 0.18215 - disable_first_stage_autocast: true - input_key: latents - log_keys: [] - scheduler_config: - target: sgm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: - - 1 - cycle_lengths: - - 10000000000000 - f_start: - - 1.0e-06 - f_max: - - 1.0 - f_min: - - 1.0 - denoiser_config: - target: sgm.modules.diffusionmodules.denoiser.Denoiser - params: - scaling_config: - target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise - network_config: - target: sgm.modules.diffusionmodules.video_model.VideoUNet - params: - adm_in_channels: 768 - num_classes: sequential - use_checkpoint: true - in_channels: 8 - out_channels: 4 - model_channels: 320 - attention_resolutions: - - 4 - - 2 - - 1 - num_res_blocks: 2 - channel_mult: - - 1 - - 2 - - 4 - - 4 - num_head_channels: 64 - use_linear_in_transformer: true - transformer_depth: 1 - context_dim: 1024 - spatial_transformer_attn_type: softmax-xformers - extra_ff_mix_layer: true - use_spatial_context: true - merge_strategy: learned_with_images - video_kernel_size: - - 3 - - 1 - - 1 - conditioner_config: - target: sgm.modules.GeneralConditioner - params: - emb_models: - - is_trainable: false - ucg_rate: 0.2 - input_key: cond_frames_without_noise - target: sgm.modules.encoders.modules.IdentityEncoder - - input_key: fps_id - is_trainable: true - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - input_key: motion_bucket_id - is_trainable: true - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - - input_key: cond_frames - is_trainable: false - ucg_rate: 0.2 - target: sgm.modules.encoders.modules.IdentityEncoder - - input_key: cond_aug - is_trainable: true - target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND - params: - outdim: 256 - first_stage_config: - target: sgm.models.autoencoder.AutoencodingEngine - params: - loss_config: - target: torch.nn.Identity - regularizer_config: - target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer - encoder_config: - target: sgm.modules.diffusionmodules.model.Encoder - params: - attn_type: vanilla - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - decoder_config: - target: sgm.modules.autoencoding.temporal_ae.VideoDecoder - params: - attn_type: vanilla - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - video_kernel_size: - - 3 - - 1 - - 1 - sampler_config: - target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler - params: - num_steps: 30 - discretization_config: - target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization - params: - sigma_max: 700.0 - guider_config: - target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider - params: - max_scale: 3.5 - min_scale: 3.5 - num_frames: 18 - loss_fn_config: - target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss - params: - batch2model_keys: - - num_video_frames - - image_only_indicator - loss_weighting_config: - target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting - params: - sigma_data: 1.0 - sigma_sampler_config: - target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling - params: - p_mean: 1.5 - p_std: 2.0 \ No newline at end of file diff --git a/scripts/tests/attention.py b/scripts/tests/attention.py deleted file mode 100644 index d7c3f7c8da27c577a7ce0ea3a01ab7f9e9c1baa2..0000000000000000000000000000000000000000 --- a/scripts/tests/attention.py +++ /dev/null @@ -1,319 +0,0 @@ -import einops -import torch -import torch.nn.functional as F -import torch.utils.benchmark as benchmark -from torch.backends.cuda import SDPBackend - -from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer - - -def benchmark_attn(): - # Lets define a helpful benchmarking function: - # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html - device = "cuda" if torch.cuda.is_available() else "cpu" - - def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - return t0.blocked_autorange().mean * 1e6 - - # Lets define the hyper-parameters of our input - batch_size = 32 - max_sequence_len = 1024 - num_heads = 32 - embed_dimension = 32 - - dtype = torch.float16 - - query = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - key = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - value = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - - print(f"q/k/v shape:", query.shape, key.shape, value.shape) - - # Lets explore the speed of each of the 3 implementations - from torch.backends.cuda import SDPBackend, sdp_kernel - - # Helpful arguments mapper - backend_map = { - SDPBackend.MATH: { - "enable_math": True, - "enable_flash": False, - "enable_mem_efficient": False, - }, - SDPBackend.FLASH_ATTENTION: { - "enable_math": False, - "enable_flash": True, - "enable_mem_efficient": False, - }, - SDPBackend.EFFICIENT_ATTENTION: { - "enable_math": False, - "enable_flash": False, - "enable_mem_efficient": True, - }, - } - - from torch.profiler import ProfilerActivity, profile, record_function - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - print( - f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("Default detailed stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - print( - f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - with sdp_kernel(**backend_map[SDPBackend.MATH]): - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("Math implmentation stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): - try: - print( - f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - except RuntimeError: - print("FlashAttention is not supported. See warnings for reasons.") - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("FlashAttention stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): - try: - print( - f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - except RuntimeError: - print("EfficientAttention is not supported. See warnings for reasons.") - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("EfficientAttention stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - -def run_model(model, x, context): - return model(x, context) - - -def benchmark_transformer_blocks(): - device = "cuda" if torch.cuda.is_available() else "cpu" - import torch.utils.benchmark as benchmark - - def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - return t0.blocked_autorange().mean * 1e6 - - checkpoint = True - compile = False - - batch_size = 32 - h, w = 64, 64 - context_len = 77 - embed_dimension = 1024 - context_dim = 1024 - d_head = 64 - - transformer_depth = 4 - - n_heads = embed_dimension // d_head - - dtype = torch.float16 - - model_native = SpatialTransformer( - embed_dimension, - n_heads, - d_head, - context_dim=context_dim, - use_linear=True, - use_checkpoint=checkpoint, - attn_type="softmax", - depth=transformer_depth, - sdp_backend=SDPBackend.FLASH_ATTENTION, - ).to(device) - model_efficient_attn = SpatialTransformer( - embed_dimension, - n_heads, - d_head, - context_dim=context_dim, - use_linear=True, - depth=transformer_depth, - use_checkpoint=checkpoint, - attn_type="softmax-xformers", - ).to(device) - if not checkpoint and compile: - print("compiling models") - model_native = torch.compile(model_native) - model_efficient_attn = torch.compile(model_efficient_attn) - - x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) - c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) - - from torch.profiler import ProfilerActivity, profile, record_function - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - with torch.autocast("cuda"): - print( - f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" - ) - print( - f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" - ) - - print(75 * "+") - print("NATIVE") - print(75 * "+") - torch.cuda.reset_peak_memory_stats() - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("NativeAttention stats"): - for _ in range(25): - model_native(x, c) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") - - print(75 * "+") - print("Xformers") - print(75 * "+") - torch.cuda.reset_peak_memory_stats() - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("xformers stats"): - for _ in range(25): - model_efficient_attn(x, c) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") - - -def test01(): - # conv1x1 vs linear - from sgm.util import count_params - - conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() - print(count_params(conv)) - linear = torch.nn.Linear(3, 32).cuda() - print(count_params(linear)) - - print(conv.weight.shape) - - # use same initialization - linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) - linear.bias = torch.nn.Parameter(conv.bias) - - print(linear.weight.shape) - - x = torch.randn(11, 3, 64, 64).cuda() - - xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() - print(xr.shape) - out_linear = linear(xr) - print(out_linear.mean(), out_linear.shape) - - out_conv = conv(x) - print(out_conv.mean(), out_conv.shape) - print("done with test01.\n") - - -def test02(): - # try cosine flash attention - import time - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.benchmark = True - print("testing cosine flash attention...") - DIM = 1024 - SEQLEN = 4096 - BS = 16 - - print(" softmax (vanilla) first...") - model = BasicTransformerBlock( - dim=DIM, - n_heads=16, - d_head=64, - dropout=0.0, - context_dim=None, - attn_mode="softmax", - ).cuda() - try: - x = torch.randn(BS, SEQLEN, DIM).cuda() - tic = time.time() - y = model(x) - toc = time.time() - print(y.shape, toc - tic) - except RuntimeError as e: - # likely oom - print(str(e)) - - print("\n now flash-cosine...") - model = BasicTransformerBlock( - dim=DIM, - n_heads=16, - d_head=64, - dropout=0.0, - context_dim=None, - attn_mode="flash-cosine", - ).cuda() - x = torch.randn(BS, SEQLEN, DIM).cuda() - tic = time.time() - y = model(x) - toc = time.time() - print(y.shape, toc - tic) - print("done with test02.\n") - - -if __name__ == "__main__": - # test01() - # test02() - # test03() - - # benchmark_attn() - benchmark_transformer_blocks() - - print("done.") diff --git a/scripts/util/detection/nsfw_and_watermark_dectection.py b/scripts/util/detection/nsfw_and_watermark_dectection.py deleted file mode 100644 index 1096b8177d8e3dbcf8e913f924e98d5ce58cb120..0000000000000000000000000000000000000000 --- a/scripts/util/detection/nsfw_and_watermark_dectection.py +++ /dev/null @@ -1,110 +0,0 @@ -import os - -import clip -import numpy as np -import torch -import torchvision.transforms as T -from PIL import Image - -RESOURCES_ROOT = "scripts/util/detection/" - - -def predict_proba(X, weights, biases): - logits = X @ weights.T + biases - proba = np.where( - logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) - ) - return proba.T - - -def load_model_weights(path: str): - model_weights = np.load(path) - return model_weights["weights"], model_weights["biases"] - - -def clip_process_images(images: torch.Tensor) -> torch.Tensor: - min_size = min(images.shape[-2:]) - return T.Compose( - [ - T.CenterCrop(min_size), # TODO: this might affect the watermark, check this - T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), - T.Normalize( - (0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711), - ), - ] - )(images) - - -class DeepFloydDataFiltering(object): - def __init__( - self, verbose: bool = False, device: torch.device = torch.device("cpu") - ): - super().__init__() - self.verbose = verbose - self._device = None - self.clip_model, _ = clip.load("ViT-L/14", device=device) - self.clip_model.eval() - - self.cpu_w_weights, self.cpu_w_biases = load_model_weights( - os.path.join(RESOURCES_ROOT, "w_head_v1.npz") - ) - self.cpu_p_weights, self.cpu_p_biases = load_model_weights( - os.path.join(RESOURCES_ROOT, "p_head_v1.npz") - ) - self.w_threshold, self.p_threshold = 0.5, 0.5 - - @torch.inference_mode() - def __call__(self, images: torch.Tensor) -> torch.Tensor: - imgs = clip_process_images(images) - if self._device is None: - self._device = next(p for p in self.clip_model.parameters()).device - image_features = self.clip_model.encode_image(imgs.to(self._device)) - image_features = image_features.detach().cpu().numpy().astype(np.float16) - p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) - w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) - print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None - query = p_pred > self.p_threshold - if query.sum() > 0: - print(f"Hit for p_threshold: {p_pred}") if self.verbose else None - images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) - query = w_pred > self.w_threshold - if query.sum() > 0: - print(f"Hit for w_threshold: {w_pred}") if self.verbose else None - images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) - return images - - -def load_img(path: str) -> torch.Tensor: - image = Image.open(path) - if not image.mode == "RGB": - image = image.convert("RGB") - image_transforms = T.Compose( - [ - T.ToTensor(), - ] - ) - return image_transforms(image)[None, ...] - - -def test(root): - from einops import rearrange - - filter = DeepFloydDataFiltering(verbose=True) - for p in os.listdir((root)): - print(f"running on {p}...") - img = load_img(os.path.join(root, p)) - filtered_img = filter(img) - filtered_img = rearrange( - 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" - ).astype(np.uint8) - Image.fromarray(filtered_img).save( - os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") - ) - - -if __name__ == "__main__": - import fire - - fire.Fire(test) - print("done.") diff --git a/scripts/util/detection/p_head_v1.npz b/scripts/util/detection/p_head_v1.npz deleted file mode 100644 index c1a824795d85811de3192d8ac20403444e19510b..0000000000000000000000000000000000000000 --- a/scripts/util/detection/p_head_v1.npz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b4653a64d5f85d8d4c5f6c5ec175f1c5c5e37db8f38d39b2ed8b5979da7fdc76 -size 3588 diff --git a/scripts/util/detection/w_head_v1.npz b/scripts/util/detection/w_head_v1.npz deleted file mode 100644 index 57789e17153038c529439b38f9a540ba0cb8bbac..0000000000000000000000000000000000000000 --- a/scripts/util/detection/w_head_v1.npz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b6af23687aa347073e692025f405ccc48c14aadc5dbe775b3312041006d496d1 -size 3588 diff --git a/sgm/__init__.py b/sgm/__init__.py deleted file mode 100644 index 24bc84af8b1041de34b9816e0507cb1ac207bd13..0000000000000000000000000000000000000000 --- a/sgm/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .models import AutoencodingEngine, DiffusionEngine -from .util import get_configs_path, instantiate_from_config - -__version__ = "0.1.0" diff --git a/sgm/data/__init__.py b/sgm/data/__init__.py deleted file mode 100644 index 7664a25c655c376bd1a7b0ccbaca7b983a2bf9ad..0000000000000000000000000000000000000000 --- a/sgm/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataset import StableDataModuleFromConfig diff --git a/sgm/data/cam_utils.py b/sgm/data/cam_utils.py deleted file mode 100644 index 6d44b38721dafc771c092887d93726b38e1ec0a6..0000000000000000000000000000000000000000 --- a/sgm/data/cam_utils.py +++ /dev/null @@ -1,1253 +0,0 @@ -''' -Common camera utilities -''' - -import math -import numpy as np -import torch -import torch.nn as nn -from pytorch3d.renderer import PerspectiveCameras -from pytorch3d.renderer.cameras import look_at_view_transform -from pytorch3d.renderer.implicit.raysampling import _xy_to_ray_bundle - -class RelativeCameraLoader(nn.Module): - def __init__(self, - query_batch_size=1, - rand_query=True, - relative=True, - center_at_origin=False, - ): - super().__init__() - - self.query_batch_size = query_batch_size - self.rand_query = rand_query - self.relative = relative - self.center_at_origin = center_at_origin - - def plot_cameras(self, cameras_1, cameras_2): - ''' - Helper function to plot cameras - - Args: - cameras_1 (PyTorch3D camera): cameras object to plot - cameras_2 (PyTorch3D camera): cameras object to plot - ''' - from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene - import plotly.graph_objects as go - plotlyplot = plot_scene( - { - 'scene_batch': { - 'cameras': cameras_1.to('cpu'), - 'rel_cameras': cameras_2.to('cpu'), - } - }, - camera_scale=.5,#0.05, - pointcloud_max_points=10000, - pointcloud_marker_size=1.0, - raybundle_max_rays=100 - ) - plotlyplot.show() - - def concat_cameras(self, camera_list): - ''' - Returns a concatenation of a list of cameras - - Args: - camera_list (List[PyTorch3D camera]): a list of PyTorch3D cameras - ''' - R_list, T_list, f_list, c_list, size_list = [], [], [], [], [] - for cameras in camera_list: - R_list.append(cameras.R) - T_list.append(cameras.T) - f_list.append(cameras.focal_length) - c_list.append(cameras.principal_point) - size_list.append(cameras.image_size) - - camera_slice = PerspectiveCameras( - R = torch.cat(R_list), - T = torch.cat(T_list), - focal_length = torch.cat(f_list), - principal_point = torch.cat(c_list), - image_size = torch.cat(size_list), - device = camera_list[0].device, - ) - return camera_slice - - def get_camera_slice(self, scene_cameras, indices): - ''' - Return a subset of cameras from a super set given indices - - Args: - scene_cameras (PyTorch3D Camera): cameras object - indices (tensor or List): a flat list or tensor of indices - - Returns: - camera_slice (PyTorch3D Camera) - cameras subset - ''' - camera_slice = PerspectiveCameras( - R = scene_cameras.R[indices], - T = scene_cameras.T[indices], - focal_length = scene_cameras.focal_length[indices], - principal_point = scene_cameras.principal_point[indices], - image_size = scene_cameras.image_size[indices], - device = scene_cameras.device, - ) - return camera_slice - - - def get_relative_camera(self, scene_cameras:PerspectiveCameras, query_idx, center_at_origin=False): - """ - Transform context cameras relative to a base query camera - - Args: - scene_cameras (PyTorch3D Camera): cameras object - query_idx (tensor or List): a length 1 list defining query idx - - Returns: - cams_relative (PyTorch3D Camera): cameras object relative to query camera - """ - - query_camera = self.get_camera_slice(scene_cameras, query_idx) - query_world2view = query_camera.get_world_to_view_transform() - all_world2view = scene_cameras.get_world_to_view_transform() - - if center_at_origin: - identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=query_camera.T) - else: - T = torch.zeros((1, 3)) - identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=T) - - identity_world2view = identity_cam.get_world_to_view_transform() - - # compose the relative transformation as g_i^{-1} g_j - relative_world2view = identity_world2view.inverse().compose(all_world2view) - - # generate a camera from the relative transform - relative_matrix = relative_world2view.get_matrix() - cams_relative = PerspectiveCameras( - R = relative_matrix[:, :3, :3], - T = relative_matrix[:, 3, :3], - focal_length = scene_cameras.focal_length, - principal_point = scene_cameras.principal_point, - image_size = scene_cameras.image_size, - device = scene_cameras.device, - ) - return cams_relative - - def forward(self, scene_cameras, scene_rgb=None, scene_masks=None, query_idx=None, context_size=3, context_idx=None, return_context=False): - ''' - Return a sampled batch of query and context cameras (used in training) - - Args: - scene_cameras (PyTorch3D Camera): a batch of PyTorch3D cameras - scene_rgb (Tensor): a batch of rgb - scene_masks (Tensor): a batch of masks (optional) - query_idx (List or Tensor): desired query idx (optional) - context_size (int): number of views for context - - Returns: - query_cameras, query_rgb, query_masks: random query view - context_cameras, context_rgb, context_masks: context views - ''' - - if query_idx is None: - query_idx = [0] - if self.rand_query: - rand = torch.randperm(len(scene_cameras)) - query_idx = rand[:1] - - if context_idx is None: - rand = torch.randperm(len(scene_cameras)) - context_idx = rand[:context_size] - - - if self.relative: - rel_cameras = self.get_relative_camera(scene_cameras, query_idx, center_at_origin=self.center_at_origin) - else: - rel_cameras = scene_cameras - - query_cameras = self.get_camera_slice(rel_cameras, query_idx) - query_rgb = None - if scene_rgb is not None: - query_rgb = scene_rgb[query_idx] - query_masks = None - if scene_masks is not None: - query_masks = scene_masks[query_idx] - - context_cameras = self.get_camera_slice(rel_cameras, context_idx) - context_rgb = None - if scene_rgb is not None: - context_rgb = scene_rgb[context_idx] - context_masks = None - if scene_masks is not None: - context_masks = scene_masks[context_idx] - - if return_context: - return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks, context_idx - return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks - - -def get_interpolated_path(cameras: PerspectiveCameras, n=50, method='circle', theta_offset_max=0.0): - ''' - Given a camera object containing a set of cameras, fit a circle and get - interpolated cameras - - Args: - cameras (PyTorch3D Camera): input camera object - n (int): length of cameras in new path - method (str): 'circle' - theta_offset_max (int): max camera jitter in radians - - Returns: - path_cameras (PyTorch3D Camera): interpolated cameras - ''' - device = cameras.device - cameras = cameras.cpu() - - if method == 'circle': - - #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ - #@ Fit plane - P = cameras.get_camera_center().cpu() - P_mean = P.mean(axis=0) - P_centered = P - P_mean - U,s,V = torch.linalg.svd(P_centered) - normal = V[2,:] - if (normal*2 - P_mean).norm() < (normal - P_mean).norm(): - normal = - normal - d = -torch.dot(P_mean, normal) # d = - - - #@ Project pts to plane - P_xy = rodrigues_rot(P_centered, normal, torch.tensor([0.0,0.0,1.0])) - - #@ Fit circle in 2D - xc, yc, r = fit_circle_2d(P_xy[:,0], P_xy[:,1]) - t = torch.linspace(0, 2*math.pi, 100) - xx = xc + r*torch.cos(t) - yy = yc + r*torch.sin(t) - - #@ Project circle to 3D - C = rodrigues_rot(torch.tensor([xc,yc,0.0]), torch.tensor([0.0,0.0,1.0]), normal) + P_mean - C = C.flatten() - - #@ Get pts n 3D - t = torch.linspace(0, 2*math.pi, n) - u = P[0] - C - new_camera_centers = generate_circle_by_vectors(t, C, r, normal, u) - - #@ OPTIONAL THETA OFFSET - if theta_offset_max > 0.0: - aug_theta = (torch.rand((new_camera_centers.shape[0])) * (2*theta_offset_max)) - theta_offset_max - new_camera_centers = rodrigues_rot2(new_camera_centers, normal, aug_theta) - - #@ Get camera look at - new_camera_look_at = get_nearest_centroid(cameras) - - #@ Get R T - up_vec = -normal - R, T = look_at_view_transform(eye=new_camera_centers, at=new_camera_look_at.unsqueeze(0), up=up_vec.unsqueeze(0), device=cameras.device) - else: - raise NotImplementedError - - c = (cameras.principal_point).mean(dim=0, keepdim=True).expand(R.shape[0],-1) - f = (cameras.focal_length).mean(dim=0, keepdim=True).expand(R.shape[0],-1) - image_size = cameras.image_size[:1].expand(R.shape[0],-1) - - - path_cameras = PerspectiveCameras(R=R,T=T,focal_length=f,principal_point=c,image_size=image_size, device=device) - cameras = cameras.to(device) - return path_cameras - -def np_normalize(vec, axis=-1): - vec = vec / (np.linalg.norm(vec, axis=axis, keepdims=True) + 1e-9) - return vec - - -#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ -#------------------------------------------------------------------------------- -# Generate points on circle -# P(t) = r*cos(t)*u + r*sin(t)*(n x u) + C -#------------------------------------------------------------------------------- -def generate_circle_by_vectors(t, C, r, n, u): - n = n/torch.linalg.norm(n) - u = u/torch.linalg.norm(u) - P_circle = r*torch.cos(t)[:,None]*u + r*torch.sin(t)[:,None]*torch.cross(n,u) + C - return P_circle - -#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ -#------------------------------------------------------------------------------- -# FIT CIRCLE 2D -# - Find center [xc, yc] and radius r of circle fitting to set of 2D points -# - Optionally specify weights for points -# -# - Implicit circle function: -# (x-xc)^2 + (y-yc)^2 = r^2 -# (2*xc)*x + (2*yc)*y + (r^2-xc^2-yc^2) = x^2+y^2 -# c[0]*x + c[1]*y + c[2] = x^2+y^2 -# -# - Solution by method of least squares: -# A*c = b, c' = argmin(||A*c - b||^2) -# A = [x y 1], b = [x^2+y^2] -#------------------------------------------------------------------------------- -def fit_circle_2d(x, y, w=[]): - - A = torch.stack([x, y, torch.ones(len(x))]).T - b = x**2 + y**2 - - # Modify A,b for weighted least squares - if len(w) == len(x): - W = torch.diag(w) - A = torch.dot(W,A) - b = torch.dot(W,b) - - # Solve by method of least squares - c = torch.linalg.lstsq(A,b,rcond=None)[0] - - # Get circle parameters from solution c - xc = c[0]/2 - yc = c[1]/2 - r = torch.sqrt(c[2] + xc**2 + yc**2) - return xc, yc, r - -#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ -#------------------------------------------------------------------------------- -# RODRIGUES ROTATION -# - Rotate given points based on a starting and ending vector -# - Axis k and angle of rotation theta given by vectors n0,n1 -# P_rot = P*cos(theta) + (k x P)*sin(theta) + k**(1-cos(theta)) -#------------------------------------------------------------------------------- -def rodrigues_rot(P, n0, n1): - - # If P is only 1d array (coords of single point), fix it to be matrix - if P.ndim == 1: - P = P[None,...] - - # Get vector of rotation k and angle theta - n0 = n0/torch.linalg.norm(n0) - n1 = n1/torch.linalg.norm(n1) - k = torch.cross(n0,n1) - k = k/torch.linalg.norm(k) - theta = torch.arccos(torch.dot(n0,n1)) - - # Compute rotated points - P_rot = torch.zeros((len(P),3)) - for i in range(len(P)): - P_rot[i] = P[i]*torch.cos(theta) + torch.cross(k,P[i])*torch.sin(theta) + k*torch.dot(k,P[i])*(1-torch.cos(theta)) - - return P_rot - -def rodrigues_rot2(P, n1, theta): - ''' - Rotate points P wrt axis k by theta radians - ''' - - # If P is only 1d array (coords of single point), fix it to be matrix - if P.ndim == 1: - P = P[None,...] - - k = torch.cross(P, n1.unsqueeze(0)) - k = k/torch.linalg.norm(k) - - # Compute rotated points - P_rot = torch.zeros((len(P),3)) - for i in range(len(P)): - P_rot[i] = P[i]*torch.cos(theta[i]) + torch.cross(k[i],P[i])*torch.sin(theta[i]) + k[i]*torch.dot(k[i],P[i])*(1-torch.cos(theta[i])) - - return P_rot - -#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ -#------------------------------------------------------------------------------- -# ANGLE BETWEEN -# - Get angle between vectors u,v with sign based on plane with unit normal n -#------------------------------------------------------------------------------- -def angle_between(u, v, n=None): - if n is None: - return torch.arctan2(torch.linalg.norm(torch.cross(u,v)), torch.dot(u,v)) - else: - return torch.arctan2(torch.dot(n,torch.cross(u,v)), torch.dot(u,v)) - -#@ https://www.crewes.org/Documents/ResearchReports/2010/CRR201032.pdf -def get_nearest_centroid(cameras: PerspectiveCameras): - ''' - Given PyTorch3D cameras, find the nearest point along their principal ray - ''' - - #@ GET CAMERA CENTERS AND DIRECTIONS - camera_centers = cameras.get_camera_center() - - c_mean = (cameras.principal_point).mean(dim=0) - xy_grid = c_mean.unsqueeze(0).unsqueeze(0) - ray_vis = _xy_to_ray_bundle(cameras, xy_grid.expand(len(cameras),-1,-1), 1.0, 15.0, 20, True) - camera_directions = ray_vis.directions - - #@ CONSTRUCT MATRICIES - A = torch.zeros((3*len(cameras)), len(cameras)+3) - b = torch.zeros((3*len(cameras), 1)) - A[:,:3] = torch.eye(3).repeat(len(cameras),1) - for ci in range(len(camera_directions)): - A[3*ci:3*ci+3, ci+3] = -camera_directions[ci] - b[3*ci:3*ci+3, 0] = camera_centers[ci] - #' A (3*N, 3*N+3) b (3*N, 1) - - #@ SVD - U, s, VT = torch.linalg.svd(A) - Sinv = torch.diag(1/s) - if len(s) < 3*len(cameras): - Sinv = torch.cat((Sinv, torch.zeros((Sinv.shape[0], 3*len(cameras) - Sinv.shape[1]), device=Sinv.device)), dim=1) - x = torch.matmul(VT.T, torch.matmul(Sinv,torch.matmul(U.T, b))) - - centroid = x[:3,0] - return centroid - - -def get_angles(target_camera: PerspectiveCameras, context_cameras: PerspectiveCameras, centroid=None): - ''' - Get angles between cameras wrt a centroid - - Args: - target_camera (Pytorch3D Camera): a camera object with a single camera - context_cameras (PyTorch3D Camera): a camera object - - Returns: - theta_deg (Tensor): a tensor containing angles in degrees - ''' - a1 = target_camera.get_camera_center() - b1 = context_cameras.get_camera_center() - - a = a1 - centroid.unsqueeze(0) - a = a.expand(len(context_cameras), -1) - b = b1 - centroid.unsqueeze(0) - - ab_dot = (a*b).sum(dim=-1) - theta = torch.acos((ab_dot)/(torch.linalg.norm(a, dim=-1) * torch.linalg.norm(b, dim=-1))) - theta_deg = theta * 180 / math.pi - - return theta_deg - - -import math -from typing import List, Literal, Optional, Tuple - -import numpy as np -import torch -from jaxtyping import Float -from numpy.typing import NDArray -from torch import Tensor - -_EPS = np.finfo(float).eps * 4.0 - - -def unit_vector(data: NDArray, axis: Optional[int] = None) -> np.ndarray: - """Return ndarray normalized by length, i.e. Euclidean norm, along axis. - - Args: - axis: the axis along which to normalize into unit vector - out: where to write out the data to. If None, returns a new np ndarray - """ - data = np.array(data, dtype=np.float64, copy=True) - if data.ndim == 1: - data /= math.sqrt(np.dot(data, data)) - return data - length = np.atleast_1d(np.sum(data * data, axis)) - np.sqrt(length, length) - if axis is not None: - length = np.expand_dims(length, axis) - data /= length - return data - - -def quaternion_from_matrix(matrix: NDArray, isprecise: bool = False) -> np.ndarray: - """Return quaternion from rotation matrix. - - Args: - matrix: rotation matrix to obtain quaternion - isprecise: if True, input matrix is assumed to be precise rotation matrix and a faster algorithm is used. - """ - M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4] - if isprecise: - q = np.empty((4,)) - t = np.trace(M) - if t > M[3, 3]: - q[0] = t - q[3] = M[1, 0] - M[0, 1] - q[2] = M[0, 2] - M[2, 0] - q[1] = M[2, 1] - M[1, 2] - else: - i, j, k = 1, 2, 3 - if M[1, 1] > M[0, 0]: - i, j, k = 2, 3, 1 - if M[2, 2] > M[i, i]: - i, j, k = 3, 1, 2 - t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] - q[i] = t - q[j] = M[i, j] + M[j, i] - q[k] = M[k, i] + M[i, k] - q[3] = M[k, j] - M[j, k] - q *= 0.5 / math.sqrt(t * M[3, 3]) - else: - m00 = M[0, 0] - m01 = M[0, 1] - m02 = M[0, 2] - m10 = M[1, 0] - m11 = M[1, 1] - m12 = M[1, 2] - m20 = M[2, 0] - m21 = M[2, 1] - m22 = M[2, 2] - # symmetric matrix K - K = [ - [m00 - m11 - m22, 0.0, 0.0, 0.0], - [m01 + m10, m11 - m00 - m22, 0.0, 0.0], - [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0], - [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22], - ] - K = np.array(K) - K /= 3.0 - # quaternion is eigenvector of K that corresponds to largest eigenvalue - w, V = np.linalg.eigh(K) - q = V[np.array([3, 0, 1, 2]), np.argmax(w)] - if q[0] < 0.0: - np.negative(q, q) - return q - - -def quaternion_slerp( - quat0: NDArray, quat1: NDArray, fraction: float, spin: int = 0, shortestpath: bool = True -) -> np.ndarray: - """Return spherical linear interpolation between two quaternions. - Args: - quat0: first quaternion - quat1: second quaternion - fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1) - spin: how much of an additional spin to place on the interpolation - shortestpath: whether to return the short or long path to rotation - """ - q0 = unit_vector(quat0[:4]) - q1 = unit_vector(quat1[:4]) - if q0 is None or q1 is None: - raise ValueError("Input quaternions invalid.") - if fraction == 0.0: - return q0 - if fraction == 1.0: - return q1 - d = np.dot(q0, q1) - if abs(abs(d) - 1.0) < _EPS: - return q0 - if shortestpath and d < 0.0: - # invert rotation - d = -d - np.negative(q1, q1) - angle = math.acos(d) + spin * math.pi - if abs(angle) < _EPS: - return q0 - isin = 1.0 / math.sin(angle) - q0 *= math.sin((1.0 - fraction) * angle) * isin - q1 *= math.sin(fraction * angle) * isin - q0 += q1 - return q0 - - -def quaternion_matrix(quaternion: NDArray) -> np.ndarray: - """Return homogeneous rotation matrix from quaternion. - - Args: - quaternion: value to convert to matrix - """ - q = np.array(quaternion, dtype=np.float64, copy=True) - n = np.dot(q, q) - if n < _EPS: - return np.identity(4) - q *= math.sqrt(2.0 / n) - q = np.outer(q, q) - return np.array( - [ - [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0], - [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0], - [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0], - [0.0, 0.0, 0.0, 1.0], - ] - ) - - -def get_interpolated_poses(pose_a: NDArray, pose_b: NDArray, steps: int = 10) -> List[float]: - """Return interpolation of poses with specified number of steps. - Args: - pose_a: first pose - pose_b: second pose - steps: number of steps the interpolated pose path should contain - """ - - quat_a = quaternion_from_matrix(pose_a[:3, :3]) - quat_b = quaternion_from_matrix(pose_b[:3, :3]) - - ts = np.linspace(0, 1, steps) - quats = [quaternion_slerp(quat_a, quat_b, t) for t in ts] - trans = [(1 - t) * pose_a[:3, 3] + t * pose_b[:3, 3] for t in ts] - - poses_ab = [] - for quat, tran in zip(quats, trans): - pose = np.identity(4) - pose[:3, :3] = quaternion_matrix(quat)[:3, :3] - pose[:3, 3] = tran - poses_ab.append(pose[:3]) - return poses_ab - - -def get_interpolated_k( - k_a: Float[Tensor, "3 3"], k_b: Float[Tensor, "3 3"], steps: int = 10 -) -> List[Float[Tensor, "3 4"]]: - """ - Returns interpolated path between two camera poses with specified number of steps. - - Args: - k_a: camera matrix 1 - k_b: camera matrix 2 - steps: number of steps the interpolated pose path should contain - - Returns: - List of interpolated camera poses - """ - Ks: List[Float[Tensor, "3 3"]] = [] - ts = np.linspace(0, 1, steps) - for t in ts: - new_k = k_a * (1.0 - t) + k_b * t - Ks.append(new_k) - return Ks - - -def get_ordered_poses_and_k( - poses: Float[Tensor, "num_poses 3 4"], - Ks: Float[Tensor, "num_poses 3 3"], -) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]: - """ - Returns ordered poses and intrinsics by euclidian distance between poses. - - Args: - poses: list of camera poses - Ks: list of camera intrinsics - - Returns: - tuple of ordered poses and intrinsics - - """ - - poses_num = len(poses) - - ordered_poses = torch.unsqueeze(poses[0], 0) - ordered_ks = torch.unsqueeze(Ks[0], 0) - - # remove the first pose from poses - poses = poses[1:] - Ks = Ks[1:] - - for _ in range(poses_num - 1): - distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1) - idx = torch.argmin(distances) - ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0) - ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0) - poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0) - Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0) - - return ordered_poses, ordered_ks - - -def get_interpolated_poses_many( - poses: Float[Tensor, "num_poses 3 4"], - Ks: Float[Tensor, "num_poses 3 3"], - steps_per_transition: int = 10, - order_poses: bool = False, -) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]: - """Return interpolated poses for many camera poses. - - Args: - poses: list of camera poses - Ks: list of camera intrinsics - steps_per_transition: number of steps per transition - order_poses: whether to order poses by euclidian distance - - Returns: - tuple of new poses and intrinsics - """ - traj = [] - k_interp = [] - - if order_poses: - poses, Ks = get_ordered_poses_and_k(poses, Ks) - - for idx in range(poses.shape[0] - 1): - pose_a = poses[idx].cpu().numpy() - pose_b = poses[idx + 1].cpu().numpy() - poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition) - traj += poses_ab - k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition) - - traj = np.stack(traj, axis=0) - k_interp = torch.stack(k_interp, dim=0) - - return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32) - - -def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]: - """Returns a normalized vector.""" - return x / torch.linalg.norm(x) - - -def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: - """Normalize tensor along axis and return normalized value with norms. - - Args: - x: tensor to normalize. - dim: axis along which to normalize. - - Returns: - Tuple of normalized tensor and corresponding norm. - """ - - norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x)) - return x / norm, norm - - -def viewmatrix(lookat: torch.Tensor, up: torch.Tensor, pos: torch.Tensor) -> Float[Tensor, "*batch"]: - """Returns a camera transformation matrix. - - Args: - lookat: The direction the camera is looking. - up: The upward direction of the camera. - pos: The position of the camera. - - Returns: - A camera transformation matrix. - """ - vec2 = normalize(lookat) - vec1_avg = normalize(up) - vec0 = normalize(torch.cross(vec1_avg, vec2)) - vec1 = normalize(torch.cross(vec2, vec0)) - m = torch.stack([vec0, vec1, vec2, pos], 1) - return m - - -def get_distortion_params( - k1: float = 0.0, - k2: float = 0.0, - k3: float = 0.0, - k4: float = 0.0, - p1: float = 0.0, - p2: float = 0.0, -) -> Float[Tensor, "*batch"]: - """Returns a distortion parameters matrix. - - Args: - k1: The first radial distortion parameter. - k2: The second radial distortion parameter. - k3: The third radial distortion parameter. - k4: The fourth radial distortion parameter. - p1: The first tangential distortion parameter. - p2: The second tangential distortion parameter. - Returns: - torch.Tensor: A distortion parameters matrix. - """ - return torch.Tensor([k1, k2, k3, k4, p1, p2]) - - -def _compute_residual_and_jacobian( - x: torch.Tensor, - y: torch.Tensor, - xd: torch.Tensor, - yd: torch.Tensor, - distortion_params: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Auxiliary function of radial_and_tangential_undistort() that computes residuals and jacobians. - Adapted from MultiNeRF: - https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L427-L474 - - Args: - x: The updated x coordinates. - y: The updated y coordinates. - xd: The distorted x coordinates. - yd: The distorted y coordinates. - distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2]. - - Returns: - The residuals (fx, fy) and jacobians (fx_x, fx_y, fy_x, fy_y). - """ - - k1 = distortion_params[..., 0] - k2 = distortion_params[..., 1] - k3 = distortion_params[..., 2] - k4 = distortion_params[..., 3] - p1 = distortion_params[..., 4] - p2 = distortion_params[..., 5] - - # let r(x, y) = x^2 + y^2; - # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 + - # k4 * r(x, y)^4; - r = x * x + y * y - d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4))) - - # The perfect projection is: - # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2); - # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2); - # - # Let's define - # - # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd; - # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd; - # - # We are looking for a solution that satisfies - # fx(x, y) = fy(x, y) = 0; - fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd - fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd - - # Compute derivative of d over [x, y] - d_r = k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4)) - d_x = 2.0 * x * d_r - d_y = 2.0 * y * d_r - - # Compute derivative of fx over x and y. - fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x - fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y - - # Compute derivative of fy over x and y. - fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x - fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y - - return fx, fy, fx_x, fx_y, fy_x, fy_y - - -# @torch_compile(dynamic=True, mode="reduce-overhead", backend="eager") -def radial_and_tangential_undistort( - coords: torch.Tensor, - distortion_params: torch.Tensor, - eps: float = 1e-3, - max_iterations: int = 10, -) -> torch.Tensor: - """Computes undistorted coords given opencv distortion parameters. - Adapted from MultiNeRF - https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L477-L509 - - Args: - coords: The distorted coordinates. - distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2]. - eps: The epsilon for the convergence. - max_iterations: The maximum number of iterations to perform. - - Returns: - The undistorted coordinates. - """ - - # Initialize from the distorted point. - x = coords[..., 0] - y = coords[..., 1] - - for _ in range(max_iterations): - fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( - x=x, y=y, xd=coords[..., 0], yd=coords[..., 1], distortion_params=distortion_params - ) - denominator = fy_x * fx_y - fx_x * fy_y - x_numerator = fx * fy_y - fy * fx_y - y_numerator = fy * fx_x - fx * fy_x - step_x = torch.where(torch.abs(denominator) > eps, x_numerator / denominator, torch.zeros_like(denominator)) - step_y = torch.where(torch.abs(denominator) > eps, y_numerator / denominator, torch.zeros_like(denominator)) - - x = x + step_x - y = y + step_y - - return torch.stack([x, y], dim=-1) - - -def rotation_matrix(a: Float[Tensor, "3"], b: Float[Tensor, "3"]) -> Float[Tensor, "3 3"]: - """Compute the rotation matrix that rotates vector a to vector b. - - Args: - a: The vector to rotate. - b: The vector to rotate to. - Returns: - The rotation matrix. - """ - a = a / torch.linalg.norm(a) - b = b / torch.linalg.norm(b) - v = torch.cross(a, b) - c = torch.dot(a, b) - # If vectors are exactly opposite, we add a little noise to one of them - if c < -1 + 1e-8: - eps = (torch.rand(3) - 0.5) * 0.01 - return rotation_matrix(a + eps, b) - s = torch.linalg.norm(v) - skew_sym_mat = torch.Tensor( - [ - [0, -v[2], v[1]], - [v[2], 0, -v[0]], - [-v[1], v[0], 0], - ] - ) - return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8)) - - -def focus_of_attention(poses: Float[Tensor, "*num_poses 4 4"], initial_focus: Float[Tensor, "3"]) -> Float[Tensor, "3"]: - """Compute the focus of attention of a set of cameras. Only cameras - that have the focus of attention in front of them are considered. - - Args: - poses: The poses to orient. - initial_focus: The 3D point views to decide which cameras are initially activated. - - Returns: - The 3D position of the focus of attention. - """ - # References to the same method in third-party code: - # https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145 - # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197 - active_directions = -poses[:, :3, 2:3] - active_origins = poses[:, :3, 3:4] - # initial value for testing if the focus_pt is in front or behind - focus_pt = initial_focus - # Prune cameras which have the current have the focus_pt behind them. - active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0 - done = False - # We need at least two active cameras, else fallback on the previous solution. - # This may be the "poses" solution if no cameras are active on first iteration, e.g. - # they are in an outward-looking configuration. - while torch.sum(active.int()) > 1 and not done: - active_directions = active_directions[active] - active_origins = active_origins[active] - # https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions - m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1) - mt_m = torch.transpose(m, -2, -1) @ m - focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0] - active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0 - if active.all(): - # the set of active cameras did not change, so we're done. - done = True - return focus_pt - - -def auto_orient_and_center_poses( - poses: Float[Tensor, "*num_poses 4 4"], - method: Literal["pca", "up", "vertical", "none"] = "up", - center_method: Literal["poses", "focus", "none"] = "poses", -) -> Tuple[Float[Tensor, "*num_poses 3 4"], Float[Tensor, "3 4"]]: - """Orients and centers the poses. - - We provide three methods for orientation: - - - pca: Orient the poses so that the principal directions of the camera centers are aligned - with the axes, Z corresponding to the smallest principal component. - This method works well when all of the cameras are in the same plane, for example when - images are taken using a mobile robot. - - up: Orient the poses so that the average up vector is aligned with the z axis. - This method works well when images are not at arbitrary angles. - - vertical: Orient the poses so that the Z 3D direction projects close to the - y axis in images. This method works better if cameras are not all - looking in the same 3D direction, which may happen in camera arrays or in LLFF. - - There are two centering methods: - - - poses: The poses are centered around the origin. - - focus: The origin is set to the focus of attention of all cameras (the - closest point to cameras optical axes). Recommended for inward-looking - camera configurations. - - Args: - poses: The poses to orient. - method: The method to use for orientation. - center_method: The method to use to center the poses. - - Returns: - Tuple of the oriented poses and the transform matrix. - """ - - origins = poses[..., :3, 3] - - mean_origin = torch.mean(origins, dim=0) - translation_diff = origins - mean_origin - - if center_method == "poses": - translation = mean_origin - elif center_method == "focus": - translation = focus_of_attention(poses, mean_origin) - elif center_method == "none": - translation = torch.zeros_like(mean_origin) - else: - raise ValueError(f"Unknown value for center_method: {center_method}") - - if method == "pca": - _, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff) - eigvec = torch.flip(eigvec, dims=(-1,)) - - if torch.linalg.det(eigvec) < 0: - eigvec[:, 2] = -eigvec[:, 2] - - transform = torch.cat([eigvec, eigvec @ -translation[..., None]], dim=-1) - oriented_poses = transform @ poses - - if oriented_poses.mean(dim=0)[2, 1] < 0: - oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3] - elif method in ("up", "vertical"): - up = torch.mean(poses[:, :3, 1], dim=0) - up = up / torch.linalg.norm(up) - if method == "vertical": - # If cameras are not all parallel (e.g. not in an LLFF configuration), - # we can find the 3D direction that most projects vertically in all - # cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares - # problem is solved by SVD. - x_axis_matrix = poses[:, :3, 0] - _, S, Vh = torch.linalg.svd(x_axis_matrix, full_matrices=False) - # Singular values are S_i=||Xv_i|| for each right singular vector v_i. - # ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i - # are an orthonormal basis. - # ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the - # RMS of cosines between x axes and v_i. If the second smallest singular - # value corresponds to an angle error less than 10° (cos(80°)=0.17), - # this is probably a degenerate camera configuration (typical values - # are around 5° average error for the true vertical). In this case, - # rather than taking the vector corresponding to the smallest singular - # value, we project the "up" vector on the plane spanned by the two - # best singular vectors. We could also just fallback to the "up" - # solution. - if S[1] > 0.17 * math.sqrt(poses.shape[0]): - # regular non-degenerate configuration - up_vertical = Vh[2, :] - # It may be pointing up or down. Use "up" to disambiguate the sign. - up = up_vertical if torch.dot(up_vertical, up) > 0 else -up_vertical - else: - # Degenerate configuration: project "up" on the plane spanned by - # the last two right singular vectors (which are orthogonal to the - # first). v_0 is a unit vector, no need to divide by its norm when - # projecting. - up = up - Vh[0, :] * torch.dot(up, Vh[0, :]) - # re-normalize - up = up / torch.linalg.norm(up) - - rotation = rotation_matrix(up, torch.Tensor([0, 0, 1])) - transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1) - oriented_poses = transform @ poses - elif method == "none": - transform = torch.eye(4) - transform[:3, 3] = -translation - transform = transform[:3, :] - oriented_poses = transform @ poses - else: - raise ValueError(f"Unknown value for method: {method}") - - return oriented_poses, transform - - -@torch.jit.script -def fisheye624_project(xyz, params): - """ - Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera - model project() function. - Inputs: - xyz: BxNx3 tensor of 3D points to be projected - params: Bx16 tensor of Fisheye624 parameters formatted like this: - [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] - or Bx15 tensor of Fisheye624 parameters formatted like this: - [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] - Outputs: - uv: BxNx2 tensor of 2D projections of xyz in image plane - Model for fisheye cameras with radial, tangential, and thin-prism distortion. - This model allows fu != fv. - Specifically, the model is: - uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion - [y_r] - proj = diag(fu,fv) * uvDistorted + [cu;cv]; - where: - a = x/z, b = y/z, r = (a^2+b^2)^(1/2) - th = atan(r) - cosPhi = a/r, sinPhi = b/r - [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi] - [y_r] [sinPhi] - the number of terms in the series is determined by the template parameter numK. - tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1] - [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0] - where rd^2 = x_r^2 + y_r^2 - thinPrismDistortion = [s0 * rd^2 + s1 rd^4] - [s2 * rd^2 + s3 rd^4] - Author: Daniel DeTone (ddetone@meta.com) - """ - - assert xyz.ndim == 3 - assert params.ndim == 2 - assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy" - eps = 1e-9 - B, N = xyz.shape[0], xyz.shape[1] - - # Radial correction. - z = xyz[:, :, 2].reshape(B, N, 1) - z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) - ab = xyz[:, :, :2] / z - r = torch.norm(ab, dim=-1, p=2, keepdim=True) - th = torch.atan(r) - th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) - th_k = th.reshape(B, N, 1).clone() - for i in range(6): - th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2) - xr_yr = th_k * th_divr - uv_dist = xr_yr - - # Tangential correction. - p0 = params[:, -6].reshape(B, 1) - p1 = params[:, -5].reshape(B, 1) - xr = xr_yr[:, :, 0].reshape(B, N) - yr = xr_yr[:, :, 1].reshape(B, N) - xr_yr_sq = torch.square(xr_yr) - xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) - yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) - rd_sq = xr_sq + yr_sq - uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1) - uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0) - uv_dist = torch.stack([uv_dist_tu, uv_dist_tv], dim=-1) # Avoids in-place complaint. - - # Thin Prism correction. - s0 = params[:, -4].reshape(B, 1) - s1 = params[:, -3].reshape(B, 1) - s2 = params[:, -2].reshape(B, 1) - s3 = params[:, -1].reshape(B, 1) - rd_4 = torch.square(rd_sq) - uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) - uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) - - # Finally, apply standard terms: focal length and camera centers. - if params.shape[-1] == 15: - fx_fy = params[:, 0].reshape(B, 1, 1) - cx_cy = params[:, 1:3].reshape(B, 1, 2) - else: - fx_fy = params[:, 0:2].reshape(B, 1, 2) - cx_cy = params[:, 2:4].reshape(B, 1, 2) - result = uv_dist * fx_fy + cx_cy - - return result - - -# Core implementation of fisheye 624 unprojection. More details are documented here: -# https://facebookresearch.github.io/projectaria_tools/docs/tech_insights/camera_intrinsic_models#the-fisheye62-model -@torch.jit.script -def fisheye624_unproject_helper(uv, params, max_iters: int = 5): - """ - Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera - model. There is no analytical solution for the inverse of the project() - function so this solves an optimization problem using Newton's method to get - the inverse. - Inputs: - uv: BxNx2 tensor of 2D pixels to be unprojected - params: Bx16 tensor of Fisheye624 parameters formatted like this: - [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] - or Bx15 tensor of Fisheye624 parameters formatted like this: - [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] - Outputs: - xyz: BxNx3 tensor of 3D rays of uv points with z = 1. - Model for fisheye cameras with radial, tangential, and thin-prism distortion. - This model assumes fu=fv. This unproject function holds that: - X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0] - and - x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2] - Author: Daniel DeTone (ddetone@meta.com) - """ - - assert uv.ndim == 3, "Expected batched input shaped BxNx3" - assert params.ndim == 2 - assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy" - eps = 1e-6 - B, N = uv.shape[0], uv.shape[1] - - if params.shape[-1] == 15: - fx_fy = params[:, 0].reshape(B, 1, 1) - cx_cy = params[:, 1:3].reshape(B, 1, 2) - else: - fx_fy = params[:, 0:2].reshape(B, 1, 2) - cx_cy = params[:, 2:4].reshape(B, 1, 2) - - uv_dist = (uv - cx_cy) / fx_fy - - # Compute xr_yr using Newton's method. - xr_yr = uv_dist.clone() # Initial guess. - for _ in range(max_iters): - uv_dist_est = xr_yr.clone() - # Tangential terms. - p0 = params[:, -6].reshape(B, 1) - p1 = params[:, -5].reshape(B, 1) - xr = xr_yr[:, :, 0].reshape(B, N) - yr = xr_yr[:, :, 1].reshape(B, N) - xr_yr_sq = torch.square(xr_yr) - xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) - yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) - rd_sq = xr_sq + yr_sq - uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1) - uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0) - # Thin Prism terms. - s0 = params[:, -4].reshape(B, 1) - s1 = params[:, -3].reshape(B, 1) - s2 = params[:, -2].reshape(B, 1) - s3 = params[:, -1].reshape(B, 1) - rd_4 = torch.square(rd_sq) - uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) - uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) - # Compute the derivative of uv_dist w.r.t. xr_yr. - duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) - duv_dist_dxr_yr[:, :, 0, 0] = 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1 - offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0) - duv_dist_dxr_yr[:, :, 0, 1] = offdiag - duv_dist_dxr_yr[:, :, 1, 0] = offdiag - duv_dist_dxr_yr[:, :, 1, 1] = 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0 - xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1] - temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) - duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (xr_yr[:, :, 0] * temp1) - duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (xr_yr[:, :, 1] * temp1) - temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) - duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (xr_yr[:, :, 0] * temp2) - duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (xr_yr[:, :, 1] * temp2) - # Compute 2x2 inverse manually here since torch.inverse() is very slow. - # Because this is slow: inv = duv_dist_dxr_yr.inverse() - # About a 10x reduction in speed with above line. - mat = duv_dist_dxr_yr.reshape(-1, 2, 2) - a = mat[:, 0, 0].reshape(-1, 1, 1) - b = mat[:, 0, 1].reshape(-1, 1, 1) - c = mat[:, 1, 0].reshape(-1, 1, 1) - d = mat[:, 1, 1].reshape(-1, 1, 1) - det = 1.0 / ((a * d) - (b * c)) - top = torch.cat([d, -b], dim=2) - bot = torch.cat([-c, a], dim=2) - inv = det * torch.cat([top, bot], dim=1) - inv = inv.reshape(B, N, 2, 2) - # Manually compute 2x2 @ 2x1 matrix multiply. - # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0] - diff = uv_dist - uv_dist_est - a = inv[:, :, 0, 0] - b = inv[:, :, 0, 1] - c = inv[:, :, 1, 0] - d = inv[:, :, 1, 1] - e = diff[:, :, 0] - f = diff[:, :, 1] - step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) - # Newton step. - xr_yr = xr_yr + step - - # Compute theta using Newton's method. - xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) - th = xr_yr_norm.clone() - for _ in range(max_iters): - th_radial = uv.new_ones(B, N, 1) - dthd_th = uv.new_ones(B, N, 1) - for k in range(6): - r_k = params[:, -12 + k].reshape(B, 1, 1) - th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2)) - dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2)) - th_radial = th_radial * th - step = (xr_yr_norm - th_radial) / dthd_th - # handle dthd_th close to 0. - step = torch.where(dthd_th.abs() > eps, step, torch.sign(step) * eps * 10.0) - th = th + step - # Compute the ray direction using theta and xr_yr. - close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) - ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) - ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) - return ray - - -# unproject 2D point to 3D with fisheye624 model -def fisheye624_unproject(coords: torch.Tensor, distortion_params: torch.Tensor) -> torch.Tensor: - dirs = fisheye624_unproject_helper(coords.unsqueeze(0), distortion_params[0].unsqueeze(0)) - # correct for camera space differences: - dirs[..., 1] = -dirs[..., 1] - dirs[..., 2] = -dirs[..., 2] - return dirs diff --git a/sgm/data/cifar10.py b/sgm/data/cifar10.py deleted file mode 100644 index 6083646f136bad308a0485843b89234cf7a9d6cd..0000000000000000000000000000000000000000 --- a/sgm/data/cifar10.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytorch_lightning as pl -import torchvision -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms - - -class CIFAR10DataDictWrapper(Dataset): - def __init__(self, dset): - super().__init__() - self.dset = dset - - def __getitem__(self, i): - x, y = self.dset[i] - return {"jpg": x, "cls": y} - - def __len__(self): - return len(self.dset) - - -class CIFAR10Loader(pl.LightningDataModule): - def __init__(self, batch_size, num_workers=0, shuffle=True): - super().__init__() - - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] - ) - - self.batch_size = batch_size - self.num_workers = num_workers - self.shuffle = shuffle - self.train_dataset = CIFAR10DataDictWrapper( - torchvision.datasets.CIFAR10( - root=".data/", train=True, download=True, transform=transform - ) - ) - self.test_dataset = CIFAR10DataDictWrapper( - torchvision.datasets.CIFAR10( - root=".data/", train=False, download=True, transform=transform - ) - ) - - def prepare_data(self): - pass - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - ) - - def test_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - ) - - def val_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - ) diff --git a/sgm/data/co3d.py b/sgm/data/co3d.py deleted file mode 100644 index ba95cfbb540e4664b0fdb313f67bb5013bdea6bf..0000000000000000000000000000000000000000 --- a/sgm/data/co3d.py +++ /dev/null @@ -1,1367 +0,0 @@ -""" -adopted from SparseFusion -Wrapper for the full CO3Dv2 dataset -#@ Modified from https://github.com/facebookresearch/pytorch3d -""" - -import json -import logging -import math -import os -import random -import time -import warnings -from collections import defaultdict -from itertools import islice -from typing import ( - Any, - ClassVar, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - TypedDict, - Union, -) -from einops import rearrange, repeat - -import numpy as np -import torch -import torch.nn.functional as F -import torchvision.transforms.functional as TF -from pytorch3d.utils import opencv_from_cameras_projection -from pytorch3d.implicitron.dataset import types -from pytorch3d.implicitron.dataset.dataset_base import DatasetBase -from sgm.data.json_index_dataset import ( - FrameAnnotsEntry, - _bbox_xywh_to_xyxy, - _bbox_xyxy_to_xywh, - _clamp_box_to_image_bounds_and_round, - _crop_around_box, - _get_1d_bounds, - _get_bbox_from_mask, - _get_clamp_bbox, - _load_1bit_png_mask, - _load_16big_png_depth, - _load_depth, - _load_depth_mask, - _load_image, - _load_mask, - _load_pointcloud, - _rescale_bbox, - _safe_as_tensor, - _seq_name_to_seed, -) -from sgm.data.objaverse import video_collate_fn -from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( - get_available_subset_names, -) -from pytorch3d.renderer.cameras import PerspectiveCameras - -logger = logging.getLogger(__name__) - - -from dataclasses import dataclass, field, fields - -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch -from pytorch_lightning import LightningDataModule -from torch.utils.data import DataLoader - -CO3D_ALL_CATEGORIES = list( - reversed( - [ - "baseballbat", - "banana", - "bicycle", - "microwave", - "tv", - "cellphone", - "toilet", - "hairdryer", - "couch", - "kite", - "pizza", - "umbrella", - "wineglass", - "laptop", - "hotdog", - "stopsign", - "frisbee", - "baseballglove", - "cup", - "parkingmeter", - "backpack", - "toyplane", - "toybus", - "handbag", - "chair", - "keyboard", - "car", - "motorcycle", - "carrot", - "bottle", - "sandwich", - "remote", - "bowl", - "skateboard", - "toaster", - "mouse", - "toytrain", - "book", - "toytruck", - "orange", - "broccoli", - "plant", - "teddybear", - "suitcase", - "bench", - "ball", - "cake", - "vase", - "hydrant", - "apple", - "donut", - ] - ) -) - -CO3D_ALL_TEN = [ - "donut", - "apple", - "hydrant", - "vase", - "cake", - "ball", - "bench", - "suitcase", - "teddybear", - "plant", -] - - -# @ FROM https://github.com/facebookresearch/pytorch3d -@dataclass -class FrameData(Mapping[str, Any]): - """ - A type of the elements returned by indexing the dataset object. - It can represent both individual frames and batches of thereof; - in this documentation, the sizes of tensors refer to single frames; - add the first batch dimension for the collation result. - Args: - frame_number: The number of the frame within its sequence. - 0-based continuous integers. - sequence_name: The unique name of the frame's sequence. - sequence_category: The object category of the sequence. - frame_timestamp: The time elapsed since the start of a sequence in sec. - image_size_hw: The size of the image in pixels; (height, width) tensor - of shape (2,). - image_path: The qualified path to the loaded image (with dataset_root). - image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image - of the frame; elements are floats in [0, 1]. - mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image - regions. Regions can be invalid (mask_crop[i,j]=0) in case they - are a result of zero-padding of the image after cropping around - the object bounding box; elements are floats in {0.0, 1.0}. - depth_path: The qualified path to the frame's depth map. - depth_map: A float Tensor of shape `(1, H, W)` holding the depth map - of the frame; values correspond to distances from the camera; - use `depth_mask` and `mask_crop` to filter for valid pixels. - depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the - depth map that are valid for evaluation, they have been checked for - consistency across views; elements are floats in {0.0, 1.0}. - mask_path: A qualified path to the foreground probability mask. - fg_probability: A Tensor of `(1, H, W)` denoting the probability of the - pixels belonging to the captured object; elements are floats - in [0, 1]. - bbox_xywh: The bounding box tightly enclosing the foreground object in the - format (x0, y0, width, height). The convention assumes that - `x0+width` and `y0+height` includes the boundary of the box. - I.e., to slice out the corresponding crop from an image tensor `I` - we execute `crop = I[..., y0:y0+height, x0:x0+width]` - crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` - in the original image coordinates in the format (x0, y0, width, height). - The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs - from `bbox_xywh` due to padding (which can happen e.g. due to - setting `JsonIndexDataset.box_crop_context > 0`) - camera: A PyTorch3D camera object corresponding the frame's viewpoint, - corrected for cropping if it happened. - camera_quality_score: The score proportional to the confidence of the - frame's camera estimation (the higher the more accurate). - point_cloud_quality_score: The score proportional to the accuracy of the - frame's sequence point cloud (the higher the more accurate). - sequence_point_cloud_path: The path to the sequence's point cloud. - sequence_point_cloud: A PyTorch3D Pointclouds object holding the - point cloud corresponding to the frame's sequence. When the object - represents a batch of frames, point clouds may be deduplicated; - see `sequence_point_cloud_idx`. - sequence_point_cloud_idx: Integer indices mapping frame indices to the - corresponding point clouds in `sequence_point_cloud`; to get the - corresponding point cloud to `image_rgb[i]`, use - `sequence_point_cloud[sequence_point_cloud_idx[i]]`. - frame_type: The type of the loaded frame specified in - `subset_lists_file`, if provided. - meta: A dict for storing additional frame information. - """ - - frame_number: Optional[torch.LongTensor] - sequence_name: Union[str, List[str]] - sequence_category: Union[str, List[str]] - frame_timestamp: Optional[torch.Tensor] = None - image_size_hw: Optional[torch.Tensor] = None - image_path: Union[str, List[str], None] = None - image_rgb: Optional[torch.Tensor] = None - # masks out padding added due to cropping the square bit - mask_crop: Optional[torch.Tensor] = None - depth_path: Union[str, List[str], None] = "" - depth_map: Optional[torch.Tensor] = torch.zeros(1) - depth_mask: Optional[torch.Tensor] = torch.zeros(1) - mask_path: Union[str, List[str], None] = None - fg_probability: Optional[torch.Tensor] = None - bbox_xywh: Optional[torch.Tensor] = None - crop_bbox_xywh: Optional[torch.Tensor] = None - camera: Optional[PerspectiveCameras] = None - camera_quality_score: Optional[torch.Tensor] = None - point_cloud_quality_score: Optional[torch.Tensor] = None - sequence_point_cloud_path: Union[str, List[str], None] = "" - sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1) - sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1) - frame_type: Union[str, List[str], None] = "" # known | unseen - meta: dict = field(default_factory=lambda: {}) - valid_region: Optional[torch.Tensor] = None - category_one_hot: Optional[torch.Tensor] = None - - def to(self, *args, **kwargs): - new_params = {} - for f in fields(self): - value = getattr(self, f.name) - if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): - new_params[f.name] = value.to(*args, **kwargs) - else: - new_params[f.name] = value - return type(self)(**new_params) - - def cpu(self): - return self.to(device=torch.device("cpu")) - - def cuda(self): - return self.to(device=torch.device("cuda")) - - # the following functions make sure **frame_data can be passed to functions - def __iter__(self): - for f in fields(self): - yield f.name - - def __getitem__(self, key): - return getattr(self, key) - - def __len__(self): - return len(fields(self)) - - @classmethod - def collate(cls, batch): - """ - Given a list objects `batch` of class `cls`, collates them into a batched - representation suitable for processing with deep networks. - """ - - elem = batch[0] - - if isinstance(elem, cls): - pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] - id_to_idx = defaultdict(list) - for i, pc_id in enumerate(pointcloud_ids): - id_to_idx[pc_id].append(i) - - sequence_point_cloud = [] - sequence_point_cloud_idx = -np.ones((len(batch),)) - for i, ind in enumerate(id_to_idx.values()): - sequence_point_cloud_idx[ind] = i - sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) - assert (sequence_point_cloud_idx >= 0).all() - - override_fields = { - "sequence_point_cloud": sequence_point_cloud, - "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), - } - # note that the pre-collate value of sequence_point_cloud_idx is unused - - collated = {} - for f in fields(elem): - list_values = override_fields.get( - f.name, [getattr(d, f.name) for d in batch] - ) - collated[f.name] = ( - cls.collate(list_values) - if all(list_value is not None for list_value in list_values) - else None - ) - return cls(**collated) - - elif isinstance(elem, Pointclouds): - return join_pointclouds_as_batch(batch) - - elif isinstance(elem, CamerasBase): - # TODO: don't store K; enforce working in NDC space - return join_cameras_as_batch(batch) - else: - return torch.utils.data._utils.collate.default_collate(batch) - - -# @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d -class CO3Dv2Wrapper(torch.utils.data.Dataset): - def __init__( - self, - root_dir="/drive/datasets/co3d/", - category="hydrant", - subset="fewview_train", - stage="train", - sample_batch_size=20, - image_size=256, - masked=False, - deprecated_val_region=False, - return_frame_data_list=False, - reso: int = 256, - mask_type: str = "random", - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - fps_id=0.0, - motion_bucket_id=300.0, - num_frames: int = 20, - use_mask: bool = True, - load_pixelnerf: bool = True, - scale_pose: bool = True, - max_n_cond: int = 5, - min_n_cond: int = 2, - cond_on_multi: bool = False, - ): - root = root_dir - from typing import List - - from co3d.dataset.data_types import ( - FrameAnnotation, - SequenceAnnotation, - load_dataclass_jgzip, - ) - - self.dataset_root = root - self.path_manager = None - self.subset = subset - self.stage = stage - self.subset_lists_file: List[str] = [ - f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json" - ] - self.subsets: Optional[List[str]] = [subset] - self.sample_batch_size = sample_batch_size - self.limit_to: int = 0 - self.limit_sequences_to: int = 0 - self.pick_sequence: Tuple[str, ...] = () - self.exclude_sequence: Tuple[str, ...] = () - self.limit_category_to: Tuple[int, ...] = () - self.load_images: bool = True - self.load_depths: bool = False - self.load_depth_masks: bool = False - self.load_masks: bool = True - self.load_point_clouds: bool = False - self.max_points: int = 0 - self.mask_images: bool = False - self.mask_depths: bool = False - self.image_height: Optional[int] = image_size - self.image_width: Optional[int] = image_size - self.box_crop: bool = True - self.box_crop_mask_thr: float = 0.4 - self.box_crop_context: float = 0.3 - self.remove_empty_masks: bool = True - self.n_frames_per_sequence: int = -1 - self.seed: int = 0 - self.sort_frames: bool = False - self.eval_batches: Any = None - - self.img_h = self.image_height - self.img_w = self.image_width - self.masked = masked - self.deprecated_val_region = deprecated_val_region - self.return_frame_data_list = return_frame_data_list - - self.reso = reso - self.num_frames = num_frames - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - self.fps_id = fps_id - self.motion_bucket_id = motion_bucket_id - self.mask_type = mask_type - self.use_mask = use_mask - self.load_pixelnerf = load_pixelnerf - self.scale_pose = scale_pose - self.max_n_cond = max_n_cond - self.min_n_cond = min_n_cond - self.cond_on_multi = cond_on_multi - - if self.cond_on_multi: - assert self.min_n_cond == self.max_n_cond - - start_time = time.time() - if "all_" in category or category == "all": - self.category_frame_annotations = [] - self.category_sequence_annotations = [] - self.subset_lists_file = [] - - if category == "all": - cats = CO3D_ALL_CATEGORIES - elif category == "all_four": - cats = ["hydrant", "teddybear", "motorcycle", "bench"] - elif category == "all_ten": - cats = [ - "donut", - "apple", - "hydrant", - "vase", - "cake", - "ball", - "bench", - "suitcase", - "teddybear", - "plant", - ] - elif category == "all_15": - cats = [ - "hydrant", - "teddybear", - "motorcycle", - "bench", - "hotdog", - "remote", - "suitcase", - "donut", - "plant", - "toaster", - "keyboard", - "handbag", - "toyplane", - "tv", - "orange", - ] - else: - print("UNSPECIFIED CATEGORY SUBSET") - cats = ["hydrant", "teddybear"] - print("loading", cats) - for cat in cats: - self.category_frame_annotations.extend( - load_dataclass_jgzip( - f"{self.dataset_root}/{cat}/frame_annotations.jgz", - List[FrameAnnotation], - ) - ) - self.category_sequence_annotations.extend( - load_dataclass_jgzip( - f"{self.dataset_root}/{cat}/sequence_annotations.jgz", - List[SequenceAnnotation], - ) - ) - self.subset_lists_file.append( - f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json" - ) - - else: - self.category_frame_annotations = load_dataclass_jgzip( - f"{self.dataset_root}/{category}/frame_annotations.jgz", - List[FrameAnnotation], - ) - self.category_sequence_annotations = load_dataclass_jgzip( - f"{self.dataset_root}/{category}/sequence_annotations.jgz", - List[SequenceAnnotation], - ) - - self.subset_to_image_path = None - self._load_frames() - self._load_sequences() - self._sort_frames() - self._load_subset_lists() - self._filter_db() # also computes sequence indices - # self._extract_and_set_eval_batches() - # print(self.eval_batches) - logger.info(str(self)) - - self.seq_to_frames = {} - for fi, item in enumerate(self.frame_annots): - if item["frame_annotation"].sequence_name in self.seq_to_frames: - self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi) - else: - self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi] - - if self.stage != "test" or self.subset != "fewview_test": - count = 0 - new_seq_to_frames = {} - for item in self.seq_to_frames: - if len(self.seq_to_frames[item]) > 10: - count += 1 - new_seq_to_frames[item] = self.seq_to_frames[item] - self.seq_to_frames = new_seq_to_frames - - self.seq_list = list(self.seq_to_frames.keys()) - - # @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG - remove_list = ["411_55952_107659", "376_42884_85882"] - for remove_idx in remove_list: - if remove_idx in self.seq_to_frames: - self.seq_list.remove(remove_idx) - print("removing", remove_idx) - - print("total training seq", len(self.seq_to_frames)) - print("data loading took", time.time() - start_time, "seconds") - - self.all_category_list = list(CO3D_ALL_CATEGORIES) - self.all_category_list.sort() - self.cat_to_idx = {} - for ci, cname in enumerate(self.all_category_list): - self.cat_to_idx[cname] = ci - - def __len__(self): - return len(self.seq_list) - - def __getitem__(self, index): - seq_index = self.seq_list[index] - - if self.subset == "fewview_test" and self.stage == "test": - batch_idx = torch.arange(len(self.seq_to_frames[seq_index])) - - elif self.stage == "test": - batch_idx = ( - torch.linspace( - 0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size - ) - .long() - .tolist() - ) - else: - rand = torch.randperm(len(self.seq_to_frames[seq_index])) - batch_idx = rand[: min(len(rand), self.sample_batch_size)] - - frame_data_list = [] - idx_list = [] - timestamp_list = [] - for idx in batch_idx: - idx_list.append(self.seq_to_frames[seq_index][idx]) - timestamp_list.append( - self.frame_annots[self.seq_to_frames[seq_index][idx]][ - "frame_annotation" - ].frame_timestamp - ) - frame_data_list.append( - self._get_frame(int(self.seq_to_frames[seq_index][idx])) - ) - - time_order = torch.argsort(torch.tensor(timestamp_list)) - frame_data_list = [frame_data_list[i] for i in time_order] - - frame_data = FrameData.collate(frame_data_list) - image_size = torch.Tensor([self.image_height]).repeat( - frame_data.camera.R.shape[0], 2 - ) - frame_dict = { - "R": frame_data.camera.R, - "T": frame_data.camera.T, - "f": frame_data.camera.focal_length, - "c": frame_data.camera.principal_point, - "images": frame_data.image_rgb * frame_data.fg_probability - + (1 - frame_data.fg_probability), - "valid_region": frame_data.mask_crop, - "bbox": frame_data.valid_region, - "image_size": image_size, - "frame_type": frame_data.frame_type, - "idx": seq_index, - "category": frame_data.category_one_hot, - } - if not self.masked: - frame_dict["images_full"] = frame_data.image_rgb - frame_dict["masks"] = frame_data.fg_probability - frame_dict["mask_crop"] = frame_data.mask_crop - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - - def _pad(input): - return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[ - : self.num_frames - ] - - if len(frame_dict["images"]) < self.num_frames: - for k in frame_dict: - if isinstance(frame_dict[k], torch.Tensor): - frame_dict[k] = _pad(frame_dict[k]) - - data = dict() - if "images_full" in frame_dict: - frames = frame_dict["images_full"] * 2 - 1 - else: - frames = frame_dict["images"] * 2 - 1 - data["frames"] = frames - cond = frames[0] - data["cond_frames_without_noise"] = cond - data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) - data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) - data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) - data["motion_bucket_id"] = torch.as_tensor( - [self.motion_bucket_id] * self.num_frames - ) - data["num_video_frames"] = self.num_frames - data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) - - if self.load_pixelnerf: - data["pixelnerf_input"] = dict() - # Rs = frame_dict["R"].transpose(-1, -2) - # Ts = frame_dict["T"] - # Rs[:, :, 2] *= -1 - # Rs[:, :, 0] *= -1 - # Ts[:, 2] *= -1 - # Ts[:, 0] *= -1 - # c2ws = torch.zeros(Rs.shape[0], 4, 4) - # c2ws[:, :3, :3] = Rs - # c2ws[:, :3, 3] = Ts - # c2ws[:, 3, 3] = 1 - # c2ws = c2ws.inverse() - # # c2ws[..., 0] *= -1 - # # c2ws[..., 2] *= -1 - # cx = frame_dict["c"][:, 0] - # cy = frame_dict["c"][:, 1] - # fx = frame_dict["f"][:, 0] - # fy = frame_dict["f"][:, 1] - # intrinsics = torch.zeros(cx.shape[0], 3, 3) - # intrinsics[:, 2, 2] = 1 - # intrinsics[:, 0, 0] = fx - # intrinsics[:, 1, 1] = fy - # intrinsics[:, 0, 2] = cx - # intrinsics[:, 1, 2] = cy - - scene_cameras = PerspectiveCameras( - R=frame_dict["R"], - T=frame_dict["T"], - focal_length=frame_dict["f"], - principal_point=frame_dict["c"], - image_size=frame_dict["image_size"], - ) - R, T, intrinsics = opencv_from_cameras_projection( - scene_cameras, frame_dict["image_size"] - ) - c2ws = torch.zeros(R.shape[0], 4, 4) - c2ws[:, :3, :3] = R - c2ws[:, :3, 3] = T - c2ws[:, 3, 3] = 1.0 - c2ws = c2ws.inverse() - c2ws[..., 1:3] *= -1 - intrinsics[:, :2] /= 256 - - cameras = torch.zeros(c2ws.shape[0], 25) - cameras[..., :16] = c2ws.reshape(-1, 16) - cameras[..., 16:] = intrinsics.reshape(-1, 9) - if self.scale_pose: - c2ws = cameras[..., :16].reshape(-1, 4, 4) - center = c2ws[:, :3, 3].mean(0) - radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() - scale = 1.5 / radius - c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale - cameras[..., :16] = c2ws.reshape(-1, 16) - - data["pixelnerf_input"]["frames"] = frames - data["pixelnerf_input"]["cameras"] = cameras - data["pixelnerf_input"]["rgb"] = ( - F.interpolate( - frames, - (self.image_width // 8, self.image_height // 8), - mode="bilinear", - align_corners=False, - ) - + 1 - ) * 0.5 - - return data - # if self.return_frame_data_list: - # return (frame_dict, frame_data_list) - # return frame_dict - - def collate_fn(self, batch): - # a hack to add source index and keep consistent within a batch - if self.max_n_cond > 1: - # TODO implement this - n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) - # debug - # source_index = [0] - if n_cond > 1: - for b in batch: - source_index = [0] + np.random.choice( - np.arange(1, self.num_frames), - self.max_n_cond - 1, - replace=False, - ).tolist() - b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) - b["pixelnerf_input"]["n_cond"] = n_cond - b["pixelnerf_input"]["source_images"] = b["frames"][source_index] - b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ - "cameras" - ][source_index] - - if self.cond_on_multi: - b["cond_frames_without_noise"] = b["frames"][source_index] - - ret = video_collate_fn(batch) - - if self.cond_on_multi: - ret["cond_frames_without_noise"] = rearrange( - ret["cond_frames_without_noise"], "b t ... -> (b t) ..." - ) - - return ret - - def _get_frame(self, index): - # if index >= len(self.frame_annots): - # raise IndexError(f"index {index} out of range {len(self.frame_annots)}") - - entry = self.frame_annots[index]["frame_annotation"] - # pyre-ignore[16] - point_cloud = self.seq_annots[entry.sequence_name].point_cloud - frame_data = FrameData( - frame_number=_safe_as_tensor(entry.frame_number, torch.long), - frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), - sequence_name=entry.sequence_name, - sequence_category=self.seq_annots[entry.sequence_name].category, - camera_quality_score=_safe_as_tensor( - self.seq_annots[entry.sequence_name].viewpoint_quality_score, - torch.float, - ), - point_cloud_quality_score=_safe_as_tensor( - point_cloud.quality_score, torch.float - ) - if point_cloud is not None - else None, - ) - - # The rest of the fields are optional - frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) - - ( - frame_data.fg_probability, - frame_data.mask_path, - frame_data.bbox_xywh, - clamp_bbox_xyxy, - frame_data.crop_bbox_xywh, - ) = self._load_crop_fg_probability(entry) - - scale = 1.0 - if self.load_images and entry.image is not None: - # original image size - frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) - - ( - frame_data.image_rgb, - frame_data.image_path, - frame_data.mask_crop, - scale, - ) = self._load_crop_images( - entry, frame_data.fg_probability, clamp_bbox_xyxy - ) - # print(frame_data.fg_probability.sum()) - # print('scale', scale) - - #! INSERT - if self.deprecated_val_region: - # print(frame_data.crop_bbox_xywh) - valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float() - # print(valid_bbox, frame_data.image_size_hw) - valid_bbox[0] = torch.clip( - ( - valid_bbox[0] - - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") - ) - / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), - -1.0, - 1.0, - ) - valid_bbox[1] = torch.clip( - ( - valid_bbox[1] - - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") - ) - / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), - -1.0, - 1.0, - ) - valid_bbox[2] = torch.clip( - ( - valid_bbox[2] - - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") - ) - / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), - -1.0, - 1.0, - ) - valid_bbox[3] = torch.clip( - ( - valid_bbox[3] - - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") - ) - / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), - -1.0, - 1.0, - ) - # print(valid_bbox) - frame_data.valid_region = valid_bbox - else: - #! UPDATED VALID BBOX - if self.stage == "train": - assert self.image_height == 256 and self.image_width == 256 - valid = torch.nonzero(frame_data.mask_crop[0]) - min_y = valid[:, 0].min() - min_x = valid[:, 1].min() - max_y = valid[:, 0].max() - max_x = valid[:, 1].max() - valid_bbox = torch.tensor( - [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device - ).unsqueeze(0) - valid_bbox = torch.clip( - (valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0 - ) - frame_data.valid_region = valid_bbox[0] - else: - valid = torch.nonzero(frame_data.mask_crop[0]) - min_y = valid[:, 0].min() - min_x = valid[:, 1].min() - max_y = valid[:, 0].max() - max_x = valid[:, 1].max() - valid_bbox = torch.tensor( - [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device - ).unsqueeze(0) - valid_bbox = torch.clip( - (valid_bbox - (self.image_height // 2)) / (self.image_height // 2), - -1.0, - 1.0, - ) - frame_data.valid_region = valid_bbox[0] - - #! SET CLASS ONEHOT - frame_data.category_one_hot = torch.zeros( - (len(self.all_category_list)), device=frame_data.image_rgb.device - ) - frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1 - - if self.load_depths and entry.depth is not None: - ( - frame_data.depth_map, - frame_data.depth_path, - frame_data.depth_mask, - ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) - - if entry.viewpoint is not None: - frame_data.camera = self._get_pytorch3d_camera( - entry, - scale, - clamp_bbox_xyxy, - ) - - if self.load_point_clouds and point_cloud is not None: - frame_data.sequence_point_cloud_path = pcl_path = os.path.join( - self.dataset_root, point_cloud.path - ) - frame_data.sequence_point_cloud = _load_pointcloud( - self._local_path(pcl_path), max_points=self.max_points - ) - - # for key in frame_data: - # if frame_data[key] == None: - # print(key) - return frame_data - - def _extract_and_set_eval_batches(self): - """ - Sets eval_batches based on input eval_batch_index. - """ - if self.eval_batch_index is not None: - if self.eval_batches is not None: - raise ValueError( - "Cannot define both eval_batch_index and eval_batches." - ) - self.eval_batches = self.seq_frame_index_to_dataset_index( - self.eval_batch_index - ) - - def _load_crop_fg_probability( - self, entry: types.FrameAnnotation - ) -> Tuple[ - Optional[torch.Tensor], - Optional[str], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: - fg_probability = None - full_path = None - bbox_xywh = None - clamp_bbox_xyxy = None - crop_box_xywh = None - - if (self.load_masks or self.box_crop) and entry.mask is not None: - full_path = os.path.join(self.dataset_root, entry.mask.path) - mask = _load_mask(self._local_path(full_path)) - - if mask.shape[-2:] != entry.image.size: - raise ValueError( - f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" - ) - - bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) - - if self.box_crop: - clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( - _get_clamp_bbox( - bbox_xywh, - image_path=entry.image.path, - box_crop_context=self.box_crop_context, - ), - image_size_hw=tuple(mask.shape[-2:]), - ) - crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) - - mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) - - fg_probability, _, _ = self._resize_image(mask, mode="nearest") - - return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh - - def _load_crop_images( - self, - entry: types.FrameAnnotation, - fg_probability: Optional[torch.Tensor], - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: - assert self.dataset_root is not None and entry.image is not None - path = os.path.join(self.dataset_root, entry.image.path) - image_rgb = _load_image(self._local_path(path)) - - if image_rgb.shape[-2:] != entry.image.size: - raise ValueError( - f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" - ) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) - - image_rgb, scale, mask_crop = self._resize_image(image_rgb) - - if self.mask_images: - assert fg_probability is not None - image_rgb *= fg_probability - - return image_rgb, path, mask_crop, scale - - def _load_mask_depth( - self, - entry: types.FrameAnnotation, - clamp_bbox_xyxy: Optional[torch.Tensor], - fg_probability: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor]: - entry_depth = entry.depth - assert entry_depth is not None - path = os.path.join(self.dataset_root, entry_depth.path) - depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] - ) - depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) - - depth_map, _, _ = self._resize_image(depth_map, mode="nearest") - - if self.mask_depths: - assert fg_probability is not None - depth_map *= fg_probability - - if self.load_depth_masks: - assert entry_depth.mask_path is not None - mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) - depth_mask = _load_depth_mask(self._local_path(mask_path)) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_mask_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] - ) - depth_mask = _crop_around_box( - depth_mask, depth_mask_bbox_xyxy, mask_path - ) - - depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") - else: - depth_mask = torch.ones_like(depth_map) - - return depth_map, path, depth_mask - - def _get_pytorch3d_camera( - self, - entry: types.FrameAnnotation, - scale: float, - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> PerspectiveCameras: - entry_viewpoint = entry.viewpoint - assert entry_viewpoint is not None - # principal point and focal length - principal_point = torch.tensor( - entry_viewpoint.principal_point, dtype=torch.float - ) - focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) - - half_image_size_wh_orig = ( - torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 - ) - - # first, we convert from the dataset's NDC convention to pixels - format = entry_viewpoint.intrinsics_format - if format.lower() == "ndc_norm_image_bounds": - # this is e.g. currently used in CO3D for storing intrinsics - rescale = half_image_size_wh_orig - elif format.lower() == "ndc_isotropic": - rescale = half_image_size_wh_orig.min() - else: - raise ValueError(f"Unknown intrinsics format: {format}") - - # principal point and focal length in pixels - principal_point_px = half_image_size_wh_orig - principal_point * rescale - focal_length_px = focal_length * rescale - if self.box_crop: - assert clamp_bbox_xyxy is not None - principal_point_px -= clamp_bbox_xyxy[:2] - - # now, convert from pixels to PyTorch3D v0.5+ NDC convention - if self.image_height is None or self.image_width is None: - out_size = list(reversed(entry.image.size)) - else: - out_size = [self.image_width, self.image_height] - - half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 - half_min_image_size_output = half_image_size_output.min() - - # rescaled principal point and focal length in ndc - principal_point = ( - half_image_size_output - principal_point_px * scale - ) / half_min_image_size_output - focal_length = focal_length_px * scale / half_min_image_size_output - - return PerspectiveCameras( - focal_length=focal_length[None], - principal_point=principal_point[None], - R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], - T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], - ) - - def _load_frames(self) -> None: - self.frame_annots = [ - FrameAnnotsEntry(frame_annotation=a, subset=None) - for a in self.category_frame_annotations - ] - - def _load_sequences(self) -> None: - self.seq_annots = { - entry.sequence_name: entry for entry in self.category_sequence_annotations - } - - def _load_subset_lists(self) -> None: - logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") - if not self.subset_lists_file: - return - - frame_path_to_subset = {} - - for subset_list_file in self.subset_lists_file: - with open(self._local_path(subset_list_file), "r") as f: - subset_to_seq_frame = json.load(f) - - #! PRINT SUBSET_LIST STATS - # if len(self.subset_lists_file) == 1: - # print('train frames', len(subset_to_seq_frame['train'])) - # print('val frames', len(subset_to_seq_frame['val'])) - # print('test frames', len(subset_to_seq_frame['test'])) - - for set_ in subset_to_seq_frame: - for _, _, path in subset_to_seq_frame[set_]: - if path in frame_path_to_subset: - frame_path_to_subset[path].add(set_) - else: - frame_path_to_subset[path] = {set_} - - # pyre-ignore[16] - for frame in self.frame_annots: - frame["subset"] = frame_path_to_subset.get( - frame["frame_annotation"].image.path, None - ) - - if frame["subset"] is None: - continue - warnings.warn( - "Subset lists are given but don't include " - + frame["frame_annotation"].image.path - ) - - def _sort_frames(self) -> None: - # Sort frames to have them grouped by sequence, ordered by timestamp - # pyre-ignore[16] - self.frame_annots = sorted( - self.frame_annots, - key=lambda f: ( - f["frame_annotation"].sequence_name, - f["frame_annotation"].frame_timestamp or 0, - ), - ) - - def _filter_db(self) -> None: - if self.remove_empty_masks: - logger.info("Removing images with empty masks.") - # pyre-ignore[16] - old_len = len(self.frame_annots) - - msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." - - def positive_mass(frame_annot: types.FrameAnnotation) -> bool: - mask = frame_annot.mask - if mask is None: - return False - if mask.mass is None: - raise ValueError(msg) - return mask.mass > 1 - - self.frame_annots = [ - frame - for frame in self.frame_annots - if positive_mass(frame["frame_annotation"]) - ] - logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) - - # this has to be called after joining with categories!! - subsets = self.subsets - if subsets: - if not self.subset_lists_file: - raise ValueError( - "Subset filter is on but subset_lists_file was not given" - ) - - logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") - - # truncate the list of subsets to the valid one - self.frame_annots = [ - entry - for entry in self.frame_annots - if (entry["subset"] is not None and self.stage in entry["subset"]) - ] - - if len(self.frame_annots) == 0: - raise ValueError(f"There are no frames in the '{subsets}' subsets!") - - self._invalidate_indexes(filter_seq_annots=True) - - if len(self.limit_category_to) > 0: - logger.info(f"Limiting dataset to categories: {self.limit_category_to}") - # pyre-ignore[16] - self.seq_annots = { - name: entry - for name, entry in self.seq_annots.items() - if entry.category in self.limit_category_to - } - - # sequence filters - for prefix in ("pick", "exclude"): - orig_len = len(self.seq_annots) - attr = f"{prefix}_sequence" - arr = getattr(self, attr) - if len(arr) > 0: - logger.info(f"{attr}: {str(arr)}") - self.seq_annots = { - name: entry - for name, entry in self.seq_annots.items() - if (name in arr) == (prefix == "pick") - } - logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) - - if self.limit_sequences_to > 0: - self.seq_annots = dict( - islice(self.seq_annots.items(), self.limit_sequences_to) - ) - - # retain only frames from retained sequences - self.frame_annots = [ - f - for f in self.frame_annots - if f["frame_annotation"].sequence_name in self.seq_annots - ] - - self._invalidate_indexes() - - if self.n_frames_per_sequence > 0: - logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") - keep_idx = [] - # pyre-ignore[16] - for seq, seq_indices in self._seq_to_idx.items(): - # infer the seed from the sequence name, this is reproducible - # and makes the selection differ for different sequences - seed = _seq_name_to_seed(seq) + self.seed - seq_idx_shuffled = random.Random(seed).sample( - sorted(seq_indices), len(seq_indices) - ) - keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) - - logger.info( - "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) - ) - self.frame_annots = [self.frame_annots[i] for i in keep_idx] - self._invalidate_indexes(filter_seq_annots=False) - # sequences are not decimated, so self.seq_annots is valid - - if self.limit_to > 0 and self.limit_to < len(self.frame_annots): - logger.info( - "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) - ) - self.frame_annots = self.frame_annots[: self.limit_to] - self._invalidate_indexes(filter_seq_annots=True) - - def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: - # update _seq_to_idx and filter seq_meta according to frame_annots change - # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx - self._invalidate_seq_to_idx() - - if filter_seq_annots: - # pyre-ignore[16] - self.seq_annots = { - k: v - for k, v in self.seq_annots.items() - # pyre-ignore[16] - if k in self._seq_to_idx - } - - def _invalidate_seq_to_idx(self) -> None: - seq_to_idx = defaultdict(list) - # pyre-ignore[16] - for idx, entry in enumerate(self.frame_annots): - seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) - # pyre-ignore[16] - self._seq_to_idx = seq_to_idx - - def _resize_image( - self, image, mode="bilinear" - ) -> Tuple[torch.Tensor, float, torch.Tensor]: - image_height, image_width = self.image_height, self.image_width - if image_height is None or image_width is None: - # skip the resizing - imre_ = torch.from_numpy(image) - return imre_, 1.0, torch.ones_like(imre_[:1]) - # takes numpy array, returns pytorch tensor - minscale = min( - image_height / image.shape[-2], - image_width / image.shape[-1], - ) - imre = torch.nn.functional.interpolate( - torch.from_numpy(image)[None], - scale_factor=minscale, - mode=mode, - align_corners=False if mode == "bilinear" else None, - recompute_scale_factor=True, - )[0] - # pyre-fixme[19]: Expected 1 positional argument. - imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) - imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre - # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. - # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. - mask = torch.zeros(1, self.image_height, self.image_width) - mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 - return imre_, minscale, mask - - def _local_path(self, path: str) -> str: - if self.path_manager is None: - return path - return self.path_manager.get_local_path(path) - - def get_frame_numbers_and_timestamps( - self, idxs: Sequence[int] - ) -> List[Tuple[int, float]]: - out: List[Tuple[int, float]] = [] - for idx in idxs: - # pyre-ignore[16] - frame_annotation = self.frame_annots[idx]["frame_annotation"] - out.append( - (frame_annotation.frame_number, frame_annotation.frame_timestamp) - ) - return out - - def get_eval_batches(self) -> Optional[List[List[int]]]: - return self.eval_batches - - def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: - return entry["frame_annotation"].meta["frame_type"] - - -class CO3DDataset(LightningDataModule): - def __init__( - self, - root_dir, - batch_size=2, - shuffle=True, - num_workers=10, - prefetch_factor=2, - category="hydrant", - **kwargs, - ): - super().__init__() - - self.batch_size = batch_size - self.num_workers = num_workers - self.prefetch_factor = prefetch_factor - self.shuffle = shuffle - - self.train_dataset = CO3Dv2Wrapper( - root_dir=root_dir, - stage="train", - category=category, - **kwargs, - ) - - self.test_dataset = CO3Dv2Wrapper( - root_dir=root_dir, - stage="test", - subset="fewview_dev", - category=category, - **kwargs, - ) - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=self.train_dataset.collate_fn, - ) - - def test_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=self.test_dataset.collate_fn, - ) - - def val_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=video_collate_fn, - ) diff --git a/sgm/data/colmap.py b/sgm/data/colmap.py deleted file mode 100644 index b739f2e9637c0c96b80c42fce05dfeab6c5e1228..0000000000000000000000000000000000000000 --- a/sgm/data/colmap.py +++ /dev/null @@ -1,605 +0,0 @@ -# Copyright (c) 2023, ETH Zurich and UNC Chapel Hill. -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# -# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of -# its contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. - - -import os -import collections -import numpy as np -import struct -import argparse - - -CameraModel = collections.namedtuple( - "CameraModel", ["model_id", "model_name", "num_params"] -) -Camera = collections.namedtuple( - "Camera", ["id", "model", "width", "height", "params"] -) -BaseImage = collections.namedtuple( - "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] -) -Point3D = collections.namedtuple( - "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] -) - - -class Image(BaseImage): - def qvec2rotmat(self): - return qvec2rotmat(self.qvec) - - -CAMERA_MODELS = { - CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), - CameraModel(model_id=1, model_name="PINHOLE", num_params=4), - CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), - CameraModel(model_id=3, model_name="RADIAL", num_params=5), - CameraModel(model_id=4, model_name="OPENCV", num_params=8), - CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), - CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), - CameraModel(model_id=7, model_name="FOV", num_params=5), - CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), - CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), - CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), -} -CAMERA_MODEL_IDS = dict( - [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] -) -CAMERA_MODEL_NAMES = dict( - [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] -) - - -def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): - """Read and unpack the next bytes from a binary file. - :param fid: - :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. - :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. - :param endian_character: Any of {@, =, <, >, !} - :return: Tuple of read and unpacked values. - """ - data = fid.read(num_bytes) - return struct.unpack(endian_character + format_char_sequence, data) - - -def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): - """pack and write to a binary file. - :param fid: - :param data: data to send, if multiple elements are sent at the same time, - they should be encapsuled either in a list or a tuple - :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. - should be the same length as the data list or tuple - :param endian_character: Any of {@, =, <, >, !} - """ - if isinstance(data, (list, tuple)): - bytes = struct.pack(endian_character + format_char_sequence, *data) - else: - bytes = struct.pack(endian_character + format_char_sequence, data) - fid.write(bytes) - - -def read_cameras_text(path): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::WriteCamerasText(const std::string& path) - void Reconstruction::ReadCamerasText(const std::string& path) - """ - cameras = {} - with open(path, "r") as fid: - while True: - line = fid.readline() - if not line: - break - line = line.strip() - if len(line) > 0 and line[0] != "#": - elems = line.split() - camera_id = int(elems[0]) - model = elems[1] - width = int(elems[2]) - height = int(elems[3]) - params = np.array(tuple(map(float, elems[4:]))) - cameras[camera_id] = Camera( - id=camera_id, - model=model, - width=width, - height=height, - params=params, - ) - return cameras - - -def read_cameras_binary(path_to_model_file): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::WriteCamerasBinary(const std::string& path) - void Reconstruction::ReadCamerasBinary(const std::string& path) - """ - cameras = {} - with open(path_to_model_file, "rb") as fid: - num_cameras = read_next_bytes(fid, 8, "Q")[0] - for _ in range(num_cameras): - camera_properties = read_next_bytes( - fid, num_bytes=24, format_char_sequence="iiQQ" - ) - camera_id = camera_properties[0] - model_id = camera_properties[1] - model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name - width = camera_properties[2] - height = camera_properties[3] - num_params = CAMERA_MODEL_IDS[model_id].num_params - params = read_next_bytes( - fid, - num_bytes=8 * num_params, - format_char_sequence="d" * num_params, - ) - cameras[camera_id] = Camera( - id=camera_id, - model=model_name, - width=width, - height=height, - params=np.array(params), - ) - assert len(cameras) == num_cameras - return cameras - - -def write_cameras_text(cameras, path): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::WriteCamerasText(const std::string& path) - void Reconstruction::ReadCamerasText(const std::string& path) - """ - HEADER = ( - "# Camera list with one line of data per camera:\n" - + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" - + "# Number of cameras: {}\n".format(len(cameras)) - ) - with open(path, "w") as fid: - fid.write(HEADER) - for _, cam in cameras.items(): - to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] - line = " ".join([str(elem) for elem in to_write]) - fid.write(line + "\n") - - -def write_cameras_binary(cameras, path_to_model_file): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::WriteCamerasBinary(const std::string& path) - void Reconstruction::ReadCamerasBinary(const std::string& path) - """ - with open(path_to_model_file, "wb") as fid: - write_next_bytes(fid, len(cameras), "Q") - for _, cam in cameras.items(): - model_id = CAMERA_MODEL_NAMES[cam.model].model_id - camera_properties = [cam.id, model_id, cam.width, cam.height] - write_next_bytes(fid, camera_properties, "iiQQ") - for p in cam.params: - write_next_bytes(fid, float(p), "d") - return cameras - - -def read_images_text(path): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadImagesText(const std::string& path) - void Reconstruction::WriteImagesText(const std::string& path) - """ - images = {} - with open(path, "r") as fid: - while True: - line = fid.readline() - if not line: - break - line = line.strip() - if len(line) > 0 and line[0] != "#": - elems = line.split() - image_id = int(elems[0]) - qvec = np.array(tuple(map(float, elems[1:5]))) - tvec = np.array(tuple(map(float, elems[5:8]))) - camera_id = int(elems[8]) - image_name = elems[9] - elems = fid.readline().split() - xys = np.column_stack( - [ - tuple(map(float, elems[0::3])), - tuple(map(float, elems[1::3])), - ] - ) - point3D_ids = np.array(tuple(map(int, elems[2::3]))) - images[image_id] = Image( - id=image_id, - qvec=qvec, - tvec=tvec, - camera_id=camera_id, - name=image_name, - xys=xys, - point3D_ids=point3D_ids, - ) - return images - - -def read_images_binary(path_to_model_file): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadImagesBinary(const std::string& path) - void Reconstruction::WriteImagesBinary(const std::string& path) - """ - images = {} - with open(path_to_model_file, "rb") as fid: - num_reg_images = read_next_bytes(fid, 8, "Q")[0] - for _ in range(num_reg_images): - binary_image_properties = read_next_bytes( - fid, num_bytes=64, format_char_sequence="idddddddi" - ) - image_id = binary_image_properties[0] - qvec = np.array(binary_image_properties[1:5]) - tvec = np.array(binary_image_properties[5:8]) - camera_id = binary_image_properties[8] - binary_image_name = b"" - current_char = read_next_bytes(fid, 1, "c")[0] - while current_char != b"\x00": # look for the ASCII 0 entry - binary_image_name += current_char - current_char = read_next_bytes(fid, 1, "c")[0] - image_name = binary_image_name.decode("utf-8") - num_points2D = read_next_bytes( - fid, num_bytes=8, format_char_sequence="Q" - )[0] - x_y_id_s = read_next_bytes( - fid, - num_bytes=24 * num_points2D, - format_char_sequence="ddq" * num_points2D, - ) - xys = np.column_stack( - [ - tuple(map(float, x_y_id_s[0::3])), - tuple(map(float, x_y_id_s[1::3])), - ] - ) - point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) - images[image_id] = Image( - id=image_id, - qvec=qvec, - tvec=tvec, - camera_id=camera_id, - name=image_name, - xys=xys, - point3D_ids=point3D_ids, - ) - return images - - -def write_images_text(images, path): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadImagesText(const std::string& path) - void Reconstruction::WriteImagesText(const std::string& path) - """ - if len(images) == 0: - mean_observations = 0 - else: - mean_observations = sum( - (len(img.point3D_ids) for _, img in images.items()) - ) / len(images) - HEADER = ( - "# Image list with two lines of data per image:\n" - + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" - + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" - + "# Number of images: {}, mean observations per image: {}\n".format( - len(images), mean_observations - ) - ) - - with open(path, "w") as fid: - fid.write(HEADER) - for _, img in images.items(): - image_header = [ - img.id, - *img.qvec, - *img.tvec, - img.camera_id, - img.name, - ] - first_line = " ".join(map(str, image_header)) - fid.write(first_line + "\n") - - points_strings = [] - for xy, point3D_id in zip(img.xys, img.point3D_ids): - points_strings.append(" ".join(map(str, [*xy, point3D_id]))) - fid.write(" ".join(points_strings) + "\n") - - -def write_images_binary(images, path_to_model_file): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadImagesBinary(const std::string& path) - void Reconstruction::WriteImagesBinary(const std::string& path) - """ - with open(path_to_model_file, "wb") as fid: - write_next_bytes(fid, len(images), "Q") - for _, img in images.items(): - write_next_bytes(fid, img.id, "i") - write_next_bytes(fid, img.qvec.tolist(), "dddd") - write_next_bytes(fid, img.tvec.tolist(), "ddd") - write_next_bytes(fid, img.camera_id, "i") - for char in img.name: - write_next_bytes(fid, char.encode("utf-8"), "c") - write_next_bytes(fid, b"\x00", "c") - write_next_bytes(fid, len(img.point3D_ids), "Q") - for xy, p3d_id in zip(img.xys, img.point3D_ids): - write_next_bytes(fid, [*xy, p3d_id], "ddq") - - -def read_points3D_text(path): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadPoints3DText(const std::string& path) - void Reconstruction::WritePoints3DText(const std::string& path) - """ - points3D = {} - with open(path, "r") as fid: - while True: - line = fid.readline() - if not line: - break - line = line.strip() - if len(line) > 0 and line[0] != "#": - elems = line.split() - point3D_id = int(elems[0]) - xyz = np.array(tuple(map(float, elems[1:4]))) - rgb = np.array(tuple(map(int, elems[4:7]))) - error = float(elems[7]) - image_ids = np.array(tuple(map(int, elems[8::2]))) - point2D_idxs = np.array(tuple(map(int, elems[9::2]))) - points3D[point3D_id] = Point3D( - id=point3D_id, - xyz=xyz, - rgb=rgb, - error=error, - image_ids=image_ids, - point2D_idxs=point2D_idxs, - ) - return points3D - - -def read_points3D_binary(path_to_model_file): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadPoints3DBinary(const std::string& path) - void Reconstruction::WritePoints3DBinary(const std::string& path) - """ - points3D = {} - with open(path_to_model_file, "rb") as fid: - num_points = read_next_bytes(fid, 8, "Q")[0] - for _ in range(num_points): - binary_point_line_properties = read_next_bytes( - fid, num_bytes=43, format_char_sequence="QdddBBBd" - ) - point3D_id = binary_point_line_properties[0] - xyz = np.array(binary_point_line_properties[1:4]) - rgb = np.array(binary_point_line_properties[4:7]) - error = np.array(binary_point_line_properties[7]) - track_length = read_next_bytes( - fid, num_bytes=8, format_char_sequence="Q" - )[0] - track_elems = read_next_bytes( - fid, - num_bytes=8 * track_length, - format_char_sequence="ii" * track_length, - ) - image_ids = np.array(tuple(map(int, track_elems[0::2]))) - point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) - points3D[point3D_id] = Point3D( - id=point3D_id, - xyz=xyz, - rgb=rgb, - error=error, - image_ids=image_ids, - point2D_idxs=point2D_idxs, - ) - return points3D - - -def write_points3D_text(points3D, path): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadPoints3DText(const std::string& path) - void Reconstruction::WritePoints3DText(const std::string& path) - """ - if len(points3D) == 0: - mean_track_length = 0 - else: - mean_track_length = sum( - (len(pt.image_ids) for _, pt in points3D.items()) - ) / len(points3D) - HEADER = ( - "# 3D point list with one line of data per point:\n" - + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" - + "# Number of points: {}, mean track length: {}\n".format( - len(points3D), mean_track_length - ) - ) - - with open(path, "w") as fid: - fid.write(HEADER) - for _, pt in points3D.items(): - point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] - fid.write(" ".join(map(str, point_header)) + " ") - track_strings = [] - for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): - track_strings.append(" ".join(map(str, [image_id, point2D]))) - fid.write(" ".join(track_strings) + "\n") - - -def write_points3D_binary(points3D, path_to_model_file): - """ - see: src/colmap/scene/reconstruction.cc - void Reconstruction::ReadPoints3DBinary(const std::string& path) - void Reconstruction::WritePoints3DBinary(const std::string& path) - """ - with open(path_to_model_file, "wb") as fid: - write_next_bytes(fid, len(points3D), "Q") - for _, pt in points3D.items(): - write_next_bytes(fid, pt.id, "Q") - write_next_bytes(fid, pt.xyz.tolist(), "ddd") - write_next_bytes(fid, pt.rgb.tolist(), "BBB") - write_next_bytes(fid, pt.error, "d") - track_length = pt.image_ids.shape[0] - write_next_bytes(fid, track_length, "Q") - for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): - write_next_bytes(fid, [image_id, point2D_id], "ii") - - -def detect_model_format(path, ext): - if ( - os.path.isfile(os.path.join(path, "cameras" + ext)) - and os.path.isfile(os.path.join(path, "images" + ext)) - and os.path.isfile(os.path.join(path, "points3D" + ext)) - ): - print("Detected model format: '" + ext + "'") - return True - - return False - - -def read_model(path, ext=""): - # try to detect the extension automatically - if ext == "": - if detect_model_format(path, ".bin"): - ext = ".bin" - elif detect_model_format(path, ".txt"): - ext = ".txt" - else: - print("Provide model format: '.bin' or '.txt'") - return - - if ext == ".txt": - cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) - images = read_images_text(os.path.join(path, "images" + ext)) - points3D = read_points3D_text(os.path.join(path, "points3D") + ext) - else: - cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) - images = read_images_binary(os.path.join(path, "images" + ext)) - points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) - return cameras, images, points3D - - -def write_model(cameras, images, points3D, path, ext=".bin"): - if ext == ".txt": - write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) - write_images_text(images, os.path.join(path, "images" + ext)) - write_points3D_text(points3D, os.path.join(path, "points3D") + ext) - else: - write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) - write_images_binary(images, os.path.join(path, "images" + ext)) - write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) - return cameras, images, points3D - - -def qvec2rotmat(qvec): - return np.array( - [ - [ - 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, - 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], - 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], - ], - [ - 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], - 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, - 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], - ], - [ - 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], - 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], - 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, - ], - ] - ) - - -def rotmat2qvec(R): - Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat - K = ( - np.array( - [ - [Rxx - Ryy - Rzz, 0, 0, 0], - [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], - [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], - [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], - ] - ) - / 3.0 - ) - eigvals, eigvecs = np.linalg.eigh(K) - qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] - if qvec[0] < 0: - qvec *= -1 - return qvec - - -def main(): - parser = argparse.ArgumentParser( - description="Read and write COLMAP binary and text models" - ) - parser.add_argument("--input_model", help="path to input model folder") - parser.add_argument( - "--input_format", - choices=[".bin", ".txt"], - help="input model format", - default="", - ) - parser.add_argument("--output_model", help="path to output model folder") - parser.add_argument( - "--output_format", - choices=[".bin", ".txt"], - help="outut model format", - default=".txt", - ) - args = parser.parse_args() - - cameras, images, points3D = read_model( - path=args.input_model, ext=args.input_format - ) - - print("num_cameras:", len(cameras)) - print("num_images:", len(images)) - print("num_points3D:", len(points3D)) - - if args.output_model is not None: - write_model( - cameras, - images, - points3D, - path=args.output_model, - ext=args.output_format, - ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py deleted file mode 100644 index b726149996591c6c3db69230e1bb68c07d2faa12..0000000000000000000000000000000000000000 --- a/sgm/data/dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import Optional - -import torchdata.datapipes.iter -import webdataset as wds -from omegaconf import DictConfig -from pytorch_lightning import LightningDataModule - -try: - from sdata import create_dataset, create_dummy_dataset, create_loader -except ImportError as e: - print("#" * 100) - print("Datasets not yet available") - print("to enable, we need to add stable-datasets as a submodule") - print("please use ``git submodule update --init --recursive``") - print("and do ``pip install -e stable-datasets/`` from the root of this repo") - print("#" * 100) - exit(1) - - -class StableDataModuleFromConfig(LightningDataModule): - def __init__( - self, - train: DictConfig, - validation: Optional[DictConfig] = None, - test: Optional[DictConfig] = None, - skip_val_loader: bool = False, - dummy: bool = False, - ): - super().__init__() - self.train_config = train - assert ( - "datapipeline" in self.train_config and "loader" in self.train_config - ), "train config requires the fields `datapipeline` and `loader`" - - self.val_config = validation - if not skip_val_loader: - if self.val_config is not None: - assert ( - "datapipeline" in self.val_config and "loader" in self.val_config - ), "validation config requires the fields `datapipeline` and `loader`" - else: - print( - "Warning: No Validation datapipeline defined, using that one from training" - ) - self.val_config = train - - self.test_config = test - if self.test_config is not None: - assert ( - "datapipeline" in self.test_config and "loader" in self.test_config - ), "test config requires the fields `datapipeline` and `loader`" - - self.dummy = dummy - if self.dummy: - print("#" * 100) - print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") - print("#" * 100) - - def setup(self, stage: str) -> None: - print("Preparing datasets") - if self.dummy: - data_fn = create_dummy_dataset - else: - data_fn = create_dataset - - self.train_datapipeline = data_fn(**self.train_config.datapipeline) - if self.val_config: - self.val_datapipeline = data_fn(**self.val_config.datapipeline) - if self.test_config: - self.test_datapipeline = data_fn(**self.test_config.datapipeline) - - def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: - loader = create_loader(self.train_datapipeline, **self.train_config.loader) - return loader - - def val_dataloader(self) -> wds.DataPipeline: - return create_loader(self.val_datapipeline, **self.val_config.loader) - - def test_dataloader(self) -> wds.DataPipeline: - return create_loader(self.test_datapipeline, **self.test_config.loader) diff --git a/sgm/data/joint3d.py b/sgm/data/joint3d.py deleted file mode 100644 index 0569210466a2391bdbb3be358c5cd8f8477aeba1..0000000000000000000000000000000000000000 --- a/sgm/data/joint3d.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -from torch.utils.data import Dataset - -default_sub_data_config = {} - - -class Joint3D(Dataset): - def __init__(self, sub_data_config: dict) -> None: - super().__init__() - self.sub_data_config = sub_data_config diff --git a/sgm/data/json_index_dataset.py b/sgm/data/json_index_dataset.py deleted file mode 100644 index 16f1dbf3bbae4fb6861f45703d1493914ffaf791..0000000000000000000000000000000000000000 --- a/sgm/data/json_index_dataset.py +++ /dev/null @@ -1,1080 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import functools -import gzip -import hashlib -import json -import logging -import os -import random -import warnings -from collections import defaultdict -from itertools import islice -from pathlib import Path -from typing import ( - Any, - ClassVar, - Dict, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TYPE_CHECKING, - Union, -) - -import numpy as np -import torch -from PIL import Image -from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from pytorch3d.io import IO -from pytorch3d.renderer.camera_utils import join_cameras_as_batch -from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras -from pytorch3d.structures.pointclouds import Pointclouds -from tqdm import tqdm - -from pytorch3d.implicitron.dataset import types -from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData -from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar - - -logger = logging.getLogger(__name__) - - -if TYPE_CHECKING: - from typing import TypedDict - - class FrameAnnotsEntry(TypedDict): - subset: Optional[str] - frame_annotation: types.FrameAnnotation - -else: - FrameAnnotsEntry = dict - - -@registry.register -class JsonIndexDataset(DatasetBase, ReplaceableBase): - """ - A dataset with annotations in json files like the Common Objects in 3D - (CO3D) dataset. - - Args: - frame_annotations_file: A zipped json file containing metadata of the - frames in the dataset, serialized List[types.FrameAnnotation]. - sequence_annotations_file: A zipped json file containing metadata of the - sequences in the dataset, serialized List[types.SequenceAnnotation]. - subset_lists_file: A json file containing the lists of frames corresponding - corresponding to different subsets (e.g. train/val/test) of the dataset; - format: {subset: (sequence_name, frame_id, file_path)}. - subsets: Restrict frames/sequences only to the given list of subsets - as defined in subset_lists_file (see above). - limit_to: Limit the dataset to the first #limit_to frames (after other - filters have been applied). - limit_sequences_to: Limit the dataset to the first - #limit_sequences_to sequences (after other sequence filters have been - applied but before frame-based filters). - pick_sequence: A list of sequence names to restrict the dataset to. - exclude_sequence: A list of the names of the sequences to exclude. - limit_category_to: Restrict the dataset to the given list of categories. - dataset_root: The root folder of the dataset; all the paths in jsons are - specified relative to this root (but not json paths themselves). - load_images: Enable loading the frame RGB data. - load_depths: Enable loading the frame depth maps. - load_depth_masks: Enable loading the frame depth map masks denoting the - depth values used for evaluation (the points consistent across views). - load_masks: Enable loading frame foreground masks. - load_point_clouds: Enable loading sequence-level point clouds. - max_points: Cap on the number of loaded points in the point cloud; - if reached, they are randomly sampled without replacement. - mask_images: Whether to mask the images with the loaded foreground masks; - 0 value is used for background. - mask_depths: Whether to mask the depth maps with the loaded foreground - masks; 0 value is used for background. - image_height: The height of the returned images, masks, and depth maps; - aspect ratio is preserved during cropping/resizing. - image_width: The width of the returned images, masks, and depth maps; - aspect ratio is preserved during cropping/resizing. - box_crop: Enable cropping of the image around the bounding box inferred - from the foreground region of the loaded segmentation mask; masks - and depth maps are cropped accordingly; cameras are corrected. - box_crop_mask_thr: The threshold used to separate pixels into foreground - and background based on the foreground_probability mask; if no value - is greater than this threshold, the loader lowers it and repeats. - box_crop_context: The amount of additional padding added to each - dimension of the cropping bounding box, relative to box size. - remove_empty_masks: Removes the frames with no active foreground pixels - in the segmentation mask after thresholding (see box_crop_mask_thr). - n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence - frames in each sequences uniformly without replacement if it has - more frames than that; applied before other frame-level filters. - seed: The seed of the random generator sampling #n_frames_per_sequence - random frames per sequence. - sort_frames: Enable frame annotations sorting to group frames from the - same sequences together and order them by timestamps - eval_batches: A list of batches that form the evaluation set; - list of batch-sized lists of indices corresponding to __getitem__ - of this class, thus it can be used directly as a batch sampler. - eval_batch_index: - ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) - A list of batches of frames described as (sequence_name, frame_idx) - that can form the evaluation set, `eval_batches` will be set from this. - - """ - - frame_annotations_type: ClassVar[ - Type[types.FrameAnnotation] - ] = types.FrameAnnotation - - path_manager: Any = None - frame_annotations_file: str = "" - sequence_annotations_file: str = "" - subset_lists_file: str = "" - subsets: Optional[List[str]] = None - limit_to: int = 0 - limit_sequences_to: int = 0 - pick_sequence: Tuple[str, ...] = () - exclude_sequence: Tuple[str, ...] = () - limit_category_to: Tuple[int, ...] = () - dataset_root: str = "" - load_images: bool = True - load_depths: bool = True - load_depth_masks: bool = True - load_masks: bool = True - load_point_clouds: bool = False - max_points: int = 0 - mask_images: bool = False - mask_depths: bool = False - image_height: Optional[int] = 800 - image_width: Optional[int] = 800 - box_crop: bool = True - box_crop_mask_thr: float = 0.4 - box_crop_context: float = 0.3 - remove_empty_masks: bool = True - n_frames_per_sequence: int = -1 - seed: int = 0 - sort_frames: bool = False - eval_batches: Any = None - eval_batch_index: Any = None - # frame_annots: List[FrameAnnotsEntry] = field(init=False) - # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) - - def __post_init__(self) -> None: - # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. - self.subset_to_image_path = None - self._load_frames() - self._load_sequences() - if self.sort_frames: - self._sort_frames() - self._load_subset_lists() - self._filter_db() # also computes sequence indices - self._extract_and_set_eval_batches() - logger.info(str(self)) - - def _extract_and_set_eval_batches(self): - """ - Sets eval_batches based on input eval_batch_index. - """ - if self.eval_batch_index is not None: - if self.eval_batches is not None: - raise ValueError( - "Cannot define both eval_batch_index and eval_batches." - ) - self.eval_batches = self.seq_frame_index_to_dataset_index( - self.eval_batch_index - ) - - def join(self, other_datasets: Iterable[DatasetBase]) -> None: - """ - Join the dataset with other JsonIndexDataset objects. - - Args: - other_datasets: A list of JsonIndexDataset objects to be joined - into the current dataset. - """ - if not all(isinstance(d, JsonIndexDataset) for d in other_datasets): - raise ValueError("This function can only join a list of JsonIndexDataset") - # pyre-ignore[16] - self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots]) - # pyre-ignore[16] - self.seq_annots.update( - # https://gist.github.com/treyhunner/f35292e676efa0be1728 - functools.reduce( - lambda a, b: {**a, **b}, - [d.seq_annots for d in other_datasets], # pyre-ignore[16] - ) - ) - all_eval_batches = [ - self.eval_batches, - # pyre-ignore - *[d.eval_batches for d in other_datasets], - ] - if not ( - all(ba is None for ba in all_eval_batches) - or all(ba is not None for ba in all_eval_batches) - ): - raise ValueError( - "When joining datasets, either all joined datasets have to have their" - " eval_batches defined, or all should have their eval batches undefined." - ) - if self.eval_batches is not None: - self.eval_batches = sum(all_eval_batches, []) - self._invalidate_indexes(filter_seq_annots=True) - - def is_filtered(self) -> bool: - """ - Returns `True` in case the dataset has been filtered and thus some frame annotations - stored on the disk might be missing in the dataset object. - - Returns: - is_filtered: `True` if the dataset has been filtered, else `False`. - """ - return ( - self.remove_empty_masks - or self.limit_to > 0 - or self.limit_sequences_to > 0 - or len(self.pick_sequence) > 0 - or len(self.exclude_sequence) > 0 - or len(self.limit_category_to) > 0 - or self.n_frames_per_sequence > 0 - ) - - def seq_frame_index_to_dataset_index( - self, - seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], - allow_missing_indices: bool = False, - remove_missing_indices: bool = False, - suppress_missing_index_warning: bool = True, - ) -> List[List[Union[Optional[int], int]]]: - """ - Obtain indices into the dataset object given a list of frame ids. - - Args: - seq_frame_index: The list of frame ids specified as - `List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally, - Image paths relative to the dataset_root can be stored specified as well: - `List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]` - allow_missing_indices: If `False`, throws an IndexError upon reaching the first - entry from `seq_frame_index` which is missing in the dataset. - Otherwise, depending on `remove_missing_indices`, either returns `None` - in place of missing entries or removes the indices of missing entries. - remove_missing_indices: Active when `allow_missing_indices=True`. - If `False`, returns `None` in place of `seq_frame_index` entries that - are not present in the dataset. - If `True` removes missing indices from the returned indices. - suppress_missing_index_warning: - Active if `allow_missing_indices==True`. Suppressess a warning message - in case an entry from `seq_frame_index` is missing in the dataset - (expected in certain cases - e.g. when setting - `self.remove_empty_masks=True`). - - Returns: - dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`. - """ - _dataset_seq_frame_n_index = { - seq: { - # pyre-ignore[16] - self.frame_annots[idx]["frame_annotation"].frame_number: idx - for idx in seq_idx - } - # pyre-ignore[16] - for seq, seq_idx in self._seq_to_idx.items() - } - - def _get_dataset_idx( - seq_name: str, frame_no: int, path: Optional[str] = None - ) -> Optional[int]: - idx_seq = _dataset_seq_frame_n_index.get(seq_name, None) - idx = idx_seq.get(frame_no, None) if idx_seq is not None else None - if idx is None: - msg = ( - f"sequence_name={seq_name} / frame_number={frame_no}" - " not in the dataset!" - ) - if not allow_missing_indices: - raise IndexError(msg) - if not suppress_missing_index_warning: - warnings.warn(msg) - return idx - if path is not None: - # Check that the loaded frame path is consistent - # with the one stored in self.frame_annots. - assert os.path.normpath( - # pyre-ignore[16] - self.frame_annots[idx]["frame_annotation"].image.path - ) == os.path.normpath( - path - ), f"Inconsistent frame indices {seq_name, frame_no, path}." - return idx - - dataset_idx = [ - [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6] - for batch in seq_frame_index - ] - - if allow_missing_indices and remove_missing_indices: - # remove all None indices, and also batches with only None entries - valid_dataset_idx = [ - [b for b in batch if b is not None] for batch in dataset_idx - ] - return [ # pyre-ignore[7] - batch for batch in valid_dataset_idx if len(batch) > 0 - ] - - return dataset_idx - - def subset_from_frame_index( - self, - frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]], - allow_missing_indices: bool = True, - ) -> "JsonIndexDataset": - """ - Generate a dataset subset given the list of frames specified in `frame_index`. - - Args: - frame_index: The list of frame indentifiers (as stored in the metadata) - specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally, - Image paths relative to the dataset_root can be stored specified as well: - `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`, - in the latter case, if imaga_path do not match the stored paths, an error - is raised. - allow_missing_indices: If `False`, throws an IndexError upon reaching the first - entry from `frame_index` which is missing in the dataset. - Otherwise, generates a subset consisting of frames entries that actually - exist in the dataset. - """ - # Get the indices into the frame annots. - dataset_indices = self.seq_frame_index_to_dataset_index( - [frame_index], - allow_missing_indices=self.is_filtered() and allow_missing_indices, - )[0] - valid_dataset_indices = [i for i in dataset_indices if i is not None] - - # Deep copy the whole dataset except frame_annots, which are large so we - # deep copy only the requested subset of frame_annots. - memo = {id(self.frame_annots): None} # pyre-ignore[16] - dataset_new = copy.deepcopy(self, memo) - dataset_new.frame_annots = copy.deepcopy( - [self.frame_annots[i] for i in valid_dataset_indices] - ) - - # This will kill all unneeded sequence annotations. - dataset_new._invalidate_indexes(filter_seq_annots=True) - - # Finally annotate the frame annotations with the name of the subset - # stored in meta. - for frame_annot in dataset_new.frame_annots: - frame_annotation = frame_annot["frame_annotation"] - if frame_annotation.meta is not None: - frame_annot["subset"] = frame_annotation.meta.get("frame_type", None) - - # A sanity check - this will crash in case some entries from frame_index are missing - # in dataset_new. - valid_frame_index = [ - fi for fi, di in zip(frame_index, dataset_indices) if di is not None - ] - dataset_new.seq_frame_index_to_dataset_index( - [valid_frame_index], allow_missing_indices=False - ) - - return dataset_new - - def __str__(self) -> str: - # pyre-ignore[16] - return f"JsonIndexDataset #frames={len(self.frame_annots)}" - - def __len__(self) -> int: - # pyre-ignore[16] - return len(self.frame_annots) - - def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: - return entry["subset"] - - def get_all_train_cameras(self) -> CamerasBase: - """ - Returns the cameras corresponding to all the known frames. - """ - logger.info("Loading all train cameras.") - cameras = [] - # pyre-ignore[16] - for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)): - frame_type = self._get_frame_type(frame_annot) - if frame_type is None: - raise ValueError("subsets not loaded") - if is_known_frame_scalar(frame_type): - cameras.append(self[frame_idx].camera) - return join_cameras_as_batch(cameras) - - def __getitem__(self, index) -> FrameData: - # pyre-ignore[16] - if index >= len(self.frame_annots): - raise IndexError(f"index {index} out of range {len(self.frame_annots)}") - - entry = self.frame_annots[index]["frame_annotation"] - # pyre-ignore[16] - point_cloud = self.seq_annots[entry.sequence_name].point_cloud - frame_data = FrameData( - frame_number=_safe_as_tensor(entry.frame_number, torch.long), - frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), - sequence_name=entry.sequence_name, - sequence_category=self.seq_annots[entry.sequence_name].category, - camera_quality_score=_safe_as_tensor( - self.seq_annots[entry.sequence_name].viewpoint_quality_score, - torch.float, - ), - point_cloud_quality_score=_safe_as_tensor( - point_cloud.quality_score, torch.float - ) - if point_cloud is not None - else None, - ) - - # The rest of the fields are optional - frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) - - ( - frame_data.fg_probability, - frame_data.mask_path, - frame_data.bbox_xywh, - clamp_bbox_xyxy, - frame_data.crop_bbox_xywh, - ) = self._load_crop_fg_probability(entry) - - scale = 1.0 - if self.load_images and entry.image is not None: - # original image size - frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) - - ( - frame_data.image_rgb, - frame_data.image_path, - frame_data.mask_crop, - scale, - ) = self._load_crop_images( - entry, frame_data.fg_probability, clamp_bbox_xyxy - ) - - if self.load_depths and entry.depth is not None: - ( - frame_data.depth_map, - frame_data.depth_path, - frame_data.depth_mask, - ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) - - if entry.viewpoint is not None: - frame_data.camera = self._get_pytorch3d_camera( - entry, - scale, - clamp_bbox_xyxy, - ) - - if self.load_point_clouds and point_cloud is not None: - pcl_path = self._fix_point_cloud_path(point_cloud.path) - frame_data.sequence_point_cloud = _load_pointcloud( - self._local_path(pcl_path), max_points=self.max_points - ) - frame_data.sequence_point_cloud_path = pcl_path - - return frame_data - - def _fix_point_cloud_path(self, path: str) -> str: - """ - Fix up a point cloud path from the dataset. - Some files in Co3Dv2 have an accidental absolute path stored. - """ - unwanted_prefix = ( - "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" - ) - if path.startswith(unwanted_prefix): - path = path[len(unwanted_prefix) :] - return os.path.join(self.dataset_root, path) - - def _load_crop_fg_probability( - self, entry: types.FrameAnnotation - ) -> Tuple[ - Optional[torch.Tensor], - Optional[str], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ]: - fg_probability = None - full_path = None - bbox_xywh = None - clamp_bbox_xyxy = None - crop_box_xywh = None - - if (self.load_masks or self.box_crop) and entry.mask is not None: - full_path = os.path.join(self.dataset_root, entry.mask.path) - mask = _load_mask(self._local_path(full_path)) - - if mask.shape[-2:] != entry.image.size: - raise ValueError( - f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" - ) - - bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) - - if self.box_crop: - clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( - _get_clamp_bbox( - bbox_xywh, - image_path=entry.image.path, - box_crop_context=self.box_crop_context, - ), - image_size_hw=tuple(mask.shape[-2:]), - ) - crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) - - mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) - - fg_probability, _, _ = self._resize_image(mask, mode="nearest") - - return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh - - def _load_crop_images( - self, - entry: types.FrameAnnotation, - fg_probability: Optional[torch.Tensor], - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: - assert self.dataset_root is not None and entry.image is not None - path = os.path.join(self.dataset_root, entry.image.path) - image_rgb = _load_image(self._local_path(path)) - - if image_rgb.shape[-2:] != entry.image.size: - raise ValueError( - f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" - ) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) - - image_rgb, scale, mask_crop = self._resize_image(image_rgb) - - if self.mask_images: - assert fg_probability is not None - image_rgb *= fg_probability - - return image_rgb, path, mask_crop, scale - - def _load_mask_depth( - self, - entry: types.FrameAnnotation, - clamp_bbox_xyxy: Optional[torch.Tensor], - fg_probability: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str, torch.Tensor]: - entry_depth = entry.depth - assert entry_depth is not None - path = os.path.join(self.dataset_root, entry_depth.path) - depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] - ) - depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) - - depth_map, _, _ = self._resize_image(depth_map, mode="nearest") - - if self.mask_depths: - assert fg_probability is not None - depth_map *= fg_probability - - if self.load_depth_masks: - assert entry_depth.mask_path is not None - mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) - depth_mask = _load_depth_mask(self._local_path(mask_path)) - - if self.box_crop: - assert clamp_bbox_xyxy is not None - depth_mask_bbox_xyxy = _rescale_bbox( - clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] - ) - depth_mask = _crop_around_box( - depth_mask, depth_mask_bbox_xyxy, mask_path - ) - - depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") - else: - depth_mask = torch.ones_like(depth_map) - - return depth_map, path, depth_mask - - def _get_pytorch3d_camera( - self, - entry: types.FrameAnnotation, - scale: float, - clamp_bbox_xyxy: Optional[torch.Tensor], - ) -> PerspectiveCameras: - entry_viewpoint = entry.viewpoint - assert entry_viewpoint is not None - # principal point and focal length - principal_point = torch.tensor( - entry_viewpoint.principal_point, dtype=torch.float - ) - focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) - - half_image_size_wh_orig = ( - torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 - ) - - # first, we convert from the dataset's NDC convention to pixels - format = entry_viewpoint.intrinsics_format - if format.lower() == "ndc_norm_image_bounds": - # this is e.g. currently used in CO3D for storing intrinsics - rescale = half_image_size_wh_orig - elif format.lower() == "ndc_isotropic": - rescale = half_image_size_wh_orig.min() - else: - raise ValueError(f"Unknown intrinsics format: {format}") - - # principal point and focal length in pixels - principal_point_px = half_image_size_wh_orig - principal_point * rescale - focal_length_px = focal_length * rescale - if self.box_crop: - assert clamp_bbox_xyxy is not None - principal_point_px -= clamp_bbox_xyxy[:2] - - # now, convert from pixels to PyTorch3D v0.5+ NDC convention - if self.image_height is None or self.image_width is None: - out_size = list(reversed(entry.image.size)) - else: - out_size = [self.image_width, self.image_height] - - half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 - half_min_image_size_output = half_image_size_output.min() - - # rescaled principal point and focal length in ndc - principal_point = ( - half_image_size_output - principal_point_px * scale - ) / half_min_image_size_output - focal_length = focal_length_px * scale / half_min_image_size_output - - return PerspectiveCameras( - focal_length=focal_length[None], - principal_point=principal_point[None], - R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], - T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], - ) - - def _load_frames(self) -> None: - logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.") - local_file = self._local_path(self.frame_annotations_file) - with gzip.open(local_file, "rt", encoding="utf8") as zipfile: - frame_annots_list = types.load_dataclass( - zipfile, List[self.frame_annotations_type] - ) - if not frame_annots_list: - raise ValueError("Empty dataset!") - # pyre-ignore[16] - self.frame_annots = [ - FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list - ] - - def _load_sequences(self) -> None: - logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.") - local_file = self._local_path(self.sequence_annotations_file) - with gzip.open(local_file, "rt", encoding="utf8") as zipfile: - seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) - if not seq_annots: - raise ValueError("Empty sequences file!") - # pyre-ignore[16] - self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} - - def _load_subset_lists(self) -> None: - logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") - if not self.subset_lists_file: - return - - with open(self._local_path(self.subset_lists_file), "r") as f: - subset_to_seq_frame = json.load(f) - - frame_path_to_subset = { - path: subset - for subset, frames in subset_to_seq_frame.items() - for _, _, path in frames - } - # pyre-ignore[16] - for frame in self.frame_annots: - frame["subset"] = frame_path_to_subset.get( - frame["frame_annotation"].image.path, None - ) - if frame["subset"] is None: - warnings.warn( - "Subset lists are given but don't include " - + frame["frame_annotation"].image.path - ) - - def _sort_frames(self) -> None: - # Sort frames to have them grouped by sequence, ordered by timestamp - # pyre-ignore[16] - self.frame_annots = sorted( - self.frame_annots, - key=lambda f: ( - f["frame_annotation"].sequence_name, - f["frame_annotation"].frame_timestamp or 0, - ), - ) - - def _filter_db(self) -> None: - if self.remove_empty_masks: - logger.info("Removing images with empty masks.") - # pyre-ignore[16] - old_len = len(self.frame_annots) - - msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." - - def positive_mass(frame_annot: types.FrameAnnotation) -> bool: - mask = frame_annot.mask - if mask is None: - return False - if mask.mass is None: - raise ValueError(msg) - return mask.mass > 1 - - self.frame_annots = [ - frame - for frame in self.frame_annots - if positive_mass(frame["frame_annotation"]) - ] - logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) - - # this has to be called after joining with categories!! - subsets = self.subsets - if subsets: - if not self.subset_lists_file: - raise ValueError( - "Subset filter is on but subset_lists_file was not given" - ) - - logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") - - # truncate the list of subsets to the valid one - self.frame_annots = [ - entry for entry in self.frame_annots if entry["subset"] in subsets - ] - if len(self.frame_annots) == 0: - raise ValueError(f"There are no frames in the '{subsets}' subsets!") - - self._invalidate_indexes(filter_seq_annots=True) - - if len(self.limit_category_to) > 0: - logger.info(f"Limiting dataset to categories: {self.limit_category_to}") - # pyre-ignore[16] - self.seq_annots = { - name: entry - for name, entry in self.seq_annots.items() - if entry.category in self.limit_category_to - } - - # sequence filters - for prefix in ("pick", "exclude"): - orig_len = len(self.seq_annots) - attr = f"{prefix}_sequence" - arr = getattr(self, attr) - if len(arr) > 0: - logger.info(f"{attr}: {str(arr)}") - self.seq_annots = { - name: entry - for name, entry in self.seq_annots.items() - if (name in arr) == (prefix == "pick") - } - logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) - - if self.limit_sequences_to > 0: - self.seq_annots = dict( - islice(self.seq_annots.items(), self.limit_sequences_to) - ) - - # retain only frames from retained sequences - self.frame_annots = [ - f - for f in self.frame_annots - if f["frame_annotation"].sequence_name in self.seq_annots - ] - - self._invalidate_indexes() - - if self.n_frames_per_sequence > 0: - logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") - keep_idx = [] - # pyre-ignore[16] - for seq, seq_indices in self._seq_to_idx.items(): - # infer the seed from the sequence name, this is reproducible - # and makes the selection differ for different sequences - seed = _seq_name_to_seed(seq) + self.seed - seq_idx_shuffled = random.Random(seed).sample( - sorted(seq_indices), len(seq_indices) - ) - keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) - - logger.info( - "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) - ) - self.frame_annots = [self.frame_annots[i] for i in keep_idx] - self._invalidate_indexes(filter_seq_annots=False) - # sequences are not decimated, so self.seq_annots is valid - - if self.limit_to > 0 and self.limit_to < len(self.frame_annots): - logger.info( - "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) - ) - self.frame_annots = self.frame_annots[: self.limit_to] - self._invalidate_indexes(filter_seq_annots=True) - - def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: - # update _seq_to_idx and filter seq_meta according to frame_annots change - # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx - self._invalidate_seq_to_idx() - - if filter_seq_annots: - # pyre-ignore[16] - self.seq_annots = { - k: v - for k, v in self.seq_annots.items() - # pyre-ignore[16] - if k in self._seq_to_idx - } - - def _invalidate_seq_to_idx(self) -> None: - seq_to_idx = defaultdict(list) - # pyre-ignore[16] - for idx, entry in enumerate(self.frame_annots): - seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) - # pyre-ignore[16] - self._seq_to_idx = seq_to_idx - - def _resize_image( - self, image, mode="bilinear" - ) -> Tuple[torch.Tensor, float, torch.Tensor]: - image_height, image_width = self.image_height, self.image_width - if image_height is None or image_width is None: - # skip the resizing - imre_ = torch.from_numpy(image) - return imre_, 1.0, torch.ones_like(imre_[:1]) - # takes numpy array, returns pytorch tensor - minscale = min( - image_height / image.shape[-2], - image_width / image.shape[-1], - ) - imre = torch.nn.functional.interpolate( - torch.from_numpy(image)[None], - scale_factor=minscale, - mode=mode, - align_corners=False if mode == "bilinear" else None, - recompute_scale_factor=True, - )[0] - # pyre-fixme[19]: Expected 1 positional argument. - imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) - imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre - # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. - # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. - mask = torch.zeros(1, self.image_height, self.image_width) - mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 - return imre_, minscale, mask - - def _local_path(self, path: str) -> str: - if self.path_manager is None: - return path - return self.path_manager.get_local_path(path) - - def get_frame_numbers_and_timestamps( - self, idxs: Sequence[int] - ) -> List[Tuple[int, float]]: - out: List[Tuple[int, float]] = [] - for idx in idxs: - # pyre-ignore[16] - frame_annotation = self.frame_annots[idx]["frame_annotation"] - out.append( - (frame_annotation.frame_number, frame_annotation.frame_timestamp) - ) - return out - - def category_to_sequence_names(self) -> Dict[str, List[str]]: - c2seq = defaultdict(list) - # pyre-ignore - for sequence_name, sa in self.seq_annots.items(): - c2seq[sa.category].append(sequence_name) - return dict(c2seq) - - def get_eval_batches(self) -> Optional[List[List[int]]]: - return self.eval_batches - - -def _seq_name_to_seed(seq_name) -> int: - return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16) - - -def _load_image(path) -> np.ndarray: - with Image.open(path) as pil_im: - im = np.array(pil_im.convert("RGB")) - im = im.transpose((2, 0, 1)) - im = im.astype(np.float32) / 255.0 - return im - - -def _load_16big_png_depth(depth_png) -> np.ndarray: - with Image.open(depth_png) as depth_pil: - # the image is stored with 16-bit depth but PIL reads it as I (32 bit). - # we cast it to uint16, then reinterpret as float16, then cast to float32 - depth = ( - np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) - .astype(np.float32) - .reshape((depth_pil.size[1], depth_pil.size[0])) - ) - return depth - - -def _load_1bit_png_mask(file: str) -> np.ndarray: - with Image.open(file) as pil_im: - mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) - return mask - - -def _load_depth_mask(path: str) -> np.ndarray: - if not path.lower().endswith(".png"): - raise ValueError('unsupported depth mask file name "%s"' % path) - m = _load_1bit_png_mask(path) - return m[None] # fake feature channel - - -def _load_depth(path, scale_adjustment) -> np.ndarray: - if not path.lower().endswith(".png"): - raise ValueError('unsupported depth file name "%s"' % path) - - d = _load_16big_png_depth(path) * scale_adjustment - d[~np.isfinite(d)] = 0.0 - return d[None] # fake feature channel - - -def _load_mask(path) -> np.ndarray: - with Image.open(path) as pil_im: - mask = np.array(pil_im) - mask = mask.astype(np.float32) / 255.0 - return mask[None] # fake feature channel - - -def _get_1d_bounds(arr) -> Tuple[int, int]: - nz = np.flatnonzero(arr) - return nz[0], nz[-1] + 1 - - -def _get_bbox_from_mask( - mask, thr, decrease_quant: float = 0.05 -) -> Tuple[int, int, int, int]: - # bbox in xywh - masks_for_box = np.zeros_like(mask) - while masks_for_box.sum() <= 1.0: - masks_for_box = (mask > thr).astype(np.float32) - thr -= decrease_quant - if thr <= 0.0: - warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") - - x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) - y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) - - return x0, y0, x1 - x0, y1 - y0 - - -def _get_clamp_bbox( - bbox: torch.Tensor, - box_crop_context: float = 0.0, - image_path: str = "", -) -> torch.Tensor: - # box_crop_context: rate of expansion for bbox - # returns possibly expanded bbox xyxy as float - - bbox = bbox.clone() # do not edit bbox in place - - # increase box size - if box_crop_context > 0.0: - c = box_crop_context - bbox = bbox.float() - bbox[0] -= bbox[2] * c / 2 - bbox[1] -= bbox[3] * c / 2 - bbox[2] += bbox[2] * c - bbox[3] += bbox[3] * c - - if (bbox[2:] <= 1.0).any(): - raise ValueError( - f"squashed image {image_path}!! The bounding box contains no pixels." - ) - - bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes - bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) - - return bbox_xyxy - - -def _crop_around_box(tensor, bbox, impath: str = ""): - # bbox is xyxy, where the upper bound is corrected with +1 - bbox = _clamp_box_to_image_bounds_and_round( - bbox, - image_size_hw=tensor.shape[-2:], - ) - tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] - assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" - return tensor - - -def _clamp_box_to_image_bounds_and_round( - bbox_xyxy: torch.Tensor, - image_size_hw: Tuple[int, int], -) -> torch.LongTensor: - bbox_xyxy = bbox_xyxy.clone() - bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) - bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) - if not isinstance(bbox_xyxy, torch.LongTensor): - bbox_xyxy = bbox_xyxy.round().long() - return bbox_xyxy # pyre-ignore [7] - - -def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: - assert bbox is not None - assert np.prod(orig_res) > 1e-8 - # average ratio of dimensions - rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 - return bbox * rel_size - - -def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: - wh = xyxy[2:] - xyxy[:2] - xywh = torch.cat([xyxy[:2], wh]) - return xywh - - -def _bbox_xywh_to_xyxy( - xywh: torch.Tensor, clamp_size: Optional[int] = None -) -> torch.Tensor: - xyxy = xywh.clone() - if clamp_size is not None: - xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) - xyxy[2:] += xyxy[:2] - return xyxy - - -def _safe_as_tensor(data, dtype): - if data is None: - return None - return torch.tensor(data, dtype=dtype) - - -# NOTE this cache is per-worker; they are implemented as processes. -# each batch is loaded and collated by a single worker; -# since sequences tend to co-occur within batches, this is useful. -@functools.lru_cache(maxsize=256) -def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: - pcl = IO().load_pointcloud(pcl_path) - if max_points > 0: - pcl = pcl.subsample(max_points) - - return pcl \ No newline at end of file diff --git a/sgm/data/latent_objaverse.py b/sgm/data/latent_objaverse.py deleted file mode 100644 index 8819c1e7529efb1fcf44a6f95f92df3d73869517..0000000000000000000000000000000000000000 --- a/sgm/data/latent_objaverse.py +++ /dev/null @@ -1,52 +0,0 @@ -import numpy as np -from pathlib import Path -from PIL import Image -import json -import torch -from torch.utils.data import Dataset, DataLoader, default_collate -from torchvision.transforms import ToTensor, Normalize, Compose, Resize -from pytorch_lightning import LightningDataModule -from einops import rearrange - - -class LatentObjaverseSpiral(Dataset): - def __init__( - self, - root_dir, - split="train", - transform=None, - random_front=False, - max_item=None, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - **unused_kwargs, - ): - print("Using LVIS subset with precomputed Latents") - self.root_dir = Path(root_dir) - self.split = split - self.random_front = random_front - self.transform = transform - - self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") - - self.ids = json.load(open("./assets/lvis_uids.json", "r")) - self.n_views = 18 - valid_ids = [] - for idx in self.ids: - if (self.root_dir / idx).exists(): - valid_ids.append(idx) - self.ids = valid_ids - print("=" * 30) - print("Number of valid ids: ", len(self.ids)) - print("=" * 30) - - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - - if max_item is not None: - self.ids = self.ids[:max_item] - - ## debug - self.ids = self.ids * 10000 diff --git a/sgm/data/mnist.py b/sgm/data/mnist.py deleted file mode 100644 index dea4d7e670666bec80ecb22aa89603345e173d09..0000000000000000000000000000000000000000 --- a/sgm/data/mnist.py +++ /dev/null @@ -1,85 +0,0 @@ -import pytorch_lightning as pl -import torchvision -from torch.utils.data import DataLoader, Dataset -from torchvision import transforms - - -class MNISTDataDictWrapper(Dataset): - def __init__(self, dset): - super().__init__() - self.dset = dset - - def __getitem__(self, i): - x, y = self.dset[i] - return {"jpg": x, "cls": y} - - def __len__(self): - return len(self.dset) - - -class MNISTLoader(pl.LightningDataModule): - def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): - super().__init__() - - transform = transforms.Compose( - [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] - ) - - self.batch_size = batch_size - self.num_workers = num_workers - self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 - self.shuffle = shuffle - self.train_dataset = MNISTDataDictWrapper( - torchvision.datasets.MNIST( - root=".data/", train=True, download=True, transform=transform - ) - ) - self.test_dataset = MNISTDataDictWrapper( - torchvision.datasets.MNIST( - root=".data/", train=False, download=True, transform=transform - ) - ) - - def prepare_data(self): - pass - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - ) - - def test_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - ) - - def val_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - ) - - -if __name__ == "__main__": - dset = MNISTDataDictWrapper( - torchvision.datasets.MNIST( - root=".data/", - train=False, - download=True, - transform=transforms.Compose( - [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] - ), - ) - ) - ex = dset[0] diff --git a/sgm/data/mvimagenet.py b/sgm/data/mvimagenet.py deleted file mode 100644 index b20c398c08dd976c8bef1455845022f181cfcb73..0000000000000000000000000000000000000000 --- a/sgm/data/mvimagenet.py +++ /dev/null @@ -1,408 +0,0 @@ -import numpy as np -import torch -from torch.utils.data import Dataset, DataLoader, default_collate -from pathlib import Path -from PIL import Image -from scipy.spatial.transform import Rotation -import rembg -from rembg import remove, new_session -from einops import rearrange - -from torchvision.transforms import ToTensor, Normalize, Compose, Resize -from torchvision.transforms.functional import to_tensor -from pytorch_lightning import LightningDataModule - -from sgm.data.colmap import read_cameras_binary, read_images_binary -from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video - - -def qvec2rotmat(qvec): - return np.array( - [ - [ - 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, - 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], - 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], - ], - [ - 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], - 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, - 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], - ], - [ - 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], - 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], - 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, - ], - ] - ) - - -def qt2c2w(q, t): - # NOTE: remember to convert to opengl coordinate system - # rot = Rotation.from_quat(q).as_matrix() - rot = qvec2rotmat(q) - c2w = np.eye(4) - c2w[:3, :3] = np.transpose(rot) - c2w[:3, 3] = -np.transpose(rot) @ t - c2w[..., 1:3] *= -1 - return c2w - - -def random_crop(): - pass - - -class MVImageNet(Dataset): - def __init__( - self, - root_dir, - split, - transform, - reso: int = 256, - mask_type: str = "random", - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - fps_id=0.0, - motion_bucket_id=300.0, - num_frames: int = 24, - use_mask: bool = True, - load_pixelnerf: bool = False, - scale_pose: bool = False, - max_n_cond: int = 1, - min_n_cond: int = 1, - cond_on_multi: bool = False, - ) -> None: - super().__init__() - - self.root_dir = Path(root_dir) - self.split = split - - avails = self.root_dir.glob("*/*") - self.ids = list( - map( - lambda x: str(x.relative_to(self.root_dir)), - filter(lambda x: x.is_dir(), avails), - ) - ) - - self.transform = transform - self.reso = reso - self.num_frames = num_frames - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - self.fps_id = fps_id - self.motion_bucket_id = motion_bucket_id - self.mask_type = mask_type - self.use_mask = use_mask - self.load_pixelnerf = load_pixelnerf - self.scale_pose = scale_pose - self.max_n_cond = max_n_cond - self.min_n_cond = min_n_cond - self.cond_on_multi = cond_on_multi - - if self.cond_on_multi: - assert self.min_n_cond == self.max_n_cond - self.session = new_session() - - def __getitem__(self, index: int): - # mvimgnet starts with idx==1 - idx_list = np.arange(0, self.num_frames) - this_image_dir = self.root_dir / self.ids[index] / "images" - this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" - - # while not this_camera_dir.exists(): - # index = (index + 1) % len(self.ids) - # this_image_dir = self.root_dir / self.ids[index] / "images" - # this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" - if not this_camera_dir.exists(): - index = 0 - this_image_dir = self.root_dir / self.ids[index] / "images" - this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" - - this_images = read_images_binary(this_camera_dir / "images.bin") - # filenames = list(map(lambda x: f"{x:03d}", this_images.keys())) - filenames = list(this_images.keys()) - - if len(filenames) == 0: - index = 0 - this_image_dir = self.root_dir / self.ids[index] / "images" - this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" - this_images = read_images_binary(this_camera_dir / "images.bin") - # filenames = list(map(lambda x: f"{x:03d}", this_images.keys())) - filenames = list(this_images.keys()) - - filenames = list( - filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames) - ) - - filenames = sorted(filenames, key=lambda x: this_images[x].name) - - # # debug - # names = [] - # for v in filenames: - # names.append(this_images[v].name) - # breakpoint() - - while len(filenames) < self.num_frames: - num_surpass = self.num_frames - len(filenames) - filenames += list(reversed(filenames[-num_surpass:])) - - if len(filenames) < self.num_frames: - print(f"\n\n{self.ids[index]}\n\n") - - frames = [] - cameras = [] - downsampled_rgb = [] - for view_idx in idx_list: - this_id = filenames[view_idx] - frame = Image.open(this_image_dir / this_images[this_id].name) - w, h = frame.size - - if self.mask_type == "random": - image_size = min(h, w) - left = np.random.randint(0, w - image_size + 1) - right = left + image_size - top = np.random.randint(0, h - image_size + 1) - bottom = top + image_size - ## need to assign left, right, top, bottom, image_size - elif self.mask_type == "object": - pass - elif self.mask_type == "rembg": - image_size = min(h, w) - if ( - cached := this_image_dir - / f"{this_images[this_id].name[:-4]}_rembg.png" - ).exists(): - try: - mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3] - except: - mask = remove(frame, session=self.session) - mask.save(cached) - mask = np.asarray(mask)[..., 3] - else: - mask = remove(frame, session=self.session) - mask.save(cached) - mask = np.asarray(mask)[..., 3] - # in h,w order - y, x = np.array(mask.nonzero()) - bbox_cx = x.mean() - bbox_cy = y.mean() - - if bbox_cy - image_size / 2 < 0: - top = 0 - elif bbox_cy + image_size / 2 > h: - top = h - image_size - else: - top = int(bbox_cy - image_size / 2) - - if bbox_cx - image_size / 2 < 0: - left = 0 - elif bbox_cx + image_size / 2 > w: - left = w - image_size - else: - left = int(bbox_cx - image_size / 2) - - # top = max(int(bbox_cy - image_size / 2), 0) - # left = max(int(bbox_cx - image_size / 2), 0) - bottom = top + image_size - right = left + image_size - else: - raise ValueError(f"Unknown mask type: {self.mask_type}") - - frame = frame.crop((left, top, right, bottom)) - frame = frame.resize((self.reso, self.reso)) - frames.append(self.transform(frame)) - - if self.load_pixelnerf: - # extrinsics - extrinsics = this_images[this_id] - c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec) - # intrinsics - intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin") - assert len(intrinsics) == 1 - intrinsics = intrinsics[1] - f, cx, cy, _ = intrinsics.params - f *= 1 / image_size - cx -= left - cy -= top - cx *= 1 / image_size - cy *= 1 / image_size # all are relative values - intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]]) - - this_camera = np.zeros(25) - this_camera[:16] = c2w.reshape(-1) - this_camera[16:] = intrinsics.reshape(-1) - - cameras.append(this_camera) - downsampled = frame.resize((self.reso // 8, self.reso // 8)) - downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5) - - data = dict() - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - frames = torch.stack(frames) - cond = frames[0] - # setting all things in data - data["frames"] = frames - data["cond_frames_without_noise"] = cond - data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) - data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) - data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) - data["motion_bucket_id"] = torch.as_tensor( - [self.motion_bucket_id] * self.num_frames - ) - data["num_video_frames"] = self.num_frames - data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) - - if self.load_pixelnerf: - # TODO: normalize camera poses - data["pixelnerf_input"] = dict() - data["pixelnerf_input"]["frames"] = frames - data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb) - - cameras = torch.from_numpy(np.stack(cameras)).float() - if self.scale_pose: - c2ws = cameras[..., :16].reshape(-1, 4, 4) - center = c2ws[:, :3, 3].mean(0) - radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() - scale = 1.5 / radius - c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale - cameras[..., :16] = c2ws.reshape(-1, 16) - - # if self.max_n_cond > 1: - # # TODO implement this - # n_cond = np.random.randint(1, self.max_n_cond + 1) - # # debug - # source_index = [0] - # if n_cond > 1: - # source_index += np.random.choice( - # np.arange(1, self.num_frames), - # self.max_n_cond - 1, - # replace=False, - # ).tolist() - # data["pixelnerf_input"]["source_index"] = torch.as_tensor( - # source_index - # ) - # data["pixelnerf_input"]["n_cond"] = n_cond - # data["pixelnerf_input"]["source_images"] = frames[source_index] - # data["pixelnerf_input"]["source_cameras"] = cameras[source_index] - - data["pixelnerf_input"]["cameras"] = cameras - - return data - - def __len__(self): - return len(self.ids) - - def collate_fn(self, batch): - # a hack to add source index and keep consistent within a batch - if self.max_n_cond > 1: - # TODO implement this - n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) - # debug - # source_index = [0] - if n_cond > 1: - for b in batch: - source_index = [0] + np.random.choice( - np.arange(1, self.num_frames), - self.max_n_cond - 1, - replace=False, - ).tolist() - b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) - b["pixelnerf_input"]["n_cond"] = n_cond - b["pixelnerf_input"]["source_images"] = b["frames"][source_index] - b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ - "cameras" - ][source_index] - - if self.cond_on_multi: - b["cond_frames_without_noise"] = b["frames"][source_index] - - ret = video_collate_fn(batch) - - if self.cond_on_multi: - ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...") - - return ret - - -class MVImageNetFixedCond(MVImageNet): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - -class MVImageNetDataset(LightningDataModule): - def __init__( - self, - root_dir, - batch_size=2, - shuffle=True, - num_workers=10, - prefetch_factor=2, - **kwargs, - ): - super().__init__() - - self.batch_size = batch_size - self.num_workers = num_workers - self.prefetch_factor = prefetch_factor - self.shuffle = shuffle - - self.transform = Compose( - [ - ToTensor(), - Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - self.train_dataset = MVImageNet( - root_dir=root_dir, - split="train", - transform=self.transform, - **kwargs, - ) - - self.test_dataset = MVImageNet( - root_dir=root_dir, - split="test", - transform=self.transform, - **kwargs, - ) - - def train_dataloader(self): - def worker_init_fn(worker_id): - np.random.seed(np.random.get_state()[1][0]) - - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=self.train_dataset.collate_fn, - ) - - def test_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=self.test_dataset.collate_fn, - ) - - def val_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=video_collate_fn, - ) diff --git a/sgm/data/objaverse.py b/sgm/data/objaverse.py deleted file mode 100644 index e9ae0730ab09dc4e5ad87e3212b3f2ae22581934..0000000000000000000000000000000000000000 --- a/sgm/data/objaverse.py +++ /dev/null @@ -1,882 +0,0 @@ -import numpy as np -from pathlib import Path -from PIL import Image -import json -import torch -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader, default_collate -from torchvision.transforms import ToTensor, Normalize, Compose, Resize -from torchvision.transforms.functional import to_tensor -from pytorch_lightning import LightningDataModule -from einops import rearrange - - -def read_camera_matrix_single(json_file): - # for gobjaverse - with open(json_file, "r", encoding="utf8") as reader: - json_content = json.load(reader) - - # negative sign for opencv to opengl - camera_matrix = torch.zeros(3, 4) - camera_matrix[:3, 0] = torch.tensor(json_content["x"]) - camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) - camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) - camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) - """ - camera_matrix = np.eye(4) - camera_matrix[:3, 0] = np.array(json_content['x']) - camera_matrix[:3, 1] = np.array(json_content['y']) - camera_matrix[:3, 2] = np.array(json_content['z']) - camera_matrix[:3, 3] = np.array(json_content['origin']) - # print(camera_matrix) - """ - - return camera_matrix - - -def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0): - with open(json_file, "r", encoding="utf8") as reader: - json_content = json.load(reader) - - h = int(h * scale) - w = int(w * scale) - - y_fov = json_content["y_fov"] - x_fov = json_content["x_fov"] - - fy = h / 2 / np.tan(y_fov / 2) - fx = w / 2 / np.tan(x_fov / 2) - - cx = w // 2 - cy = h // 2 - - intrinsics = torch.tensor( - [ - [fx, fy], - [cx, cy], - [w, h], - ], - dtype=torch.float32, - ) - return intrinsics - - -def compose_extrinsic_RT(RT: torch.Tensor): - """ - Compose the standard form extrinsic matrix from RT. - Batched I/O. - """ - return torch.cat( - [ - RT, - torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat( - RT.shape[0], 1, 1 - ), - ], - dim=1, - ) - - -def get_normalized_camera_intrinsics(intrinsics: torch.Tensor): - """ - intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] - Return batched fx, fy, cx, cy - """ - fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] - cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] - width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] - fx, fy = fx / width, fy / height - cx, cy = cx / width, cy / height - return fx, fy, cx, cy - - -def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor): - """ - RT: (N, 3, 4) - intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] - """ - E = compose_extrinsic_RT(RT) - fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics) - I = torch.stack( - [ - torch.stack([fx, torch.zeros_like(fx), cx], dim=-1), - torch.stack([torch.zeros_like(fy), fy, cy], dim=-1), - torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1), - ], - dim=1, - ) - return torch.cat( - [ - E.reshape(-1, 16), - I.reshape(-1, 9), - ], - dim=-1, - ) - - -def calc_elevation(c2w): - ## works for single or batched c2w - ## assume world up is (0, 0, 1) - pos = c2w[..., :3, 3] - - return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False)) - - -def read_camera_matrix_single(json_file): - with open(json_file, "r", encoding="utf8") as reader: - json_content = json.load(reader) - - # negative sign for opencv to opengl - # camera_matrix = np.zeros([3, 4]) - # camera_matrix[:3, 0] = np.array(json_content["x"]) - # camera_matrix[:3, 1] = -np.array(json_content["y"]) - # camera_matrix[:3, 2] = -np.array(json_content["z"]) - # camera_matrix[:3, 3] = np.array(json_content["origin"]) - camera_matrix = torch.zeros([3, 4]) - camera_matrix[:3, 0] = torch.tensor(json_content["x"]) - camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) - camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) - camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) - """ - camera_matrix = np.eye(4) - camera_matrix[:3, 0] = np.array(json_content['x']) - camera_matrix[:3, 1] = np.array(json_content['y']) - camera_matrix[:3, 2] = np.array(json_content['z']) - camera_matrix[:3, 3] = np.array(json_content['origin']) - # print(camera_matrix) - """ - - return camera_matrix - - -def blend_white_bg(image): - new_image = Image.new("RGB", image.size, (255, 255, 255)) - new_image.paste(image, mask=image.split()[3]) - - return new_image - - -def flatten_for_video(input): - return input.flatten() - - -FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"] - - -def video_collate_fn(batch: list[dict], *args, **kwargs): - out = {} - for key in batch[0].keys(): - if key in FLATTEN_FIELDS: - out[key] = default_collate([item[key] for item in batch]) - out[key] = flatten_for_video(out[key]) - elif key == "num_video_frames": - out[key] = batch[0][key] - elif key in ["frames", "latents", "rgb"]: - out[key] = default_collate([item[key] for item in batch]) - out[key] = rearrange(out[key], "b t c h w -> (b t) c h w") - else: - out[key] = default_collate([item[key] for item in batch]) - - if "pixelnerf_input" in out: - out["pixelnerf_input"]["rgb"] = rearrange( - out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w" - ) - - return out - - -class GObjaverse(Dataset): - def __init__( - self, - root_dir, - split="train", - transform=None, - random_front=False, - max_item=None, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - fps_id=0.0, - motion_bucket_id=300.0, - use_latents=False, - load_caps=False, - front_view_selection="random", - load_pixelnerf=False, - debug_base_idx=None, - scale_pose: bool = False, - max_n_cond: int = 1, - **unused_kwargs, - ): - self.root_dir = Path(root_dir) - self.split = split - self.random_front = random_front - self.transform = transform - self.use_latents = use_latents - - self.ids = json.load(open(self.root_dir / "valid_uids.json", "r")) - self.n_views = 24 - - self.load_caps = load_caps - if self.load_caps: - self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r")) - - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - self.fps_id = fps_id - self.motion_bucket_id = motion_bucket_id - self.load_pixelnerf = load_pixelnerf - self.scale_pose = scale_pose - self.max_n_cond = max_n_cond - - if self.use_latents: - self.latents_dir = self.root_dir / "latents256" - self.clip_dir = self.root_dir / "clip_emb256" - - self.front_view_selection = front_view_selection - if self.front_view_selection == "random": - pass - elif self.front_view_selection == "fixed": - pass - elif self.front_view_selection.startswith("clip_score"): - self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt") - self.ids = list(self.clip_scores.keys()) - else: - raise ValueError( - f"Unknown front view selection method {self.front_view_selection}" - ) - - if max_item is not None: - self.ids = self.ids[:max_item] - ## debug - self.ids = self.ids * 10000 - - if debug_base_idx is not None: - print(f"debug mode with base idx: {debug_base_idx}") - self.debug_base_idx = debug_base_idx - - def __getitem__(self, idx: int): - if hasattr(self, "debug_base_idx"): - idx = (idx + self.debug_base_idx) % len(self.ids) - data = {} - idx_list = np.arange(self.n_views) - # if self.random_front: - # roll_idx = np.random.randint(self.n_views) - # idx_list = np.roll(idx_list, roll_idx) - if self.front_view_selection == "random": - roll_idx = np.random.randint(self.n_views) - idx_list = np.roll(idx_list, roll_idx) - elif self.front_view_selection == "fixed": - pass - elif self.front_view_selection == "clip_score_softmax": - this_clip_score = ( - F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() - ) - roll_idx = np.random.choice(idx_list, p=this_clip_score) - idx_list = np.roll(idx_list, roll_idx) - elif self.front_view_selection == "clip_score_max": - this_clip_score = ( - F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() - ) - roll_idx = np.argmax(this_clip_score) - idx_list = np.roll(idx_list, roll_idx) - frames = [] - if not self.use_latents: - try: - for view_idx in idx_list: - frame = Image.open( - self.root_dir - / "gobjaverse" - / self.ids[idx] - / f"{view_idx:05d}/{view_idx:05d}.png" - ) - frames.append(self.transform(frame)) - except: - idx = 0 - frames = [] - for view_idx in idx_list: - frame = Image.open( - self.root_dir - / "gobjaverse" - / self.ids[idx] - / f"{view_idx:05d}/{view_idx:05d}.png" - ) - frames.append(self.transform(frame)) - # a workaround for some bugs in gobjaverse - # use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results - frames = torch.stack(frames, dim=0) - cond = frames[0] - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - - data.update( - { - "frames": frames, - "cond_frames_without_noise": cond, - "cond_aug": torch.as_tensor([cond_aug] * self.n_views), - "cond_frames": cond + cond_aug * torch.randn_like(cond), - "fps_id": torch.as_tensor([self.fps_id] * self.n_views), - "motion_bucket_id": torch.as_tensor( - [self.motion_bucket_id] * self.n_views - ), - "num_video_frames": 24, - "image_only_indicator": torch.as_tensor([0.0] * self.n_views), - } - ) - else: - latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list] - clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0] - - cond = latents[0] - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - - data.update( - { - "latents": latents, - "cond_frames_without_noise": clip_emb, - "cond_aug": torch.as_tensor([cond_aug] * self.n_views), - "cond_frames": cond + cond_aug * torch.randn_like(cond), - "fps_id": torch.as_tensor([self.fps_id] * self.n_views), - "motion_bucket_id": torch.as_tensor( - [self.motion_bucket_id] * self.n_views - ), - "num_video_frames": 24, - "image_only_indicator": torch.as_tensor([0.0] * self.n_views), - } - ) - - if self.condition_on_elevation: - sample_c2w = read_camera_matrix_single( - self.root_dir / self.ids[idx] / f"00000/00000.json" - ) - elevation = calc_elevation(sample_c2w) - data["elevation"] = torch.as_tensor([elevation] * self.n_views) - - if self.load_pixelnerf: - assert "frames" in data, f"pixelnerf cannot work with latents only mode" - data["pixelnerf_input"] = {} - RTs = [] - intrinsics = [] - for view_idx in idx_list: - meta = ( - self.root_dir - / "gobjaverse" - / self.ids[idx] - / f"{view_idx:05d}/{view_idx:05d}.json" - ) - RTs.append(read_camera_matrix_single(meta)[:3]) - intrinsics.append(read_camera_instrinsics_single(meta, 256, 256)) - RTs = torch.stack(RTs, dim=0) - intrinsics = torch.stack(intrinsics, dim=0) - cameras = build_camera_standard(RTs, intrinsics) - data["pixelnerf_input"]["cameras"] = cameras - - downsampled = [] - for view_idx in idx_list: - frame = Image.open( - self.root_dir - / "gobjaverse" - / self.ids[idx] - / f"{view_idx:05d}/{view_idx:05d}.png" - ).resize((32, 32)) - downsampled.append(to_tensor(blend_white_bg(frame))) - data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0) - data["pixelnerf_input"]["frames"] = data["frames"] - if self.scale_pose: - c2ws = cameras[..., :16].reshape(-1, 4, 4) - center = c2ws[:, :3, 3].mean(0) - radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() - scale = 1.5 / radius - c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale - cameras[..., :16] = c2ws.reshape(-1, 16) - - if self.load_caps: - data["caption"] = self.caps[self.ids[idx]] - data["ids"] = self.ids[idx] - - return data - - def __len__(self): - return len(self.ids) - - def collate_fn(self, batch): - if self.max_n_cond > 1: - n_cond = np.random.randint(1, self.max_n_cond + 1) - if n_cond > 1: - for b in batch: - source_index = [0] + np.random.choice( - np.arange(1, self.n_views), - self.max_n_cond - 1, - replace=False, - ).tolist() - b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) - b["pixelnerf_input"]["n_cond"] = n_cond - b["pixelnerf_input"]["source_images"] = b["frames"][source_index] - b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ - "cameras" - ][source_index] - - return video_collate_fn(batch) - - -class ObjaverseSpiral(Dataset): - def __init__( - self, - root_dir, - split="train", - transform=None, - random_front=False, - max_item=None, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - **unused_kwargs, - ): - self.root_dir = Path(root_dir) - self.split = split - self.random_front = random_front - self.transform = transform - - self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r")) - self.n_views = 24 - valid_ids = [] - for idx in self.ids: - if (self.root_dir / idx).exists(): - valid_ids.append(idx) - self.ids = valid_ids - - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - - if max_item is not None: - self.ids = self.ids[:max_item] - - ## debug - self.ids = self.ids * 10000 - - def __getitem__(self, idx: int): - frames = [] - idx_list = np.arange(self.n_views) - if self.random_front: - roll_idx = np.random.randint(self.n_views) - idx_list = np.roll(idx_list, roll_idx) - for view_idx in idx_list: - frame = Image.open( - self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" - ) - frames.append(self.transform(frame)) - - # data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W] - frames = torch.stack(frames, dim=0) - cond = frames[0] - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - - data = { - "frames": frames, - "cond_frames_without_noise": cond, - "cond_aug": torch.as_tensor([cond_aug] * self.n_views), - "cond_frames": cond + cond_aug * torch.randn_like(cond), - "fps_id": torch.as_tensor([1.0] * self.n_views), - "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), - "num_video_frames": 24, - "image_only_indicator": torch.as_tensor([0.0] * self.n_views), - } - - if self.condition_on_elevation: - sample_c2w = read_camera_matrix_single( - self.root_dir / self.ids[idx] / f"00000/00000.json" - ) - elevation = calc_elevation(sample_c2w) - data["elevation"] = torch.as_tensor([elevation] * self.n_views) - - return data - - def __len__(self): - return len(self.ids) - - -class ObjaverseLVISSpiral(Dataset): - def __init__( - self, - root_dir, - split="train", - transform=None, - random_front=False, - max_item=None, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - use_precomputed_latents=False, - **unused_kwargs, - ): - print("Using LVIS subset") - self.root_dir = Path(root_dir) - self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") - self.split = split - self.random_front = random_front - self.transform = transform - self.use_precomputed_latents = use_precomputed_latents - - self.ids = json.load(open("./assets/lvis_uids.json", "r")) - self.n_views = 18 - valid_ids = [] - for idx in self.ids: - if (self.root_dir / idx).exists(): - valid_ids.append(idx) - self.ids = valid_ids - print("=" * 30) - print("Number of valid ids: ", len(self.ids)) - print("=" * 30) - - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - - if max_item is not None: - self.ids = self.ids[:max_item] - - ## debug - self.ids = self.ids * 10000 - - def __getitem__(self, idx: int): - frames = [] - idx_list = np.arange(self.n_views) - if self.random_front: - roll_idx = np.random.randint(self.n_views) - idx_list = np.roll(idx_list, roll_idx) - for view_idx in idx_list: - frame = Image.open( - self.root_dir - / self.ids[idx] - / "elevations_0" - / f"colors_{view_idx * 2}.png" - ) - frames.append(self.transform(frame)) - - frames = torch.stack(frames, dim=0) - cond = frames[0] - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - - data = { - "frames": frames, - "cond_frames_without_noise": cond, - "cond_aug": torch.as_tensor([cond_aug] * self.n_views), - "cond_frames": cond + cond_aug * torch.randn_like(cond), - "fps_id": torch.as_tensor([0.0] * self.n_views), - "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), - "num_video_frames": self.n_views, - "image_only_indicator": torch.as_tensor([0.0] * self.n_views), - } - - if self.use_precomputed_latents: - data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") - - if self.condition_on_elevation: - # sample_c2w = read_camera_matrix_single( - # self.root_dir / self.ids[idx] / f"00000/00000.json" - # ) - # elevation = calc_elevation(sample_c2w) - # data["elevation"] = torch.as_tensor([elevation] * self.n_views) - assert False, "currently assumes elevation 0" - - return data - - def __len__(self): - return len(self.ids) - - -class ObjaverseALLSpiral(ObjaverseLVISSpiral): - def __init__( - self, - root_dir, - split="train", - transform=None, - random_front=False, - max_item=None, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - use_precomputed_latents=False, - **unused_kwargs, - ): - print("Using ALL objects in Objaverse") - self.root_dir = Path(root_dir) - self.split = split - self.random_front = random_front - self.transform = transform - self.use_precomputed_latents = use_precomputed_latents - self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") - - self.ids = json.load(open("./assets/all_ids.json", "r")) - self.n_views = 18 - valid_ids = [] - for idx in self.ids: - if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): - valid_ids.append(idx) - self.ids = valid_ids - print("=" * 30) - print("Number of valid ids: ", len(self.ids)) - print("=" * 30) - - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - - if max_item is not None: - self.ids = self.ids[:max_item] - - ## debug - self.ids = self.ids * 10000 - - -class ObjaverseWithPose(Dataset): - def __init__( - self, - root_dir, - split="train", - transform=None, - random_front=False, - max_item=None, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - condition_on_elevation=False, - use_precomputed_latents=False, - **unused_kwargs, - ): - print("Using Objaverse with poses") - self.root_dir = Path(root_dir) - self.split = split - self.random_front = random_front - self.transform = transform - self.use_precomputed_latents = use_precomputed_latents - self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") - - self.ids = json.load(open("./assets/all_ids.json", "r")) - self.n_views = 18 - valid_ids = [] - for idx in self.ids: - if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): - valid_ids.append(idx) - self.ids = valid_ids - print("=" * 30) - print("Number of valid ids: ", len(self.ids)) - print("=" * 30) - - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - self.condition_on_elevation = condition_on_elevation - - def __getitem__(self, idx: int): - frames = [] - idx_list = np.arange(self.n_views) - if self.random_front: - roll_idx = np.random.randint(self.n_views) - idx_list = np.roll(idx_list, roll_idx) - for view_idx in idx_list: - frame = Image.open( - self.root_dir - / self.ids[idx] - / "elevations_0" - / f"colors_{view_idx * 2}.png" - ) - frames.append(self.transform(frame)) - - frames = torch.stack(frames, dim=0) - cond = frames[0] - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - - data = { - "frames": frames, - "cond_frames_without_noise": cond, - "cond_aug": torch.as_tensor([cond_aug] * self.n_views), - "cond_frames": cond + cond_aug * torch.randn_like(cond), - "fps_id": torch.as_tensor([0.0] * self.n_views), - "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), - "num_video_frames": self.n_views, - "image_only_indicator": torch.as_tensor([0.0] * self.n_views), - } - - if self.use_precomputed_latents: - data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") - - if self.condition_on_elevation: - assert False, "currently assumes elevation 0" - - return data - - -class LatentObjaverse(Dataset): - def __init__( - self, - root_dir, - split="train", - random_front=False, - subset="lvis", - fps_id=1.0, - motion_bucket_id=300.0, - cond_aug_mean=-3.0, - cond_aug_std=0.5, - **unused_kwargs, - ): - self.root_dir = Path(root_dir) - self.split = split - self.random_front = random_front - self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r")) - self.clip_emb_dir = self.root_dir / ".." / "clip_emb512" - self.n_views = 18 - self.fps_id = fps_id - self.motion_bucket_id = motion_bucket_id - self.cond_aug_mean = cond_aug_mean - self.cond_aug_std = cond_aug_std - if self.random_front: - print("Using a random view as front view") - - valid_ids = [] - for idx in self.ids: - if (self.root_dir / f"{idx}.pt").exists() and ( - self.clip_emb_dir / f"{idx}.pt" - ).exists(): - valid_ids.append(idx) - self.ids = valid_ids - print("=" * 30) - print("Number of valid ids: ", len(self.ids)) - print("=" * 30) - - def __getitem__(self, idx: int): - uid = self.ids[idx] - idx_list = torch.arange(self.n_views) - latents = torch.load(self.root_dir / f"{uid}.pt") - clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt") - if self.random_front: - idx_list = torch.roll(idx_list, np.random.randint(self.n_views)) - latents = latents[idx_list] - clip_emb = clip_emb[idx_list][0] - - cond_aug = np.exp( - np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean - ) - cond = latents[0] - - data = { - "latents": latents, - "cond_frames_without_noise": clip_emb, - "cond_frames": cond + cond_aug * torch.randn_like(cond), - "fps_id": torch.as_tensor([self.fps_id] * self.n_views), - "motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views), - "cond_aug": torch.as_tensor([cond_aug] * self.n_views), - "num_video_frames": self.n_views, - "image_only_indicator": torch.as_tensor([0.0] * self.n_views), - } - - return data - - def __len__(self): - return len(self.ids) - - -class ObjaverseSpiralDataset(LightningDataModule): - def __init__( - self, - root_dir, - random_front=False, - batch_size=2, - num_workers=10, - prefetch_factor=2, - shuffle=True, - max_item=None, - dataset_cls="richdreamer", - reso: int = 256, - **kwargs, - ) -> None: - super().__init__() - - self.batch_size = batch_size - self.num_workers = num_workers - self.prefetch_factor = prefetch_factor - self.shuffle = shuffle - self.max_item = max_item - - self.transform = Compose( - [ - blend_white_bg, - Resize((reso, reso)), - ToTensor(), - Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), - ] - ) - - data_cls = { - "richdreamer": ObjaverseSpiral, - "lvis": ObjaverseLVISSpiral, - "shengshu_all": ObjaverseALLSpiral, - "latent": LatentObjaverse, - "gobjaverse": GObjaverse, - }[dataset_cls] - - self.train_dataset = data_cls( - root_dir=root_dir, - split="train", - random_front=random_front, - transform=self.transform, - max_item=self.max_item, - **kwargs, - ) - self.test_dataset = data_cls( - root_dir=root_dir, - split="val", - random_front=random_front, - transform=self.transform, - max_item=self.max_item, - **kwargs, - ) - - def train_dataloader(self): - return DataLoader( - self.train_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=video_collate_fn - if not hasattr(self.train_dataset, "collate_fn") - else self.train_dataset.collate_fn, - ) - - def test_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=video_collate_fn - if not hasattr(self.test_dataset, "collate_fn") - else self.train_dataset.collate_fn, - ) - - def val_dataloader(self): - return DataLoader( - self.test_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - prefetch_factor=self.prefetch_factor, - collate_fn=video_collate_fn - if not hasattr(self.test_dataset, "collate_fn") - else self.train_dataset.collate_fn, - ) diff --git a/sgm/inference/api.py b/sgm/inference/api.py deleted file mode 100644 index 7171ff4abb774556b638c98ad809e195082bdccf..0000000000000000000000000000000000000000 --- a/sgm/inference/api.py +++ /dev/null @@ -1,385 +0,0 @@ -import pathlib -from dataclasses import asdict, dataclass -from enum import Enum -from typing import Optional - -from omegaconf import OmegaConf - -from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, - do_sample) -from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, - DPMPP2SAncestralSampler, - EulerAncestralSampler, - EulerEDMSampler, - HeunEDMSampler, - LinearMultistepSampler) -from sgm.util import load_model_from_config - - -class ModelArchitecture(str, Enum): - SD_2_1 = "stable-diffusion-v2-1" - SD_2_1_768 = "stable-diffusion-v2-1-768" - SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" - SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" - SDXL_V1_BASE = "stable-diffusion-xl-v1-base" - SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" - - -class Sampler(str, Enum): - EULER_EDM = "EulerEDMSampler" - HEUN_EDM = "HeunEDMSampler" - EULER_ANCESTRAL = "EulerAncestralSampler" - DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" - DPMPP2M = "DPMPP2MSampler" - LINEAR_MULTISTEP = "LinearMultistepSampler" - - -class Discretization(str, Enum): - LEGACY_DDPM = "LegacyDDPMDiscretization" - EDM = "EDMDiscretization" - - -class Guider(str, Enum): - VANILLA = "VanillaCFG" - IDENTITY = "IdentityGuider" - - -class Thresholder(str, Enum): - NONE = "None" - - -@dataclass -class SamplingParams: - width: int = 1024 - height: int = 1024 - steps: int = 50 - sampler: Sampler = Sampler.DPMPP2M - discretization: Discretization = Discretization.LEGACY_DDPM - guider: Guider = Guider.VANILLA - thresholder: Thresholder = Thresholder.NONE - scale: float = 6.0 - aesthetic_score: float = 5.0 - negative_aesthetic_score: float = 5.0 - img2img_strength: float = 1.0 - orig_width: int = 1024 - orig_height: int = 1024 - crop_coords_top: int = 0 - crop_coords_left: int = 0 - sigma_min: float = 0.0292 - sigma_max: float = 14.6146 - rho: float = 3.0 - s_churn: float = 0.0 - s_tmin: float = 0.0 - s_tmax: float = 999.0 - s_noise: float = 1.0 - eta: float = 1.0 - order: int = 4 - - -@dataclass -class SamplingSpec: - width: int - height: int - channels: int - factor: int - is_legacy: bool - config: str - ckpt: str - is_guided: bool - - -model_specs = { - ModelArchitecture.SD_2_1: SamplingSpec( - height=512, - width=512, - channels=4, - factor=8, - is_legacy=True, - config="sd_2_1.yaml", - ckpt="v2-1_512-ema-pruned.safetensors", - is_guided=True, - ), - ModelArchitecture.SD_2_1_768: SamplingSpec( - height=768, - width=768, - channels=4, - factor=8, - is_legacy=True, - config="sd_2_1_768.yaml", - ckpt="v2-1_768-ema-pruned.safetensors", - is_guided=True, - ), - ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( - height=1024, - width=1024, - channels=4, - factor=8, - is_legacy=False, - config="sd_xl_base.yaml", - ckpt="sd_xl_base_0.9.safetensors", - is_guided=True, - ), - ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( - height=1024, - width=1024, - channels=4, - factor=8, - is_legacy=True, - config="sd_xl_refiner.yaml", - ckpt="sd_xl_refiner_0.9.safetensors", - is_guided=True, - ), - ModelArchitecture.SDXL_V1_BASE: SamplingSpec( - height=1024, - width=1024, - channels=4, - factor=8, - is_legacy=False, - config="sd_xl_base.yaml", - ckpt="sd_xl_base_1.0.safetensors", - is_guided=True, - ), - ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( - height=1024, - width=1024, - channels=4, - factor=8, - is_legacy=True, - config="sd_xl_refiner.yaml", - ckpt="sd_xl_refiner_1.0.safetensors", - is_guided=True, - ), -} - - -class SamplingPipeline: - def __init__( - self, - model_id: ModelArchitecture, - model_path="checkpoints", - config_path="configs/inference", - device="cuda", - use_fp16=True, - ) -> None: - if model_id not in model_specs: - raise ValueError(f"Model {model_id} not supported") - self.model_id = model_id - self.specs = model_specs[self.model_id] - self.config = str(pathlib.Path(config_path, self.specs.config)) - self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) - self.device = device - self.model = self._load_model(device=device, use_fp16=use_fp16) - - def _load_model(self, device="cuda", use_fp16=True): - config = OmegaConf.load(self.config) - model = load_model_from_config(config, self.ckpt) - if model is None: - raise ValueError(f"Model {self.model_id} could not be loaded") - model.to(device) - if use_fp16: - model.conditioner.half() - model.model.half() - return model - - def text_to_image( - self, - params: SamplingParams, - prompt: str, - negative_prompt: str = "", - samples: int = 1, - return_latents: bool = False, - ): - sampler = get_sampler_config(params) - value_dict = asdict(params) - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - value_dict["target_width"] = params.width - value_dict["target_height"] = params.height - return do_sample( - self.model, - sampler, - value_dict, - samples, - params.height, - params.width, - self.specs.channels, - self.specs.factor, - force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], - return_latents=return_latents, - filter=None, - ) - - def image_to_image( - self, - params: SamplingParams, - image, - prompt: str, - negative_prompt: str = "", - samples: int = 1, - return_latents: bool = False, - ): - sampler = get_sampler_config(params) - - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) - height, width = image.shape[2], image.shape[3] - value_dict = asdict(params) - value_dict["prompt"] = prompt - value_dict["negative_prompt"] = negative_prompt - value_dict["target_width"] = width - value_dict["target_height"] = height - return do_img2img( - image, - self.model, - sampler, - value_dict, - samples, - force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], - return_latents=return_latents, - filter=None, - ) - - def refiner( - self, - params: SamplingParams, - image, - prompt: str, - negative_prompt: Optional[str] = None, - samples: int = 1, - return_latents: bool = False, - ): - sampler = get_sampler_config(params) - value_dict = { - "orig_width": image.shape[3] * 8, - "orig_height": image.shape[2] * 8, - "target_width": image.shape[3] * 8, - "target_height": image.shape[2] * 8, - "prompt": prompt, - "negative_prompt": negative_prompt, - "crop_coords_top": 0, - "crop_coords_left": 0, - "aesthetic_score": 6.0, - "negative_aesthetic_score": 2.5, - } - - return do_img2img( - image, - self.model, - sampler, - value_dict, - samples, - skip_encode=True, - return_latents=return_latents, - filter=None, - ) - - -def get_guider_config(params: SamplingParams): - if params.guider == Guider.IDENTITY: - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } - elif params.guider == Guider.VANILLA: - scale = params.scale - - thresholder = params.thresholder - - if thresholder == Thresholder.NONE: - dyn_thresh_config = { - "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" - } - else: - raise NotImplementedError - - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", - "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, - } - else: - raise NotImplementedError - return guider_config - - -def get_discretization_config(params: SamplingParams): - if params.discretization == Discretization.LEGACY_DDPM: - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", - } - elif params.discretization == Discretization.EDM: - discretization_config = { - "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", - "params": { - "sigma_min": params.sigma_min, - "sigma_max": params.sigma_max, - "rho": params.rho, - }, - } - else: - raise ValueError(f"unknown discretization {params.discretization}") - return discretization_config - - -def get_sampler_config(params: SamplingParams): - discretization_config = get_discretization_config(params) - guider_config = get_guider_config(params) - sampler = None - if params.sampler == Sampler.EULER_EDM: - return EulerEDMSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=params.s_churn, - s_tmin=params.s_tmin, - s_tmax=params.s_tmax, - s_noise=params.s_noise, - verbose=True, - ) - if params.sampler == Sampler.HEUN_EDM: - return HeunEDMSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - s_churn=params.s_churn, - s_tmin=params.s_tmin, - s_tmax=params.s_tmax, - s_noise=params.s_noise, - verbose=True, - ) - if params.sampler == Sampler.EULER_ANCESTRAL: - return EulerAncestralSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=params.eta, - s_noise=params.s_noise, - verbose=True, - ) - if params.sampler == Sampler.DPMPP2S_ANCESTRAL: - return DPMPP2SAncestralSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - eta=params.eta, - s_noise=params.s_noise, - verbose=True, - ) - if params.sampler == Sampler.DPMPP2M: - return DPMPP2MSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - verbose=True, - ) - if params.sampler == Sampler.LINEAR_MULTISTEP: - return LinearMultistepSampler( - num_steps=params.steps, - discretization_config=discretization_config, - guider_config=guider_config, - order=params.order, - verbose=True, - ) - - raise ValueError(f"unknown sampler {params.sampler}!") diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py deleted file mode 100644 index 31b0ec3dc414bf522261e35f73805810cd35582d..0000000000000000000000000000000000000000 --- a/sgm/inference/helpers.py +++ /dev/null @@ -1,305 +0,0 @@ -import math -import os -from typing import List, Optional, Union - -import numpy as np -import torch -from einops import rearrange -from imwatermark import WatermarkEncoder -from omegaconf import ListConfig -from PIL import Image -from torch import autocast - -from sgm.util import append_dims - - -class WatermarkEmbedder: - def __init__(self, watermark): - self.watermark = watermark - self.num_bits = len(WATERMARK_BITS) - self.encoder = WatermarkEncoder() - self.encoder.set_watermark("bits", self.watermark) - - def __call__(self, image: torch.Tensor) -> torch.Tensor: - """ - Adds a predefined watermark to the input image - - Args: - image: ([N,] B, RGB, H, W) in range [0, 1] - - Returns: - same as input but watermarked - """ - squeeze = len(image.shape) == 4 - if squeeze: - image = image[None, ...] - n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] - # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] - # watermarking libary expects input as cv2 BGR format - for k in range(image_np.shape[0]): - image_np[k] = self.encoder.encode(image_np[k], "dwtDct") - image = torch.from_numpy( - rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) - ).to(image.device) - image = torch.clamp(image / 255, min=0.0, max=1.0) - if squeeze: - image = image[0] - return image - - -# A fixed 48-bit message that was choosen at random -# WATERMARK_MESSAGE = 0xB3EC907BB19E -WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 -# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 -WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] -embed_watermark = WatermarkEmbedder(WATERMARK_BITS) - - -def get_unique_embedder_keys_from_conditioner(conditioner): - return list({x.input_key for x in conditioner.embedders}) - - -def perform_save_locally(save_path, samples): - os.makedirs(os.path.join(save_path), exist_ok=True) - base_count = len(os.listdir(os.path.join(save_path))) - samples = embed_watermark(samples) - for sample in samples: - sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") - Image.fromarray(sample.astype(np.uint8)).save( - os.path.join(save_path, f"{base_count:09}.png") - ) - base_count += 1 - - -class Img2ImgDiscretizationWrapper: - """ - wraps a discretizer, and prunes the sigmas - params: - strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) - """ - - def __init__(self, discretization, strength: float = 1.0): - self.discretization = discretization - self.strength = strength - assert 0.0 <= self.strength <= 1.0 - - def __call__(self, *args, **kwargs): - # sigmas start large first, and decrease then - sigmas = self.discretization(*args, **kwargs) - print(f"sigmas after discretization, before pruning img2img: ", sigmas) - sigmas = torch.flip(sigmas, (0,)) - sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] - print("prune index:", max(int(self.strength * len(sigmas)), 1)) - sigmas = torch.flip(sigmas, (0,)) - print(f"sigmas after pruning: ", sigmas) - return sigmas - - -def do_sample( - model, - sampler, - value_dict, - num_samples, - H, - W, - C, - F, - force_uc_zero_embeddings: Optional[List] = None, - batch2model_input: Optional[List] = None, - return_latents=False, - filter=None, - device="cuda", -): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - if batch2model_input is None: - batch2model_input = [] - - with torch.no_grad(): - with autocast(device) as precision_scope: - with model.ema_scope(): - num_samples = [num_samples] - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - num_samples, - ) - for key in batch: - if isinstance(batch[key], torch.Tensor): - print(key, batch[key].shape) - elif isinstance(batch[key], list): - print(key, [len(l) for l in batch[key]]) - else: - print(key, batch[key]) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - - for k in c: - if not k == "crossattn": - c[k], uc[k] = map( - lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) - ) - - additional_model_inputs = {} - for k in batch2model_input: - additional_model_inputs[k] = batch[k] - - shape = (math.prod(num_samples), C, H // F, W // F) - randn = torch.randn(shape).to(device) - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - samples_z = sampler(denoiser, randn, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - if return_latents: - return samples, samples_z - return samples - - -def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): - # Hardcoded demo setups; might undergo some changes in the future - - batch = {} - batch_uc = {} - - for key in keys: - if key == "txt": - batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() - ) - elif key == "original_size_as_tuple": - batch["original_size_as_tuple"] = ( - torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) - .to(device) - .repeat(*N, 1) - ) - elif key == "crop_coords_top_left": - batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) - .to(device) - .repeat(*N, 1) - ) - elif key == "aesthetic_score": - batch["aesthetic_score"] = ( - torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) - ) - batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) - ) - - elif key == "target_size_as_tuple": - batch["target_size_as_tuple"] = ( - torch.tensor([value_dict["target_height"], value_dict["target_width"]]) - .to(device) - .repeat(*N, 1) - ) - else: - batch[key] = value_dict[key] - - for key in batch.keys(): - if key not in batch_uc and isinstance(batch[key], torch.Tensor): - batch_uc[key] = torch.clone(batch[key]) - return batch, batch_uc - - -def get_input_image_tensor(image: Image.Image, device="cuda"): - w, h = image.size - print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 - image = image.resize((width, height)) - image_array = np.array(image.convert("RGB")) - image_array = image_array[None].transpose(0, 3, 1, 2) - image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 - return image_tensor.to(device) - - -def do_img2img( - img, - model, - sampler, - value_dict, - num_samples, - force_uc_zero_embeddings=[], - additional_kwargs={}, - offset_noise_level: float = 0.0, - return_latents=False, - skip_encode=False, - filter=None, - device="cuda", -): - with torch.no_grad(): - with autocast(device) as precision_scope: - with model.ema_scope(): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [num_samples], - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=force_uc_zero_embeddings, - ) - - for k in c: - c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) - - for k in additional_kwargs: - c[k] = uc[k] = additional_kwargs[k] - if skip_encode: - z = img - else: - z = model.encode_first_stage(img) - noise = torch.randn_like(z) - sigmas = sampler.discretization(sampler.num_steps) - sigma = sigmas[0].to(z.device) - - if offset_noise_level > 0.0: - noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim - ) - noised_z = z + noise * append_dims(sigma, z.ndim) - noised_z = noised_z / torch.sqrt( - 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. - - def denoiser(x, sigma, c): - return model.denoiser(model.model, x, sigma, c) - - samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - if filter is not None: - samples = filter(samples) - - if return_latents: - return samples, samples_z - return samples diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py deleted file mode 100644 index b2f4d384c1fcaff0df13e0564450d3fa972ace42..0000000000000000000000000000000000000000 --- a/sgm/lr_scheduler.py +++ /dev/null @@ -1,135 +0,0 @@ -import numpy as np - - -class LambdaWarmUpCosineScheduler: - """ - note: use with a base_lr of 1.0 - """ - - def __init__( - self, - warm_up_steps, - lr_min, - lr_max, - lr_start, - max_decay_steps, - verbosity_interval=0, - ): - self.lr_warm_up_steps = warm_up_steps - self.lr_start = lr_start - self.lr_min = lr_min - self.lr_max = lr_max - self.lr_max_decay_steps = max_decay_steps - self.last_lr = 0.0 - self.verbosity_interval = verbosity_interval - - def schedule(self, n, **kwargs): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") - if n < self.lr_warm_up_steps: - lr = ( - self.lr_max - self.lr_start - ) / self.lr_warm_up_steps * n + self.lr_start - self.last_lr = lr - return lr - else: - t = (n - self.lr_warm_up_steps) / ( - self.lr_max_decay_steps - self.lr_warm_up_steps - ) - t = min(t, 1.0) - lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( - 1 + np.cos(t * np.pi) - ) - self.last_lr = lr - return lr - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaWarmUpCosineScheduler2: - """ - supports repeated iterations, configurable via lists - note: use with a base_lr of 1.0. - """ - - def __init__( - self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 - ): - assert ( - len(warm_up_steps) - == len(f_min) - == len(f_max) - == len(f_start) - == len(cycle_lengths) - ) - self.lr_warm_up_steps = warm_up_steps - self.f_start = f_start - self.f_min = f_min - self.f_max = f_max - self.cycle_lengths = cycle_lengths - self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) - self.last_f = 0.0 - self.verbosity_interval = verbosity_interval - - def find_in_interval(self, n): - interval = 0 - for cl in self.cum_cycles[1:]: - if n <= cl: - return interval - interval += 1 - - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ - cycle - ] * n + self.f_start[cycle] - self.last_f = f - return f - else: - t = (n - self.lr_warm_up_steps[cycle]) / ( - self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] - ) - t = min(t, 1.0) - f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( - 1 + np.cos(t * np.pi) - ) - self.last_f = f - return f - - def __call__(self, n, **kwargs): - return self.schedule(n, **kwargs) - - -class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): - def schedule(self, n, **kwargs): - cycle = self.find_in_interval(n) - n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) - - if n < self.lr_warm_up_steps[cycle]: - f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ - cycle - ] * n + self.f_start[cycle] - self.last_f = f - return f - else: - f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( - self.cycle_lengths[cycle] - n - ) / (self.cycle_lengths[cycle]) - self.last_f = f - return f diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py deleted file mode 100644 index c410b3747afc208e4204c8f140170e0a7808eace..0000000000000000000000000000000000000000 --- a/sgm/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .autoencoder import AutoencodingEngine -from .diffusion import DiffusionEngine diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py deleted file mode 100644 index 2949b91011a2be7a6b8ca17ce260812f20ce8b75..0000000000000000000000000000000000000000 --- a/sgm/models/autoencoder.py +++ /dev/null @@ -1,615 +0,0 @@ -import logging -import math -import re -from abc import abstractmethod -from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union - -import pytorch_lightning as pl -import torch -import torch.nn as nn -from einops import rearrange -from packaging import version - -from ..modules.autoencoding.regularizers import AbstractRegularizer -from ..modules.ema import LitEma -from ..util import (default, get_nested_attribute, get_obj_from_str, - instantiate_from_config) - -logpy = logging.getLogger(__name__) - - -class AbstractAutoencoder(pl.LightningModule): - """ - This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, - unCLIP models, etc. Hence, it is fairly general, and specific features - (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. - """ - - def __init__( - self, - ema_decay: Union[None, float] = None, - monitor: Union[None, str] = None, - input_key: str = "jpg", - ): - super().__init__() - - self.input_key = input_key - self.use_ema = ema_decay is not None - if monitor is not None: - self.monitor = monitor - - if self.use_ema: - self.model_ema = LitEma(self, decay=ema_decay) - logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - if version.parse(torch.__version__) >= version.parse("2.0.0"): - self.automatic_optimization = False - - def apply_ckpt(self, ckpt: Union[None, str, dict]): - if ckpt is None: - return - if isinstance(ckpt, str): - ckpt = { - "target": "sgm.modules.checkpoint.CheckpointEngine", - "params": {"ckpt_path": ckpt}, - } - engine = instantiate_from_config(ckpt) - engine(self) - - @abstractmethod - def get_input(self, batch) -> Any: - raise NotImplementedError() - - def on_train_batch_end(self, *args, **kwargs): - # for EMA computation - if self.use_ema: - self.model_ema(self) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.parameters()) - self.model_ema.copy_to(self) - if context is not None: - logpy.info(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.parameters()) - if context is not None: - logpy.info(f"{context}: Restored training weights") - - @abstractmethod - def encode(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError("encode()-method of abstract base class called") - - @abstractmethod - def decode(self, *args, **kwargs) -> torch.Tensor: - raise NotImplementedError("decode()-method of abstract base class called") - - def instantiate_optimizer_from_config(self, params, lr, cfg): - logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) - - def configure_optimizers(self) -> Any: - raise NotImplementedError() - - -class AutoencodingEngine(AbstractAutoencoder): - """ - Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL - (we also restore them explicitly as special cases for legacy reasons). - Regularizations such as KL or VQ are moved to the regularizer class. - """ - - def __init__( - self, - *args, - encoder_config: Dict, - decoder_config: Dict, - loss_config: Dict, - regularizer_config: Dict, - optimizer_config: Union[Dict, None] = None, - lr_g_factor: float = 1.0, - trainable_ae_params: Optional[List[List[str]]] = None, - ae_optimizer_args: Optional[List[dict]] = None, - trainable_disc_params: Optional[List[List[str]]] = None, - disc_optimizer_args: Optional[List[dict]] = None, - disc_start_iter: int = 0, - diff_boost_factor: float = 3.0, - ckpt_engine: Union[None, str, dict] = None, - ckpt_path: Optional[str] = None, - additional_decode_keys: Optional[List[str]] = None, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.automatic_optimization = False # pytorch lightning - - self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) - self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) - self.loss: torch.nn.Module = instantiate_from_config(loss_config) - self.regularization: AbstractRegularizer = instantiate_from_config( - regularizer_config - ) - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.Adam"} - ) - self.diff_boost_factor = diff_boost_factor - self.disc_start_iter = disc_start_iter - self.lr_g_factor = lr_g_factor - self.trainable_ae_params = trainable_ae_params - if self.trainable_ae_params is not None: - self.ae_optimizer_args = default( - ae_optimizer_args, - [{} for _ in range(len(self.trainable_ae_params))], - ) - assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) - else: - self.ae_optimizer_args = [{}] # makes type consitent - - self.trainable_disc_params = trainable_disc_params - if self.trainable_disc_params is not None: - self.disc_optimizer_args = default( - disc_optimizer_args, - [{} for _ in range(len(self.trainable_disc_params))], - ) - assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) - else: - self.disc_optimizer_args = [{}] # makes type consitent - - if ckpt_path is not None: - assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" - logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") - self.apply_ckpt(default(ckpt_path, ckpt_engine)) - self.additional_decode_keys = set(default(additional_decode_keys, [])) - - def get_input(self, batch: Dict) -> torch.Tensor: - # assuming unified data format, dataloader returns a dict. - # image tensors should be scaled to -1 ... 1 and in channels-first - # format (e.g., bchw instead if bhwc) - return batch[self.input_key] - - def get_autoencoder_params(self) -> list: - params = [] - if hasattr(self.loss, "get_trainable_autoencoder_parameters"): - params += list(self.loss.get_trainable_autoencoder_parameters()) - if hasattr(self.regularization, "get_trainable_parameters"): - params += list(self.regularization.get_trainable_parameters()) - params = params + list(self.encoder.parameters()) - params = params + list(self.decoder.parameters()) - return params - - def get_discriminator_params(self) -> list: - if hasattr(self.loss, "get_trainable_parameters"): - params = list(self.loss.get_trainable_parameters()) # e.g., discriminator - else: - params = [] - return params - - def get_last_layer(self): - return self.decoder.get_last_layer() - - def encode( - self, - x: torch.Tensor, - return_reg_log: bool = False, - unregularized: bool = False, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - z = self.encoder(x) - if unregularized: - return z, dict() - z, reg_log = self.regularization(z) - if return_reg_log: - return z, reg_log - return z - - def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - x = self.decoder(z, **kwargs) - return x - - def forward( - self, x: torch.Tensor, **additional_decode_kwargs - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - z, reg_log = self.encode(x, return_reg_log=True) - dec = self.decode(z, **additional_decode_kwargs) - return z, dec, reg_log - - def inner_training_step( - self, batch: dict, batch_idx: int, optimizer_idx: int = 0 - ) -> torch.Tensor: - x = self.get_input(batch) - additional_decode_kwargs = { - key: batch[key] for key in self.additional_decode_keys.intersection(batch) - } - z, xrec, regularization_log = self(x, **additional_decode_kwargs) - if hasattr(self.loss, "forward_keys"): - extra_info = { - "z": z, - "optimizer_idx": optimizer_idx, - "global_step": self.global_step, - "last_layer": self.get_last_layer(), - "split": "train", - "regularization_log": regularization_log, - "autoencoder": self, - } - extra_info = {k: extra_info[k] for k in self.loss.forward_keys} - else: - extra_info = dict() - - if optimizer_idx == 0: - # autoencode - out_loss = self.loss(x, xrec, **extra_info) - if isinstance(out_loss, tuple): - aeloss, log_dict_ae = out_loss - else: - # simple loss function - aeloss = out_loss - log_dict_ae = {"train/loss/rec": aeloss.detach()} - - self.log_dict( - log_dict_ae, - prog_bar=False, - logger=True, - on_step=True, - on_epoch=True, - sync_dist=False, - ) - self.log( - "loss", - aeloss.mean().detach(), - prog_bar=True, - logger=False, - on_epoch=False, - on_step=True, - ) - return aeloss - elif optimizer_idx == 1: - # discriminator - discloss, log_dict_disc = self.loss(x, xrec, **extra_info) - # -> discriminator always needs to return a tuple - self.log_dict( - log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True - ) - return discloss - else: - raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") - - def training_step(self, batch: dict, batch_idx: int): - opts = self.optimizers() - if not isinstance(opts, list): - # Non-adversarial case - opts = [opts] - optimizer_idx = batch_idx % len(opts) - if self.global_step < self.disc_start_iter: - optimizer_idx = 0 - opt = opts[optimizer_idx] - opt.zero_grad() - with opt.toggle_model(): - loss = self.inner_training_step( - batch, batch_idx, optimizer_idx=optimizer_idx - ) - self.manual_backward(loss) - opt.step() - - def validation_step(self, batch: dict, batch_idx: int) -> Dict: - log_dict = self._validation_step(batch, batch_idx) - with self.ema_scope(): - log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") - log_dict.update(log_dict_ema) - return log_dict - - def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: - x = self.get_input(batch) - - z, xrec, regularization_log = self(x) - if hasattr(self.loss, "forward_keys"): - extra_info = { - "z": z, - "optimizer_idx": 0, - "global_step": self.global_step, - "last_layer": self.get_last_layer(), - "split": "val" + postfix, - "regularization_log": regularization_log, - "autoencoder": self, - } - extra_info = {k: extra_info[k] for k in self.loss.forward_keys} - else: - extra_info = dict() - out_loss = self.loss(x, xrec, **extra_info) - if isinstance(out_loss, tuple): - aeloss, log_dict_ae = out_loss - else: - # simple loss function - aeloss = out_loss - log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} - full_log_dict = log_dict_ae - - if "optimizer_idx" in extra_info: - extra_info["optimizer_idx"] = 1 - discloss, log_dict_disc = self.loss(x, xrec, **extra_info) - full_log_dict.update(log_dict_disc) - self.log( - f"val{postfix}/loss/rec", - log_dict_ae[f"val{postfix}/loss/rec"], - sync_dist=True, - ) - self.log_dict(full_log_dict, sync_dist=True) - return full_log_dict - - def get_param_groups( - self, parameter_names: List[List[str]], optimizer_args: List[dict] - ) -> Tuple[List[Dict[str, Any]], int]: - groups = [] - num_params = 0 - for names, args in zip(parameter_names, optimizer_args): - params = [] - for pattern_ in names: - pattern_params = [] - pattern = re.compile(pattern_) - for p_name, param in self.named_parameters(): - if re.match(pattern, p_name): - pattern_params.append(param) - num_params += param.numel() - if len(pattern_params) == 0: - logpy.warn(f"Did not find parameters for pattern {pattern_}") - params.extend(pattern_params) - groups.append({"params": params, **args}) - return groups, num_params - - def configure_optimizers(self) -> List[torch.optim.Optimizer]: - if self.trainable_ae_params is None: - ae_params = self.get_autoencoder_params() - else: - ae_params, num_ae_params = self.get_param_groups( - self.trainable_ae_params, self.ae_optimizer_args - ) - logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") - if self.trainable_disc_params is None: - disc_params = self.get_discriminator_params() - else: - disc_params, num_disc_params = self.get_param_groups( - self.trainable_disc_params, self.disc_optimizer_args - ) - logpy.info( - f"Number of trainable discriminator parameters: {num_disc_params:,}" - ) - opt_ae = self.instantiate_optimizer_from_config( - ae_params, - default(self.lr_g_factor, 1.0) * self.learning_rate, - self.optimizer_config, - ) - opts = [opt_ae] - if len(disc_params) > 0: - opt_disc = self.instantiate_optimizer_from_config( - disc_params, self.learning_rate, self.optimizer_config - ) - opts.append(opt_disc) - - return opts - - @torch.no_grad() - def log_images( - self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs - ) -> dict: - log = dict() - additional_decode_kwargs = {} - x = self.get_input(batch) - additional_decode_kwargs.update( - {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} - ) - - _, xrec, _ = self(x, **additional_decode_kwargs) - log["inputs"] = x - log["reconstructions"] = xrec - diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) - diff.clamp_(0, 1.0) - log["diff"] = 2.0 * diff - 1.0 - # diff_boost shows location of small errors, by boosting their - # brightness. - log["diff_boost"] = ( - 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 - ) - if hasattr(self.loss, "log_images"): - log.update(self.loss.log_images(x, xrec)) - with self.ema_scope(): - _, xrec_ema, _ = self(x, **additional_decode_kwargs) - log["reconstructions_ema"] = xrec_ema - diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) - diff_ema.clamp_(0, 1.0) - log["diff_ema"] = 2.0 * diff_ema - 1.0 - log["diff_boost_ema"] = ( - 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 - ) - if additional_log_kwargs: - additional_decode_kwargs.update(additional_log_kwargs) - _, xrec_add, _ = self(x, **additional_decode_kwargs) - log_str = "reconstructions-" + "-".join( - [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] - ) - log[log_str] = xrec_add - return log - - -class AutoencodingEngineLegacy(AutoencodingEngine): - def __init__(self, embed_dim: int, **kwargs): - self.max_batch_size = kwargs.pop("max_batch_size", None) - ddconfig = kwargs.pop("ddconfig") - ckpt_path = kwargs.pop("ckpt_path", None) - ckpt_engine = kwargs.pop("ckpt_engine", None) - super().__init__( - encoder_config={ - "target": "sgm.modules.diffusionmodules.model.Encoder", - "params": ddconfig, - }, - decoder_config={ - "target": "sgm.modules.diffusionmodules.model.Decoder", - "params": ddconfig, - }, - **kwargs, - ) - self.quant_conv = torch.nn.Conv2d( - (1 + ddconfig["double_z"]) * ddconfig["z_channels"], - (1 + ddconfig["double_z"]) * embed_dim, - 1, - ) - self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) - self.embed_dim = embed_dim - - self.apply_ckpt(default(ckpt_path, ckpt_engine)) - - def get_autoencoder_params(self) -> list: - params = super().get_autoencoder_params() - return params - - def encode( - self, x: torch.Tensor, return_reg_log: bool = False - ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: - if self.max_batch_size is None: - z = self.encoder(x) - z = self.quant_conv(z) - else: - N = x.shape[0] - bs = self.max_batch_size - n_batches = int(math.ceil(N / bs)) - z = list() - for i_batch in range(n_batches): - z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) - z_batch = self.quant_conv(z_batch) - z.append(z_batch) - z = torch.cat(z, 0) - - z, reg_log = self.regularization(z) - if return_reg_log: - return z, reg_log - return z - - def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: - if self.max_batch_size is None: - dec = self.post_quant_conv(z) - dec = self.decoder(dec, **decoder_kwargs) - else: - N = z.shape[0] - bs = self.max_batch_size - n_batches = int(math.ceil(N / bs)) - dec = list() - for i_batch in range(n_batches): - dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) - dec_batch = self.decoder(dec_batch, **decoder_kwargs) - dec.append(dec_batch) - dec = torch.cat(dec, 0) - - return dec - - -class AutoencoderKL(AutoencodingEngineLegacy): - def __init__(self, **kwargs): - if "lossconfig" in kwargs: - kwargs["loss_config"] = kwargs.pop("lossconfig") - super().__init__( - regularizer_config={ - "target": ( - "sgm.modules.autoencoding.regularizers" - ".DiagonalGaussianRegularizer" - ) - }, - **kwargs, - ) - - -class AutoencoderLegacyVQ(AutoencodingEngineLegacy): - def __init__( - self, - embed_dim: int, - n_embed: int, - sane_index_shape: bool = False, - **kwargs, - ): - if "lossconfig" in kwargs: - logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.") - kwargs["loss_config"] = kwargs.pop("lossconfig") - super().__init__( - regularizer_config={ - "target": ( - "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer" - ), - "params": { - "n_e": n_embed, - "e_dim": embed_dim, - "sane_index_shape": sane_index_shape, - }, - }, - **kwargs, - ) - - -class IdentityFirstStage(AbstractAutoencoder): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def get_input(self, x: Any) -> Any: - return x - - def encode(self, x: Any, *args, **kwargs) -> Any: - return x - - def decode(self, x: Any, *args, **kwargs) -> Any: - return x - - -class AEIntegerWrapper(nn.Module): - def __init__( - self, - model: nn.Module, - shape: Union[None, Tuple[int, int], List[int]] = (16, 16), - regularization_key: str = "regularization", - encoder_kwargs: Optional[Dict[str, Any]] = None, - ): - super().__init__() - self.model = model - assert hasattr(model, "encode") and hasattr( - model, "decode" - ), "Need AE interface" - self.regularization = get_nested_attribute(model, regularization_key) - self.shape = shape - self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True}) - - def encode(self, x) -> torch.Tensor: - assert ( - not self.training - ), f"{self.__class__.__name__} only supports inference currently" - _, log = self.model.encode(x, **self.encoder_kwargs) - assert isinstance(log, dict) - inds = log["min_encoding_indices"] - return rearrange(inds, "b ... -> b (...)") - - def decode( - self, inds: torch.Tensor, shape: Union[None, tuple, list] = None - ) -> torch.Tensor: - # expect inds shape (b, s) with s = h*w - shape = default(shape, self.shape) # Optional[(h, w)] - if shape is not None: - assert len(shape) == 2, f"Unhandeled shape {shape}" - inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1]) - h = self.regularization.get_codebook_entry(inds) # (b, h, w, c) - h = rearrange(h, "b h w c -> b c h w") - return self.model.decode(h) - - -class AutoencoderKLModeOnly(AutoencodingEngineLegacy): - def __init__(self, **kwargs): - if "lossconfig" in kwargs: - kwargs["loss_config"] = kwargs.pop("lossconfig") - super().__init__( - regularizer_config={ - "target": ( - "sgm.modules.autoencoding.regularizers" - ".DiagonalGaussianRegularizer" - ), - "params": {"sample": False}, - }, - **kwargs, - ) diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py deleted file mode 100644 index 41a0f4a7c6a7ed49e2d2538879d47d18ede16cba..0000000000000000000000000000000000000000 --- a/sgm/models/diffusion.py +++ /dev/null @@ -1,358 +0,0 @@ -import math -from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union - -import pytorch_lightning as pl -import torch -from omegaconf import ListConfig, OmegaConf -from safetensors.torch import load_file as load_safetensors -from torch.optim.lr_scheduler import LambdaLR -from einops import rearrange - -from ..modules import UNCONDITIONAL_CONFIG -from ..modules.autoencoding.temporal_ae import VideoDecoder -from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER -from ..modules.ema import LitEma -from ..util import ( - default, - disabled_train, - get_obj_from_str, - instantiate_from_config, - log_txt_as_img, -) - - -class DiffusionEngine(pl.LightningModule): - def __init__( - self, - network_config, - denoiser_config, - first_stage_config, - conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, - sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, - scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, - network_wrapper: Union[None, str] = None, - ckpt_path: Union[None, str] = None, - use_ema: bool = False, - ema_decay_rate: float = 0.9999, - scale_factor: float = 1.0, - disable_first_stage_autocast=False, - input_key: str = "jpg", - log_keys: Union[List, None] = None, - no_cond_log: bool = False, - compile_model: bool = False, - en_and_decode_n_samples_a_time: Optional[int] = None, - ): - super().__init__() - self.log_keys = log_keys - self.input_key = input_key - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.AdamW"} - ) - model = instantiate_from_config(network_config) - self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( - model, compile_model=compile_model - ) - - self.denoiser = instantiate_from_config(denoiser_config) - self.sampler = ( - instantiate_from_config(sampler_config) - if sampler_config is not None - else None - ) - self.conditioner = instantiate_from_config( - default(conditioner_config, UNCONDITIONAL_CONFIG) - ) - self.scheduler_config = scheduler_config - self._init_first_stage(first_stage_config) - - self.loss_fn = ( - instantiate_from_config(loss_fn_config) - if loss_fn_config is not None - else None - ) - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model, decay=ema_decay_rate) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.scale_factor = scale_factor - self.disable_first_stage_autocast = disable_first_stage_autocast - self.no_cond_log = no_cond_log - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path) - - self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time - - def init_from_ckpt( - self, - path: str, - ) -> None: - if path.endswith("ckpt"): - sd = torch.load(path, map_location="cpu")["state_dict"] - elif path.endswith("safetensors"): - sd = load_safetensors(path) - else: - raise NotImplementedError - - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - - def _init_first_stage(self, config): - model = instantiate_from_config(config).eval() - model.train = disabled_train - for param in model.parameters(): - param.requires_grad = False - self.first_stage_model = model - - def get_input(self, batch): - # assuming unified data format, dataloader returns a dict. - # image tensors should be scaled to -1 ... 1 and in bchw format - return batch[self.input_key] - - @torch.no_grad() - def decode_first_stage(self, z): - z = 1.0 / self.scale_factor * z - n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) - - n_rounds = math.ceil(z.shape[0] / n_samples) - all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - if isinstance(self.first_stage_model.decoder, VideoDecoder): - kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} - else: - kwargs = {} - out = self.first_stage_model.decode( - z[n * n_samples : (n + 1) * n_samples], **kwargs - ) - all_out.append(out) - out = torch.cat(all_out, dim=0) - return out - - @torch.no_grad() - def encode_first_stage(self, x): - bs = x.shape[0] - is_video_input = False - if x.dim() == 5: - is_video_input = True - # for video diffusion - x = rearrange(x, "b t c h w -> (b t) c h w") - n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) - n_rounds = math.ceil(x.shape[0] / n_samples) - all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - out = self.first_stage_model.encode( - x[n * n_samples : (n + 1) * n_samples] - ) - all_out.append(out) - z = torch.cat(all_out, dim=0) - z = self.scale_factor * z - - if is_video_input: - z = rearrange(z, "(b t) c h w -> b t c h w", b=bs) - - return z - - def forward(self, x, batch): - loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) - loss_mean = loss.mean() - loss_dict = {"loss": loss_mean} - return loss_mean, loss_dict - - def shared_step(self, batch: Dict) -> Any: - x = self.get_input(batch) - breakpoint() - x = self.encode_first_stage(x) - batch["global_step"] = self.global_step - loss, loss_dict = self(x, batch) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - self.log_dict( - loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - self.log( - "global_step", - self.global_step, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - if self.scheduler_config is not None: - lr = self.optimizers().param_groups[0]["lr"] - self.log( - "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - return loss - - def on_train_start(self, *args, **kwargs): - if self.sampler is None or self.loss_fn is None: - raise ValueError("Sampler and loss function need to be set for training.") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def instantiate_optimizer_from_config(self, params, lr, cfg): - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - for embedder in self.conditioner.embedders: - if embedder.is_trainable: - params = params + list(embedder.parameters()) - opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - } - ] - return [opt], scheduler - return opt - - @torch.no_grad() - def sample( - self, - cond: Dict, - uc: Union[Dict, None] = None, - batch_size: int = 16, - shape: Union[None, Tuple, List] = None, - **kwargs, - ): - randn = torch.randn(batch_size, *shape).to(self.device) - - denoiser = lambda input, sigma, c: self.denoiser( - self.model, input, sigma, c, **kwargs - ) - samples = self.sampler(denoiser, randn, cond, uc=uc) - return samples - - @torch.no_grad() - def log_conditionings(self, batch: Dict, n: int) -> Dict: - """ - Defines heuristics to log different conditionings. - These can be lists of strings (text-to-image), tensors, ints, ... - """ - image_h, image_w = batch[self.input_key].shape[2:] - log = dict() - - for embedder in self.conditioner.embedders: - if ( - (self.log_keys is None) or (embedder.input_key in self.log_keys) - ) and not self.no_cond_log: - x = batch[embedder.input_key][:n] - if isinstance(x, torch.Tensor): - if x.dim() == 1: - # class-conditional, convert integer to string - x = [str(x[i].item()) for i in range(x.shape[0])] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) - elif x.dim() == 2: - # size and crop cond and the like - x = [ - "x".join([str(xx) for xx in x[i].tolist()]) - for i in range(x.shape[0]) - ] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - else: - raise NotImplementedError() - elif isinstance(x, (List, ListConfig)): - if isinstance(x[0], str): - # strings - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - else: - raise NotImplementedError() - else: - raise NotImplementedError() - log[embedder.input_key] = xc - return log - - @torch.no_grad() - def log_images( - self, - batch: Dict, - N: int = 8, - sample: bool = True, - ucg_keys: List[str] = None, - **kwargs, - ) -> Dict: - conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] - if ucg_keys: - assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( - "Each defined ucg key for sampling must be in the provided conditioner input keys," - f"but we have {ucg_keys} vs. {conditioner_input_keys}" - ) - else: - ucg_keys = conditioner_input_keys - log = dict() - - x = self.get_input(batch) - - c, uc = self.conditioner.get_unconditional_conditioning( - batch, - force_uc_zero_embeddings=ucg_keys - if len(self.conditioner.embedders) > 0 - else [], - ) - - sampling_kwargs = {} - - N = min(x.shape[0], N) - x = x.to(self.device)[:N] - log["inputs"] = x - z = self.encode_first_stage(x) - log["reconstructions"] = self.decode_first_stage(z) - log.update(self.log_conditionings(batch, N)) - - for k in c: - if isinstance(c[k], torch.Tensor): - c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) - - if sample: - with self.ema_scope("Plotting"): - samples = self.sample( - c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs - ) - samples = self.decode_first_stage(samples) - log["samples"] = samples - return log diff --git a/sgm/models/video3d_diffusion.py b/sgm/models/video3d_diffusion.py deleted file mode 100644 index 8c4f97ec0c975937f4686471b1fa5698af013197..0000000000000000000000000000000000000000 --- a/sgm/models/video3d_diffusion.py +++ /dev/null @@ -1,524 +0,0 @@ -import re -import math -from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union - -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger -import torch -from omegaconf import ListConfig, OmegaConf -from safetensors.torch import load_file as load_safetensors -from torch.optim.lr_scheduler import LambdaLR -from torchvision.utils import make_grid -from einops import rearrange, repeat - -from ..modules import UNCONDITIONAL_CONFIG -from ..modules.autoencoding.temporal_ae import VideoDecoder -from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER -from ..modules.ema import LitEma -from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder -from ..util import ( - default, - disabled_train, - get_obj_from_str, - instantiate_from_config, - log_txt_as_img, - video_frames_as_grid, -) - - -def flatten_for_video(input): - return input.flatten() - - -class Video3DDiffusionEngine(pl.LightningModule): - def __init__( - self, - network_config, - denoiser_config, - first_stage_config, - conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, - sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, - scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, - network_wrapper: Union[None, str] = None, - ckpt_path: Union[None, str] = None, - use_ema: bool = False, - ema_decay_rate: float = 0.9999, - scale_factor: float = 1.0, - disable_first_stage_autocast=False, - input_key: str = "frames", # for video inputs - log_keys: Union[List, None] = None, - no_cond_log: bool = False, - compile_model: bool = False, - en_and_decode_n_samples_a_time: Optional[int] = None, - ): - super().__init__() - self.log_keys = log_keys - self.input_key = input_key - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.AdamW"} - ) - model = instantiate_from_config(network_config) - self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( - model, compile_model=compile_model - ) - - self.denoiser = instantiate_from_config(denoiser_config) - self.sampler = ( - instantiate_from_config(sampler_config) - if sampler_config is not None - else None - ) - self.conditioner = instantiate_from_config( - default(conditioner_config, UNCONDITIONAL_CONFIG) - ) - self.scheduler_config = scheduler_config - self._init_first_stage(first_stage_config) - - self.loss_fn = ( - instantiate_from_config(loss_fn_config) - if loss_fn_config is not None - else None - ) - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model, decay=ema_decay_rate) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.scale_factor = scale_factor - self.disable_first_stage_autocast = disable_first_stage_autocast - self.no_cond_log = no_cond_log - - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path) - - self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time - - def _load_last_embedder(self, original_state_dict): - original_module_name = "conditioner.embedders.3" - state_dict = dict() - for k, v in original_state_dict.items(): - m = re.match(rf"^{original_module_name}\.(.*)$", k) - if m is None: - continue - state_dict[m.group(1)] = v - - idx = -1 - for i in range(len(self.conditioner.embedders)): - if isinstance( - self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder - ): - idx = i - - print(f"Embedder [{idx}] is the frame encoder, make sure this is expected") - - self.conditioner.embedders[idx].load_state_dict(state_dict) - - def init_from_ckpt( - self, - path: str, - ) -> None: - if path.endswith("ckpt"): - sd = torch.load(path, map_location="cpu")["state_dict"] - elif path.endswith("safetensors"): - sd = load_safetensors(path) - else: - raise NotImplementedError - - self_sd = self.state_dict() - input_keys = [ - "model.diffusion_model.input_blocks.0.0.weight", - "model_ema.diffusion_modelinput_blocks00weight", - ] - for input_key in input_keys: - if input_key not in sd or input_key not in self_sd: - continue - - input_weight = self_sd[input_key] - - if input_weight.shape != sd[input_key].shape: - print("Manual init: {}".format(input_key)) - input_weight.zero_() - input_weight[:, :8, :, :].copy_(sd[input_key]) - - deleted_keys = [] - for k, v in self.state_dict().items(): - # resolve shape dismatch - if k in sd: - if v.shape != sd[k].shape: - del sd[k] - deleted_keys.append(k) - - if len(deleted_keys) > 0: - print(f"Deleted Keys: {deleted_keys}") - - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - if len(deleted_keys) > 0: - print(f"Deleted Keys: {deleted_keys}") - - if len(missing) > 0 or len(unexpected) > 0: - # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id) - print("Modified embedder to support 3d spiral video inputs") - try: - self._load_last_embedder(sd) - except: - print("Failed to load last embedder, make sure this is expected") - - def _init_first_stage(self, config): - model = instantiate_from_config(config).eval() - model.train = disabled_train - for param in model.parameters(): - param.requires_grad = False - self.first_stage_model = model - - def get_input(self, batch): - # assuming unified data format, dataloader returns a dict. - # image tensors should be scaled to -1 ... 1 and in bchw format - return batch[self.input_key] - - @torch.no_grad() - def decode_first_stage(self, z): - z = 1.0 / self.scale_factor * z - is_video_input = False - bs = z.shape[0] - if z.dim() == 5: - is_video_input = True - # for video diffusion - z = rearrange(z, "b t c h w -> (b t) c h w") - n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) - - n_rounds = math.ceil(z.shape[0] / n_samples) - all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - if isinstance(self.first_stage_model.decoder, VideoDecoder): - kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} - else: - kwargs = {} - out = self.first_stage_model.decode( - z[n * n_samples : (n + 1) * n_samples], **kwargs - ) - all_out.append(out) - out = torch.cat(all_out, dim=0) - - if is_video_input: - out = rearrange(out, "(b t) c h w -> b t c h w", b=bs) - - return out - - @torch.no_grad() - def encode_first_stage(self, x): - if self.input_key == "latents": - return x - - bs = x.shape[0] - is_video_input = False - if x.dim() == 5: - is_video_input = True - # for video diffusion - x = rearrange(x, "b t c h w -> (b t) c h w") - n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) - n_rounds = math.ceil(x.shape[0] / n_samples) - all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - out = self.first_stage_model.encode( - x[n * n_samples : (n + 1) * n_samples] - ) - all_out.append(out) - z = torch.cat(all_out, dim=0) - z = self.scale_factor * z - - # if is_video_input: - # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs) - - return z - - def forward(self, x, batch): - loss, model_output = self.loss_fn( - self.model, - self.denoiser, - self.conditioner, - x, - batch, - return_model_output=True, - ) - loss_mean = loss.mean() - loss_dict = {"loss": loss_mean, "model_output": model_output} - return loss_mean, loss_dict - - def shared_step(self, batch: Dict) -> Any: - # TODO: move this shit to collate_fn in dataloader - # if "fps_id" in batch: - # batch["fps_id"] = flatten_for_video(batch["fps_id"]) - # if "motion_bucket_id" in batch: - # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"]) - # if "cond_aug" in batch: - # batch["cond_aug"] = flatten_for_video(batch["cond_aug"]) - x = self.get_input(batch) - x = self.encode_first_stage(x) - # ## debug - # x_recon = self.decode_first_stage(x) - # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg") - # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg") - # ## debug - batch["global_step"] = self.global_step - loss, loss_dict = self(x, batch) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - with torch.no_grad(): - if "model_output" in loss_dict: - if batch_idx % 100 == 0: - if isinstance(self.logger, WandbLogger): - model_output = loss_dict["model_output"].detach()[ - : batch["num_video_frames"] - ] - recons = ( - (self.decode_first_stage(model_output) + 1.0) / 2.0 - ).clamp(0.0, 1.0) - recon_grid = make_grid(recons, nrow=4) - self.logger.log_image( - key=f"train/model_output_recon", - images=[recon_grid], - step=self.global_step, - ) - del loss_dict["model_output"] - - if torch.isnan(loss).any(): - print("Nan detected") - loss = None - - self.log_dict( - loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - self.log( - "global_step", - self.global_step, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - if self.scheduler_config is not None: - lr = self.optimizers().param_groups[0]["lr"] - self.log( - "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - return loss - - def on_train_start(self, *args, **kwargs): - if self.sampler is None or self.loss_fn is None: - raise ValueError("Sampler and loss function need to be set for training.") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def instantiate_optimizer_from_config(self, params, lr, cfg): - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - for embedder in self.conditioner.embedders: - if embedder.is_trainable: - params = params + list(embedder.parameters()) - opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - } - ] - return [opt], scheduler - return opt - - @torch.no_grad() - def sample( - self, - cond: Dict, - uc: Union[Dict, None] = None, - batch_size: int = 16, - shape: Union[None, Tuple, List] = None, - **kwargs, - ): - randn = torch.randn(batch_size, *shape).to(self.device) - - denoiser = lambda input, sigma, c: self.denoiser( - self.model, input, sigma, c, **kwargs - ) - samples = self.sampler(denoiser, randn, cond, uc=uc) - return samples - - @torch.no_grad() - def log_conditionings(self, batch: Dict, n: int) -> Dict: - """ - Defines heuristics to log different conditionings. - These can be lists of strings (text-to-image), tensors, ints, ... - """ - image_h, image_w = batch[self.input_key].shape[-2:] - log = dict() - - for embedder in self.conditioner.embedders: - if ( - (self.log_keys is None) or (embedder.input_key in self.log_keys) - ) and not self.no_cond_log: - x = batch[embedder.input_key][:n] - if isinstance(x, torch.Tensor): - if x.dim() == 1: - # class-conditional, convert integer to string - x = [str(x[i].item()) for i in range(x.shape[0])] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) - elif x.dim() == 2: - # size and crop cond and the like - x = [ - "x".join([str(xx) for xx in x[i].tolist()]) - for i in range(x.shape[0]) - ] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - elif x.dim() == 4: - # image - xc = x - else: - raise NotImplementedError() - elif isinstance(x, (List, ListConfig)): - if isinstance(x[0], str): - # strings - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - else: - raise NotImplementedError() - else: - raise NotImplementedError() - log[embedder.input_key] = xc - - return log - - # for video diffusions will be logging frames of a video - @torch.no_grad() - def log_images( - self, - batch: Dict, - N: int = 1, - sample: bool = True, - ucg_keys: List[str] = None, - **kwargs, - ) -> Dict: - # # debug - # return {} - # # debug - assert "num_video_frames" in batch, "num_video_frames must be in batch" - num_video_frames = batch["num_video_frames"] - conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] - conditioner_input_keys = [] - for e in self.conditioner.embedders: - if e.input_key is not None: - conditioner_input_keys.append(e.input_key) - else: - conditioner_input_keys.extend(e.input_keys) - if ucg_keys: - assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( - "Each defined ucg key for sampling must be in the provided conditioner input keys," - f"but we have {ucg_keys} vs. {conditioner_input_keys}" - ) - else: - ucg_keys = conditioner_input_keys - log = dict() - - x = self.get_input(batch) - - c, uc = self.conditioner.get_unconditional_conditioning( - batch, - force_uc_zero_embeddings=ucg_keys - if len(self.conditioner.embedders) > 0 - else [], - ) - - sampling_kwargs = {"num_video_frames": num_video_frames} - n = min(x.shape[0] // num_video_frames, N) - sampling_kwargs["image_only_indicator"] = torch.cat( - [batch["image_only_indicator"][:n]] * 2 - ) - - N = min(x.shape[0] // num_video_frames, N) * num_video_frames - x = x.to(self.device)[:N] - # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames) - log["inputs"] = x - z = self.encode_first_stage(x) - recon = self.decode_first_stage(z) - # log["reconstructions"] = rearrange( - # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames - # ) - log["reconstructions"] = recon - log.update(self.log_conditionings(batch, N)) - log["pixelnerf_rgb"] = c["rgb"] - - for k in ["crossattn", "concat", "vector"]: - if k in c: - c[k] = c[k][:N] - uc[k] = uc[k][:N] - - # for k in c: - # if isinstance(c[k], torch.Tensor): - # if k == "vector": - # end = N - # else: - # end = n - # c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc)) - - # # for k in c: - # # print(c[k].shape) - - # breakpoint() - # for k in ["crossattn", "concat"]: - # c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames) - # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames) - # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames) - # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames) - - # for k in c: - # print(c[k].shape) - if sample: - with self.ema_scope("Plotting"): - samples = self.sample( - c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs - ) - samples = self.decode_first_stage(samples) - log["samples"] = samples - return log diff --git a/sgm/models/video_diffusion.py b/sgm/models/video_diffusion.py deleted file mode 100644 index 5dbaa4a6d99e44fb2662f13e7cb5ca3ff9b0939e..0000000000000000000000000000000000000000 --- a/sgm/models/video_diffusion.py +++ /dev/null @@ -1,503 +0,0 @@ -import re -import math -from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Tuple, Union - -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger -import torch -from omegaconf import ListConfig, OmegaConf -from safetensors.torch import load_file as load_safetensors -from torch.optim.lr_scheduler import LambdaLR -from torchvision.utils import make_grid -from einops import rearrange, repeat - -from ..modules import UNCONDITIONAL_CONFIG -from ..modules.autoencoding.temporal_ae import VideoDecoder -from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER -from ..modules.ema import LitEma -from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder -from ..util import ( - default, - disabled_train, - get_obj_from_str, - instantiate_from_config, - log_txt_as_img, - video_frames_as_grid, -) - - -def flatten_for_video(input): - return input.flatten() - - -class DiffusionEngine(pl.LightningModule): - def __init__( - self, - network_config, - denoiser_config, - first_stage_config, - conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, - sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, - scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, - loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, - network_wrapper: Union[None, str] = None, - ckpt_path: Union[None, str] = None, - use_ema: bool = False, - ema_decay_rate: float = 0.9999, - scale_factor: float = 1.0, - disable_first_stage_autocast=False, - input_key: str = "frames", # for video inputs - log_keys: Union[List, None] = None, - no_cond_log: bool = False, - compile_model: bool = False, - en_and_decode_n_samples_a_time: Optional[int] = None, - load_last_embedder: bool = False, - from_scratch: bool = False, - ): - super().__init__() - self.log_keys = log_keys - self.input_key = input_key - self.optimizer_config = default( - optimizer_config, {"target": "torch.optim.AdamW"} - ) - model = instantiate_from_config(network_config) - self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( - model, compile_model=compile_model - ) - - self.denoiser = instantiate_from_config(denoiser_config) - self.sampler = ( - instantiate_from_config(sampler_config) - if sampler_config is not None - else None - ) - self.conditioner = instantiate_from_config( - default(conditioner_config, UNCONDITIONAL_CONFIG) - ) - self.scheduler_config = scheduler_config - self._init_first_stage(first_stage_config) - - self.loss_fn = ( - instantiate_from_config(loss_fn_config) - if loss_fn_config is not None - else None - ) - - self.use_ema = use_ema - if self.use_ema: - self.model_ema = LitEma(self.model, decay=ema_decay_rate) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") - - self.scale_factor = scale_factor - self.disable_first_stage_autocast = disable_first_stage_autocast - self.no_cond_log = no_cond_log - - self.load_last_embedder = load_last_embedder - if ckpt_path is not None: - self.init_from_ckpt(ckpt_path, from_scratch) - - self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time - - def _load_last_embedder(self, original_state_dict): - original_module_name = "conditioner.embedders.3" - state_dict = dict() - for k, v in original_state_dict.items(): - m = re.match(rf"^{original_module_name}\.(.*)$", k) - if m is None: - continue - state_dict[m.group(1)] = v - - idx = -1 - for i in range(len(self.conditioner.embedders)): - if isinstance( - self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder - ): - idx = i - - print(f"Embedder [{idx}] is the frame encoder, make sure this is expected") - - self.conditioner.embedders[idx].load_state_dict(state_dict) - - def init_from_ckpt( - self, - path: str, - from_scratch: bool = False, - ) -> None: - if path.endswith("ckpt"): - sd = torch.load(path, map_location="cpu")["state_dict"] - elif path.endswith("safetensors"): - sd = load_safetensors(path) - else: - raise NotImplementedError - - deleted_keys = [] - for k, v in self.state_dict().items(): - # resolve shape dismatch - if k in sd: - if v.shape != sd[k].shape: - del sd[k] - deleted_keys.append(k) - - if from_scratch: - new_sd = {} - for k in sd: - if "first_stage_model" in k: - new_sd[k] = sd[k] - sd = new_sd - print(sd.keys()) - - if len(deleted_keys) > 0: - print(f"Deleted Keys: {deleted_keys}") - - missing, unexpected = self.load_state_dict(sd, strict=False) - print( - f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" - ) - if len(missing) > 0: - print(f"Missing Keys: {missing}") - if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") - if len(deleted_keys) > 0: - print(f"Deleted Keys: {deleted_keys}") - - if (len(missing) > 0 or len(unexpected) > 0) and self.load_last_embedder: - # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id) - print("Modified embedder to support 3d spiral video inputs") - self._load_last_embedder(sd) - - def _init_first_stage(self, config): - model = instantiate_from_config(config).eval() - model.train = disabled_train - for param in model.parameters(): - param.requires_grad = False - self.first_stage_model = model - - def get_input(self, batch): - # assuming unified data format, dataloader returns a dict. - # image tensors should be scaled to -1 ... 1 and in bchw format - return batch[self.input_key] - - @torch.no_grad() - def decode_first_stage(self, z): - z = 1.0 / self.scale_factor * z - is_video_input = False - bs = z.shape[0] - if z.dim() == 5: - is_video_input = True - # for video diffusion - z = rearrange(z, "b t c h w -> (b t) c h w") - n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) - - n_rounds = math.ceil(z.shape[0] / n_samples) - all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - if isinstance(self.first_stage_model.decoder, VideoDecoder): - kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} - else: - kwargs = {} - out = self.first_stage_model.decode( - z[n * n_samples : (n + 1) * n_samples], **kwargs - ) - all_out.append(out) - out = torch.cat(all_out, dim=0) - - if is_video_input: - out = rearrange(out, "(b t) c h w -> b t c h w", b=bs) - - return out - - @torch.no_grad() - def encode_first_stage(self, x): - if self.input_key == "latents": - return x * self.scale_factor - - bs = x.shape[0] - is_video_input = False - if x.dim() == 5: - is_video_input = True - # for video diffusion - x = rearrange(x, "b t c h w -> (b t) c h w") - n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) - n_rounds = math.ceil(x.shape[0] / n_samples) - all_out = [] - with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): - for n in range(n_rounds): - out = self.first_stage_model.encode( - x[n * n_samples : (n + 1) * n_samples] - ) - all_out.append(out) - z = torch.cat(all_out, dim=0) - z = self.scale_factor * z - - # if is_video_input: - # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs) - - return z - - def forward(self, x, batch): - loss, model_output = self.loss_fn( - self.model, - self.denoiser, - self.conditioner, - x, - batch, - return_model_output=True, - ) - loss_mean = loss.mean() - loss_dict = {"loss": loss_mean, "model_output": model_output} - return loss_mean, loss_dict - - def shared_step(self, batch: Dict) -> Any: - # TODO: move this shit to collate_fn in dataloader - # if "fps_id" in batch: - # batch["fps_id"] = flatten_for_video(batch["fps_id"]) - # if "motion_bucket_id" in batch: - # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"]) - # if "cond_aug" in batch: - # batch["cond_aug"] = flatten_for_video(batch["cond_aug"]) - x = self.get_input(batch) - x = self.encode_first_stage(x) - # ## debug - # x_recon = self.decode_first_stage(x) - # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg") - # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg") - # ## debug - batch["global_step"] = self.global_step - # breakpoint() - loss, loss_dict = self(x, batch) - return loss, loss_dict - - def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) - - with torch.no_grad(): - if "model_output" in loss_dict: - if batch_idx % 100 == 0: - if isinstance(self.logger, WandbLogger): - model_output = loss_dict["model_output"].detach()[ - : batch["num_video_frames"] - ] - recons = ( - (self.decode_first_stage(model_output) + 1.0) / 2.0 - ).clamp(0.0, 1.0) - recon_grid = make_grid(recons, nrow=4) - self.logger.log_image( - key=f"train/model_output_recon", - images=[recon_grid], - step=self.global_step, - ) - del loss_dict["model_output"] - - self.log_dict( - loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - self.log( - "global_step", - self.global_step, - prog_bar=True, - logger=True, - on_step=True, - on_epoch=False, - ) - - if self.scheduler_config is not None: - lr = self.optimizers().param_groups[0]["lr"] - self.log( - "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False - ) - - return loss - - def on_train_start(self, *args, **kwargs): - if self.sampler is None or self.loss_fn is None: - raise ValueError("Sampler and loss function need to be set for training.") - - def on_train_batch_end(self, *args, **kwargs): - if self.use_ema: - self.model_ema(self.model) - - @contextmanager - def ema_scope(self, context=None): - if self.use_ema: - self.model_ema.store(self.model.parameters()) - self.model_ema.copy_to(self.model) - if context is not None: - print(f"{context}: Switched to EMA weights") - try: - yield None - finally: - if self.use_ema: - self.model_ema.restore(self.model.parameters()) - if context is not None: - print(f"{context}: Restored training weights") - - def instantiate_optimizer_from_config(self, params, lr, cfg): - return get_obj_from_str(cfg["target"])( - params, lr=lr, **cfg.get("params", dict()) - ) - - def configure_optimizers(self): - lr = self.learning_rate - params = list(self.model.parameters()) - for embedder in self.conditioner.embedders: - if embedder.is_trainable: - params = params + list(embedder.parameters()) - opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) - if self.scheduler_config is not None: - scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") - scheduler = [ - { - "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), - "interval": "step", - "frequency": 1, - } - ] - return [opt], scheduler - return opt - - @torch.no_grad() - def sample( - self, - cond: Dict, - uc: Union[Dict, None] = None, - batch_size: int = 16, - shape: Union[None, Tuple, List] = None, - **kwargs, - ): - randn = torch.randn(batch_size, *shape).to(self.device) - - denoiser = lambda input, sigma, c: self.denoiser( - self.model, input, sigma, c, **kwargs - ) - samples = self.sampler(denoiser, randn, cond, uc=uc) - return samples - - @torch.no_grad() - def log_conditionings(self, batch: Dict, n: int) -> Dict: - """ - Defines heuristics to log different conditionings. - These can be lists of strings (text-to-image), tensors, ints, ... - """ - image_h, image_w = batch[self.input_key].shape[-2:] - log = dict() - - for embedder in self.conditioner.embedders: - if ( - (self.log_keys is None) or (embedder.input_key in self.log_keys) - ) and not self.no_cond_log: - x = batch[embedder.input_key][:n] - if isinstance(x, torch.Tensor): - if x.dim() == 1: - # class-conditional, convert integer to string - x = [str(x[i].item()) for i in range(x.shape[0])] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) - elif x.dim() == 2: - # size and crop cond and the like - x = [ - "x".join([str(xx) for xx in x[i].tolist()]) - for i in range(x.shape[0]) - ] - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - elif x.dim() == 4: - # image - xc = x - else: - pass - # breakpoint() - # raise NotImplementedError() - elif isinstance(x, (List, ListConfig)): - if isinstance(x[0], str): - # strings - xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) - else: - raise NotImplementedError() - else: - raise NotImplementedError() - log[embedder.input_key] = xc - return log - - # for video diffusions will be logging frames of a video - @torch.no_grad() - def log_images( - self, - batch: Dict, - N: int = 1, - sample: bool = True, - ucg_keys: List[str] = None, - **kwargs, - ) -> Dict: - # # debug - # return {} - # # debug - assert "num_video_frames" in batch, "num_video_frames must be in batch" - num_video_frames = batch["num_video_frames"] - conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] - if ucg_keys: - assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( - "Each defined ucg key for sampling must be in the provided conditioner input keys," - f"but we have {ucg_keys} vs. {conditioner_input_keys}" - ) - else: - ucg_keys = conditioner_input_keys - log = dict() - - x = self.get_input(batch) - - c, uc = self.conditioner.get_unconditional_conditioning( - batch, - force_uc_zero_embeddings=ucg_keys - if len(self.conditioner.embedders) > 0 - else [], - ) - - sampling_kwargs = {"num_video_frames": num_video_frames} - n = min(x.shape[0] // num_video_frames, N) - sampling_kwargs["image_only_indicator"] = torch.cat( - [batch["image_only_indicator"][:n]] * 2 - ) - - N = min(x.shape[0] // num_video_frames, N) * num_video_frames - x = x.to(self.device)[:N] - # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames) - if self.input_key != "latents": - log["inputs"] = x - z = self.encode_first_stage(x) - recon = self.decode_first_stage(z) - # log["reconstructions"] = rearrange( - # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames - # ) - log["reconstructions"] = recon - log.update(self.log_conditionings(batch, N)) - - for k in c: - if isinstance(c[k], torch.Tensor): - if k == "vector": - end = N - else: - end = n - c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc)) - - # for k in c: - # print(c[k].shape) - - for k in ["crossattn", "concat"]: - c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames) - c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames) - uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames) - uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames) - - # for k in c: - # print(c[k].shape) - if sample: - with self.ema_scope("Plotting"): - samples = self.sample( - c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs - ) - samples = self.decode_first_stage(samples) - log["samples"] = samples - return log diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py deleted file mode 100644 index 2aa9ad360acf32dab22989d81630b3eb7978abb1..0000000000000000000000000000000000000000 --- a/sgm/modules/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .encoders.modules import GeneralConditioner, ExtraConditioner - -UNCONDITIONAL_CONFIG = { - "target": "sgm.modules.GeneralConditioner", - "params": {"emb_models": []}, -} diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py deleted file mode 100644 index f3b60cabce854b52527f6dee85ea4f0cb0951eb6..0000000000000000000000000000000000000000 --- a/sgm/modules/attention.py +++ /dev/null @@ -1,764 +0,0 @@ -import logging -import math -from inspect import isfunction -from typing import Any, Optional -from functools import partial - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat -from packaging import version -from torch import nn - -# from torch.utils.checkpoint import checkpoint - -checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - - -logpy = logging.getLogger(__name__) - -if version.parse(torch.__version__) >= version.parse("2.0.0"): - SDP_IS_AVAILABLE = True - from torch.backends.cuda import SDPBackend, sdp_kernel - - BACKEND_MAP = { - SDPBackend.MATH: { - "enable_math": True, - "enable_flash": False, - "enable_mem_efficient": False, - }, - SDPBackend.FLASH_ATTENTION: { - "enable_math": False, - "enable_flash": True, - "enable_mem_efficient": False, - }, - SDPBackend.EFFICIENT_ATTENTION: { - "enable_math": False, - "enable_flash": False, - "enable_mem_efficient": True, - }, - None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, - } -else: - from contextlib import nullcontext - - SDP_IS_AVAILABLE = False - sdp_kernel = nullcontext - BACKEND_MAP = {} - logpy.warn( - f"No SDP backend available, likely because you are running in pytorch " - f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " - f"You might want to consider upgrading." - ) - -try: - import xformers - import xformers.ops - - XFORMERS_IS_AVAILABLE = True -except: - XFORMERS_IS_AVAILABLE = False - logpy.warn("no module 'xformers'. Processing without...") - -# from .diffusionmodules.util import mixed_checkpoint as checkpoint - - -def exists(val): - return val is not None - - -def uniq(arr): - return {el: True for el in arr}.keys() - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def max_neg_value(t): - return -torch.finfo(t.dtype).max - - -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor - - -# feedforward -class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): - super().__init__() - self.proj = nn.Linear(dim_in, dim_out * 2) - - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = ( - nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) - if not glu - else GEGLU(dim, inner_dim) - ) - - self.net = nn.Sequential( - project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm( - num_groups=32, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange( - qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 - ) - k = k.softmax(dim=-1) - context = torch.einsum("bhdn,bhen->bhde", k, v) - out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange( - out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w - ) - return self.to_out(out) - - -class SelfAttention(nn.Module): - ATTENTION_MODES = ("xformers", "torch", "math") - - def __init__( - self, - dim: int, - num_heads: int = 8, - qkv_bias: bool = False, - qk_scale: Optional[float] = None, - attn_drop: float = 0.0, - proj_drop: float = 0.0, - attn_mode: str = "xformers", - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - assert attn_mode in self.ATTENTION_MODES - self.attn_mode = attn_mode - - def forward(self, x: torch.Tensor) -> torch.Tensor: - B, L, C = x.shape - - qkv = self.qkv(x) - if self.attn_mode == "torch": - qkv = rearrange( - qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads - ).float() - q, k, v = qkv[0], qkv[1], qkv[2] # B H L D - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) - x = rearrange(x, "B H L D -> B L (H D)") - elif self.attn_mode == "xformers": - qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) - q, k, v = qkv[0], qkv[1], qkv[2] # B L H D - x = xformers.ops.memory_efficient_attention(q, k, v) - x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) - elif self.attn_mode == "math": - qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) - q, k, v = qkv[0], qkv[1], qkv[2] # B H L D - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = (attn @ v).transpose(1, 2).reshape(B, L, C) - else: - raise NotImplemented - - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b, c, h, w = q.shape - q = rearrange(q, "b c h w -> b (h w) c") - k = rearrange(k, "b c h w -> b c (h w)") - w_ = torch.einsum("bij,bjk->bik", q, k) - - w_ = w_ * (int(c) ** (-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, "b c h w -> b c (h w)") - w_ = rearrange(w_, "b i j -> b j i") - h_ = torch.einsum("bij,bjk->bik", v, w_) - h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) - h_ = self.proj_out(h_) - - return x + h_ - - -class CrossAttention(nn.Module): - def __init__( - self, - query_dim, - context_dim=None, - heads=8, - dim_head=64, - dropout=0.0, - backend=None, - ): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.scale = dim_head**-0.5 - self.heads = heads - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - self.backend = backend - - def forward( - self, - x, - context=None, - mask=None, - additional_tokens=None, - n_times_crossframe_attn_in_self=0, - ): - h = self.heads - - if additional_tokens is not None: - # get the number of masked tokens at the beginning of the output sequence - n_tokens_to_mask = additional_tokens.shape[1] - # add additional token - x = torch.cat([additional_tokens, x], dim=1) - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - if n_times_crossframe_attn_in_self: - # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 - assert x.shape[0] % n_times_crossframe_attn_in_self == 0 - n_cp = x.shape[0] // n_times_crossframe_attn_in_self - k = repeat( - k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp - ) - v = repeat( - v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp - ) - - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) - - ## old - """ - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', sim, v) - """ - ## new - with sdp_kernel(**BACKEND_MAP[self.backend]): - # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) - out = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask - ) # scale is dim_head ** -0.5 per default - - del q, k, v - out = rearrange(out, "b h n d -> b n (h d)", h=h) - - if additional_tokens is not None: - # remove additional token - out = out[:, n_tokens_to_mask:] - return self.to_out(out) - - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__( - self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs - ): - super().__init__() - logpy.debug( - f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, " - f"context_dim is {context_dim} and using {heads} heads with a " - f"dimension of {dim_head}." - ) - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) - ) - self.attention_op: Optional[Any] = None - - def forward( - self, - x, - context=None, - mask=None, - additional_tokens=None, - n_times_crossframe_attn_in_self=0, - ): - if additional_tokens is not None: - # get the number of masked tokens at the beginning of the output sequence - n_tokens_to_mask = additional_tokens.shape[1] - # add additional token - x = torch.cat([additional_tokens, x], dim=1) - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) - - if n_times_crossframe_attn_in_self: - # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 - assert x.shape[0] % n_times_crossframe_attn_in_self == 0 - # n_cp = x.shape[0]//n_times_crossframe_attn_in_self - k = repeat( - k[::n_times_crossframe_attn_in_self], - "b ... -> (b n) ...", - n=n_times_crossframe_attn_in_self, - ) - v = repeat( - v[::n_times_crossframe_attn_in_self], - "b ... -> (b n) ...", - n=n_times_crossframe_attn_in_self, - ) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - if version.parse(xformers.__version__) >= version.parse("0.0.21"): - # NOTE: workaround for - # https://github.com/facebookresearch/xformers/issues/845 - max_bs = 32768 - N = q.shape[0] - n_batches = math.ceil(N / max_bs) - out = list() - for i_batch in range(n_batches): - batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) - out.append( - xformers.ops.memory_efficient_attention( - q[batch], - k[batch], - v[batch], - attn_bias=None, - op=self.attention_op, - ) - ) - out = torch.cat(out, 0) - else: - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) - - # TODO: Use this directly in the attention operation, as a bias - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) - if additional_tokens is not None: - # remove additional token - out = out[:, n_tokens_to_mask:] - return self.to_out(out) - - -class BasicTransformerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention, # ampere - } - - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - disable_self_attn=False, - attn_mode="softmax", - sdp_backend=None, - ): - super().__init__() - assert attn_mode in self.ATTENTION_MODES - if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: - logpy.warn( - f"Attention mode '{attn_mode}' is not available. Falling " - f"back to native attention. This is not a problem in " - f"Pytorch >= 2.0. FYI, you are running with PyTorch " - f"version {torch.__version__}." - ) - attn_mode = "softmax" - elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: - logpy.warn( - "We do not support vanilla attention anymore, as it is too " - "expensive. Sorry." - ) - if not XFORMERS_IS_AVAILABLE: - assert ( - False - ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" - else: - logpy.info("Falling back to xformers efficient attention.") - attn_mode = "softmax-xformers" - attn_cls = self.ATTENTION_MODES[attn_mode] - if version.parse(torch.__version__) >= version.parse("2.0.0"): - assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) - else: - assert sdp_backend is None - self.disable_self_attn = disable_self_attn - self.attn1 = attn_cls( - query_dim=dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, - backend=sdp_backend, - ) # is a self-attention if not self.disable_self_attn - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = attn_cls( - query_dim=dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - backend=sdp_backend, - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - if self.checkpoint: - logpy.debug(f"{self.__class__.__name__} is using checkpointing") - - def forward( - self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 - ): - kwargs = {"x": x} - - if context is not None: - kwargs.update({"context": context}) - - if additional_tokens is not None: - kwargs.update({"additional_tokens": additional_tokens}) - - if n_times_crossframe_attn_in_self: - kwargs.update( - {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} - ) - - # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) - if self.checkpoint: - # inputs = {"x": x, "context": context} - return checkpoint(self._forward, x, context) - # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) - else: - return self._forward(**kwargs) - - def _forward( - self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 - ): - x = ( - self.attn1( - self.norm1(x), - context=context if self.disable_self_attn else None, - additional_tokens=additional_tokens, - n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self - if not self.disable_self_attn - else 0, - ) - + x - ) - x = ( - self.attn2( - self.norm2(x), context=context, additional_tokens=additional_tokens - ) - + x - ) - x = self.ff(self.norm3(x)) + x - return x - - -class BasicTransformerSingleLayerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, # vanilla attention - "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version - # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) - } - - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - attn_mode="softmax", - ): - super().__init__() - assert attn_mode in self.ATTENTION_MODES - attn_cls = self.ATTENTION_MODES[attn_mode] - self.attn1 = attn_cls( - query_dim=dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - context_dim=context_dim, - ) - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def forward(self, x, context=None): - # inputs = {"x": x, "context": context} - # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) - return checkpoint(self._forward, x, context) - - def _forward(self, x, context=None): - x = self.attn1(self.norm1(x), context=context) + x - x = self.ff(self.norm2(x)) + x - return x - - -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image - NEW: use_linear for more efficiency instead of the 1x1 convs - """ - - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - context_dim=None, - disable_self_attn=False, - use_linear=False, - attn_type="softmax", - use_checkpoint=True, - # sdp_backend=SDPBackend.FLASH_ATTENTION - sdp_backend=None, - ): - super().__init__() - logpy.debug( - f"constructing {self.__class__.__name__} of depth {depth} w/ " - f"{in_channels} channels and {n_heads} heads." - ) - - if exists(context_dim) and not isinstance(context_dim, list): - context_dim = [context_dim] - if exists(context_dim) and isinstance(context_dim, list): - if depth != len(context_dim): - logpy.warn( - f"{self.__class__.__name__}: Found context dims " - f"{context_dim} of depth {len(context_dim)}, which does not " - f"match the specified 'depth' of {depth}. Setting context_dim " - f"to {depth * [context_dim[0]]} now." - ) - # depth does not match context dims. - assert all( - map(lambda x: x == context_dim[0], context_dim) - ), "need homogenous context_dim to match depth automatically" - context_dim = depth * [context_dim[0]] - elif context_dim is None: - context_dim = [None] * depth - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = Normalize(in_channels) - if not use_linear: - self.proj_in = nn.Conv2d( - in_channels, inner_dim, kernel_size=1, stride=1, padding=0 - ) - else: - self.proj_in = nn.Linear(in_channels, inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - n_heads, - d_head, - dropout=dropout, - context_dim=context_dim[d], - disable_self_attn=disable_self_attn, - attn_mode=attn_type, - checkpoint=use_checkpoint, - sdp_backend=sdp_backend, - ) - for d in range(depth) - ] - ) - if not use_linear: - self.proj_out = zero_module( - nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - ) - else: - # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) - self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) - self.use_linear = use_linear - - def forward(self, x, context=None): - # note: if no context is given, cross-attention defaults to self-attention - if not isinstance(context, list): - context = [context] - b, c, h, w = x.shape - x_in = x - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c").contiguous() - if self.use_linear: - x = self.proj_in(x) - for i, block in enumerate(self.transformer_blocks): - if i > 0 and len(context) == 1: - i = 0 # use same context for each block - x = block(x, context=context[i]) - if self.use_linear: - x = self.proj_out(x) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() - if not self.use_linear: - x = self.proj_out(x) - return x + x_in - - -class SimpleTransformer(nn.Module): - def __init__( - self, - dim: int, - depth: int, - heads: int, - dim_head: int, - context_dim: Optional[int] = None, - dropout: float = 0.0, - checkpoint: bool = True, - ): - super().__init__() - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - BasicTransformerBlock( - dim, - heads, - dim_head, - dropout=dropout, - context_dim=context_dim, - attn_mode="softmax-xformers", - checkpoint=checkpoint, - ) - ) - - def forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - for layer in self.layers: - x = layer(x, context) - return x diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py deleted file mode 100644 index 6b316c7aa6ea1c5e31a58987aa3b37b2933eb7e2..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/losses/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -__all__ = [ - "GeneralLPIPSWithDiscriminator", - "LatentLPIPS", -] - -from .discriminator_loss import GeneralLPIPSWithDiscriminator -from .lpips import LatentLPIPS diff --git a/sgm/modules/autoencoding/losses/discriminator_loss.py b/sgm/modules/autoencoding/losses/discriminator_loss.py deleted file mode 100644 index 09b6829267bf8e4d98c3f29abdc19e58dcbcbe64..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/losses/discriminator_loss.py +++ /dev/null @@ -1,306 +0,0 @@ -from typing import Dict, Iterator, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torchvision -from einops import rearrange -from matplotlib import colormaps -from matplotlib import pyplot as plt - -from ....util import default, instantiate_from_config -from ..lpips.loss.lpips import LPIPS -from ..lpips.model.model import weights_init -from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss - - -class GeneralLPIPSWithDiscriminator(nn.Module): - def __init__( - self, - disc_start: int, - logvar_init: float = 0.0, - disc_num_layers: int = 3, - disc_in_channels: int = 3, - disc_factor: float = 1.0, - disc_weight: float = 1.0, - perceptual_weight: float = 1.0, - disc_loss: str = "hinge", - scale_input_to_tgt_size: bool = False, - dims: int = 2, - learn_logvar: bool = False, - regularization_weights: Union[None, Dict[str, float]] = None, - additional_log_keys: Optional[List[str]] = None, - discriminator_config: Optional[Dict] = None, - ): - super().__init__() - self.dims = dims - if self.dims > 2: - print( - f"running with dims={dims}. This means that for perceptual loss " - f"calculation, the LPIPS loss will be applied to each frame " - f"independently." - ) - self.scale_input_to_tgt_size = scale_input_to_tgt_size - assert disc_loss in ["hinge", "vanilla"] - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - # output log variance - self.logvar = nn.Parameter( - torch.full((), logvar_init), requires_grad=learn_logvar - ) - self.learn_logvar = learn_logvar - - discriminator_config = default( - discriminator_config, - { - "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", - "params": { - "input_nc": disc_in_channels, - "n_layers": disc_num_layers, - "use_actnorm": False, - }, - }, - ) - - self.discriminator = instantiate_from_config(discriminator_config).apply( - weights_init - ) - self.discriminator_iter_start = disc_start - self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss - self.disc_factor = disc_factor - self.discriminator_weight = disc_weight - self.regularization_weights = default(regularization_weights, {}) - - self.forward_keys = [ - "optimizer_idx", - "global_step", - "last_layer", - "split", - "regularization_log", - ] - - self.additional_log_keys = set(default(additional_log_keys, [])) - self.additional_log_keys.update(set(self.regularization_weights.keys())) - - def get_trainable_parameters(self) -> Iterator[nn.Parameter]: - return self.discriminator.parameters() - - def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: - if self.learn_logvar: - yield self.logvar - yield from () - - @torch.no_grad() - def log_images( - self, inputs: torch.Tensor, reconstructions: torch.Tensor - ) -> Dict[str, torch.Tensor]: - # calc logits of real/fake - logits_real = self.discriminator(inputs.contiguous().detach()) - if len(logits_real.shape) < 4: - # Non patch-discriminator - return dict() - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - # -> (b, 1, h, w) - - # parameters for colormapping - high = max(logits_fake.abs().max(), logits_real.abs().max()).item() - cmap = colormaps["PiYG"] # diverging colormap - - def to_colormap(logits: torch.Tensor) -> torch.Tensor: - """(b, 1, ...) -> (b, 3, ...)""" - logits = (logits + high) / (2 * high) - logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel - # -> (b, 1, ..., 3) - logits = torch.from_numpy(logits_np).to(logits.device) - return rearrange(logits, "b 1 ... c -> b c ...") - - logits_real = torch.nn.functional.interpolate( - logits_real, - size=inputs.shape[-2:], - mode="nearest", - antialias=False, - ) - logits_fake = torch.nn.functional.interpolate( - logits_fake, - size=reconstructions.shape[-2:], - mode="nearest", - antialias=False, - ) - - # alpha value of logits for overlay - alpha_real = torch.abs(logits_real) / high - alpha_fake = torch.abs(logits_fake) / high - # -> (b, 1, h, w) in range [0, 0.5] - # alpha value of lines don't really matter, since the values are the same - # for both images and logits anyway - grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) - grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) - grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) - # -> (1, h, w) - # blend logits and images together - - # prepare logits for plotting - logits_real = to_colormap(logits_real) - logits_fake = to_colormap(logits_fake) - # resize logits - # -> (b, 3, h, w) - - # make some grids - # add all logits to one plot - logits_real = torchvision.utils.make_grid(logits_real, nrow=4) - logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) - # I just love how torchvision calls the number of columns `nrow` - grid_logits = torch.cat((logits_real, logits_fake), dim=1) - # -> (3, h, w) - - grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) - grid_images_fake = torchvision.utils.make_grid( - 0.5 * reconstructions + 0.5, nrow=4 - ) - grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) - # -> (3, h, w) in range [0, 1] - - grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images - - # Create labeled colorbar - dpi = 100 - height = 128 / dpi - width = grid_logits.shape[2] / dpi - fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) - img = ax.imshow(np.array([[-high, high]]), cmap=cmap) - plt.colorbar( - img, - cax=ax, - orientation="horizontal", - fraction=0.9, - aspect=width / height, - pad=0.0, - ) - img.set_visible(False) - fig.tight_layout() - fig.canvas.draw() - # manually convert figure to numpy - cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) - cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 - cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) - - # Add colorbar to plot - annotated_grid = torch.cat((grid_logits, cbar), dim=1) - blended_grid = torch.cat((grid_blend, cbar), dim=1) - return { - "vis_logits": 2 * annotated_grid[None, ...] - 1, - "vis_logits_blended": 2 * blended_grid[None, ...] - 1, - } - - def calculate_adaptive_weight( - self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor - ) -> torch.Tensor: - nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] - g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] - - d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) - d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() - d_weight = d_weight * self.discriminator_weight - return d_weight - - def forward( - self, - inputs: torch.Tensor, - reconstructions: torch.Tensor, - *, # added because I changed the order here - regularization_log: Dict[str, torch.Tensor], - optimizer_idx: int, - global_step: int, - last_layer: torch.Tensor, - split: str = "train", - weights: Union[None, float, torch.Tensor] = None, - ) -> Tuple[torch.Tensor, dict]: - if self.scale_input_to_tgt_size: - inputs = torch.nn.functional.interpolate( - inputs, reconstructions.shape[2:], mode="bicubic", antialias=True - ) - - if self.dims > 2: - inputs, reconstructions = map( - lambda x: rearrange(x, "b c t h w -> (b t) c h w"), - (inputs, reconstructions), - ) - - rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) - if self.perceptual_weight > 0: - p_loss = self.perceptual_loss( - inputs.contiguous(), reconstructions.contiguous() - ) - rec_loss = rec_loss + self.perceptual_weight * p_loss - - nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) - - # now the GAN part - if optimizer_idx == 0: - # generator update - if global_step >= self.discriminator_iter_start or not self.training: - logits_fake = self.discriminator(reconstructions.contiguous()) - g_loss = -torch.mean(logits_fake) - if self.training: - d_weight = self.calculate_adaptive_weight( - nll_loss, g_loss, last_layer=last_layer - ) - else: - d_weight = torch.tensor(1.0) - else: - d_weight = torch.tensor(0.0) - g_loss = torch.tensor(0.0, requires_grad=True) - - loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss - log = dict() - for k in regularization_log: - if k in self.regularization_weights: - loss = loss + self.regularization_weights[k] * regularization_log[k] - if k in self.additional_log_keys: - log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() - - log.update( - { - f"{split}/loss/total": loss.clone().detach().mean(), - f"{split}/loss/nll": nll_loss.detach().mean(), - f"{split}/loss/rec": rec_loss.detach().mean(), - f"{split}/loss/g": g_loss.detach().mean(), - f"{split}/scalars/logvar": self.logvar.detach(), - f"{split}/scalars/d_weight": d_weight.detach(), - } - ) - - return loss, log - elif optimizer_idx == 1: - # second pass for discriminator update - logits_real = self.discriminator(inputs.contiguous().detach()) - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - - if global_step >= self.discriminator_iter_start or not self.training: - d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) - else: - d_loss = torch.tensor(0.0, requires_grad=True) - - log = { - f"{split}/loss/disc": d_loss.clone().detach().mean(), - f"{split}/logits/real": logits_real.detach().mean(), - f"{split}/logits/fake": logits_fake.detach().mean(), - } - return d_loss, log - else: - raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") - - def get_nll_loss( - self, - rec_loss: torch.Tensor, - weights: Optional[Union[float, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar - weighted_nll_loss = nll_loss - if weights is not None: - weighted_nll_loss = weights * nll_loss - weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] - nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] - - return nll_loss, weighted_nll_loss diff --git a/sgm/modules/autoencoding/losses/lpips.py b/sgm/modules/autoencoding/losses/lpips.py deleted file mode 100644 index b329fcc2ee9477f0122aa7d066866cdfe71ce521..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/losses/lpips.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import torch.nn as nn - -from ....util import default, instantiate_from_config -from ..lpips.loss.lpips import LPIPS - - -class LatentLPIPS(nn.Module): - def __init__( - self, - decoder_config, - perceptual_weight=1.0, - latent_weight=1.0, - scale_input_to_tgt_size=False, - scale_tgt_to_input_size=False, - perceptual_weight_on_inputs=0.0, - ): - super().__init__() - self.scale_input_to_tgt_size = scale_input_to_tgt_size - self.scale_tgt_to_input_size = scale_tgt_to_input_size - self.init_decoder(decoder_config) - self.perceptual_loss = LPIPS().eval() - self.perceptual_weight = perceptual_weight - self.latent_weight = latent_weight - self.perceptual_weight_on_inputs = perceptual_weight_on_inputs - - def init_decoder(self, config): - self.decoder = instantiate_from_config(config) - if hasattr(self.decoder, "encoder"): - del self.decoder.encoder - - def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): - log = dict() - loss = (latent_inputs - latent_predictions) ** 2 - log[f"{split}/latent_l2_loss"] = loss.mean().detach() - image_reconstructions = None - if self.perceptual_weight > 0.0: - image_reconstructions = self.decoder.decode(latent_predictions) - image_targets = self.decoder.decode(latent_inputs) - perceptual_loss = self.perceptual_loss( - image_targets.contiguous(), image_reconstructions.contiguous() - ) - loss = ( - self.latent_weight * loss.mean() - + self.perceptual_weight * perceptual_loss.mean() - ) - log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() - - if self.perceptual_weight_on_inputs > 0.0: - image_reconstructions = default( - image_reconstructions, self.decoder.decode(latent_predictions) - ) - if self.scale_input_to_tgt_size: - image_inputs = torch.nn.functional.interpolate( - image_inputs, - image_reconstructions.shape[2:], - mode="bicubic", - antialias=True, - ) - elif self.scale_tgt_to_input_size: - image_reconstructions = torch.nn.functional.interpolate( - image_reconstructions, - image_inputs.shape[2:], - mode="bicubic", - antialias=True, - ) - - perceptual_loss2 = self.perceptual_loss( - image_inputs.contiguous(), image_reconstructions.contiguous() - ) - loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() - log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() - return loss, log diff --git a/sgm/modules/autoencoding/lpips/loss/.gitignore b/sgm/modules/autoencoding/lpips/loss/.gitignore deleted file mode 100644 index a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/loss/.gitignore +++ /dev/null @@ -1 +0,0 @@ -vgg.pth \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/LICENSE b/sgm/modules/autoencoding/lpips/loss/LICENSE deleted file mode 100644 index 924cfc85b8d63ef538f5676f830a2a8497932108..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/loss/LICENSE +++ /dev/null @@ -1,23 +0,0 @@ -Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py deleted file mode 100644 index 3e34f3d083674f675a5ca024e9bd27fb77e2b6b5..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/loss/lpips.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" - -from collections import namedtuple - -import torch -import torch.nn as nn -from torchvision import models - -from ..util import get_ckpt_path - - -class LPIPS(nn.Module): - # Learned perceptual metric - def __init__(self, use_dropout=True): - super().__init__() - self.scaling_layer = ScalingLayer() - self.chns = [64, 128, 256, 512, 512] # vg16 features - self.net = vgg16(pretrained=True, requires_grad=False) - self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) - self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) - self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) - self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) - self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) - self.load_from_pretrained() - for param in self.parameters(): - param.requires_grad = False - - def load_from_pretrained(self, name="vgg_lpips"): - ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") - self.load_state_dict( - torch.load(ckpt, map_location=torch.device("cpu")), strict=False - ) - print("loaded pretrained LPIPS loss from {}".format(ckpt)) - - @classmethod - def from_pretrained(cls, name="vgg_lpips"): - if name != "vgg_lpips": - raise NotImplementedError - model = cls() - ckpt = get_ckpt_path(name) - model.load_state_dict( - torch.load(ckpt, map_location=torch.device("cpu")), strict=False - ) - return model - - def forward(self, input, target): - in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) - outs0, outs1 = self.net(in0_input), self.net(in1_input) - feats0, feats1, diffs = {}, {}, {} - lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] - for kk in range(len(self.chns)): - feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( - outs1[kk] - ) - diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 - - res = [ - spatial_average(lins[kk].model(diffs[kk]), keepdim=True) - for kk in range(len(self.chns)) - ] - val = res[0] - for l in range(1, len(self.chns)): - val += res[l] - return val - - -class ScalingLayer(nn.Module): - def __init__(self): - super(ScalingLayer, self).__init__() - self.register_buffer( - "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] - ) - self.register_buffer( - "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] - ) - - def forward(self, inp): - return (inp - self.shift) / self.scale - - -class NetLinLayer(nn.Module): - """A single linear layer which does a 1x1 conv""" - - def __init__(self, chn_in, chn_out=1, use_dropout=False): - super(NetLinLayer, self).__init__() - layers = ( - [ - nn.Dropout(), - ] - if (use_dropout) - else [] - ) - layers += [ - nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), - ] - self.model = nn.Sequential(*layers) - - -class vgg16(torch.nn.Module): - def __init__(self, requires_grad=False, pretrained=True): - super(vgg16, self).__init__() - vgg_pretrained_features = models.vgg16(pretrained=pretrained).features - self.slice1 = torch.nn.Sequential() - self.slice2 = torch.nn.Sequential() - self.slice3 = torch.nn.Sequential() - self.slice4 = torch.nn.Sequential() - self.slice5 = torch.nn.Sequential() - self.N_slices = 5 - for x in range(4): - self.slice1.add_module(str(x), vgg_pretrained_features[x]) - for x in range(4, 9): - self.slice2.add_module(str(x), vgg_pretrained_features[x]) - for x in range(9, 16): - self.slice3.add_module(str(x), vgg_pretrained_features[x]) - for x in range(16, 23): - self.slice4.add_module(str(x), vgg_pretrained_features[x]) - for x in range(23, 30): - self.slice5.add_module(str(x), vgg_pretrained_features[x]) - if not requires_grad: - for param in self.parameters(): - param.requires_grad = False - - def forward(self, X): - h = self.slice1(X) - h_relu1_2 = h - h = self.slice2(h) - h_relu2_2 = h - h = self.slice3(h) - h_relu3_3 = h - h = self.slice4(h) - h_relu4_3 = h - h = self.slice5(h) - h_relu5_3 = h - vgg_outputs = namedtuple( - "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] - ) - out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) - return out - - -def normalize_tensor(x, eps=1e-10): - norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) - return x / (norm_factor + eps) - - -def spatial_average(x, keepdim=True): - return x.mean([2, 3], keepdim=keepdim) diff --git a/sgm/modules/autoencoding/lpips/model/LICENSE b/sgm/modules/autoencoding/lpips/model/LICENSE deleted file mode 100644 index 4b356e66b5aa689b339f1a80a9f1b5ba378003bb..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/model/LICENSE +++ /dev/null @@ -1,58 +0,0 @@ -Copyright (c) 2017, Jun-Yan Zhu and Taesung Park -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ---------------------------- LICENSE FOR pix2pix -------------------------------- -BSD License - -For pix2pix software -Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - ------------------------------ LICENSE FOR DCGAN -------------------------------- -BSD License - -For dcgan.torch software - -Copyright (c) 2015, Facebook, Inc. All rights reserved. - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/model/__init__.py b/sgm/modules/autoencoding/lpips/model/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/autoencoding/lpips/model/model.py b/sgm/modules/autoencoding/lpips/model/model.py deleted file mode 100644 index 66357d4e627f9a69a5abbbad15546c96fcd758fe..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/model/model.py +++ /dev/null @@ -1,88 +0,0 @@ -import functools - -import torch.nn as nn - -from ..util import ActNorm - - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - nn.init.normal_(m.weight.data, 0.0, 0.02) - elif classname.find("BatchNorm") != -1: - nn.init.normal_(m.weight.data, 1.0, 0.02) - nn.init.constant_(m.bias.data, 0) - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator as in Pix2Pix - --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py - """ - - def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): - """Construct a PatchGAN discriminator - Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer - """ - super(NLayerDiscriminator, self).__init__() - if not use_actnorm: - norm_layer = nn.BatchNorm2d - else: - norm_layer = ActNorm - if ( - type(norm_layer) == functools.partial - ): # no need to use bias as BatchNorm2d has affine parameters - use_bias = norm_layer.func != nn.BatchNorm2d - else: - use_bias = norm_layer != nn.BatchNorm2d - - kw = 4 - padw = 1 - sequence = [ - nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), - nn.LeakyReLU(0.2, True), - ] - nf_mult = 1 - nf_mult_prev = 1 - for n in range(1, n_layers): # gradually increase the number of filters - nf_mult_prev = nf_mult - nf_mult = min(2**n, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=2, - padding=padw, - bias=use_bias, - ), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True), - ] - - nf_mult_prev = nf_mult - nf_mult = min(2**n_layers, 8) - sequence += [ - nn.Conv2d( - ndf * nf_mult_prev, - ndf * nf_mult, - kernel_size=kw, - stride=1, - padding=padw, - bias=use_bias, - ), - norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.2, True), - ] - - sequence += [ - nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) - ] # output 1 channel prediction map - self.main = nn.Sequential(*sequence) - - def forward(self, input): - """Standard forward.""" - return self.main(input) diff --git a/sgm/modules/autoencoding/lpips/util.py b/sgm/modules/autoencoding/lpips/util.py deleted file mode 100644 index 49c76e370bf16888ab61f42844b3c9f14ad9014c..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/util.py +++ /dev/null @@ -1,128 +0,0 @@ -import hashlib -import os - -import requests -import torch -import torch.nn as nn -from tqdm import tqdm - -URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} - -CKPT_MAP = {"vgg_lpips": "vgg.pth"} - -MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} - - -def download(url, local_path, chunk_size=1024): - os.makedirs(os.path.split(local_path)[0], exist_ok=True) - with requests.get(url, stream=True) as r: - total_size = int(r.headers.get("content-length", 0)) - with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: - with open(local_path, "wb") as f: - for data in r.iter_content(chunk_size=chunk_size): - if data: - f.write(data) - pbar.update(chunk_size) - - -def md5_hash(path): - with open(path, "rb") as f: - content = f.read() - return hashlib.md5(content).hexdigest() - - -def get_ckpt_path(name, root, check=False): - assert name in URL_MAP - path = os.path.join(root, CKPT_MAP[name]) - if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): - print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) - download(URL_MAP[name], path) - md5 = md5_hash(path) - assert md5 == MD5_MAP[name], md5 - return path - - -class ActNorm(nn.Module): - def __init__( - self, num_features, logdet=False, affine=True, allow_reverse_init=False - ): - assert affine - super().__init__() - self.logdet = logdet - self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) - self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) - self.allow_reverse_init = allow_reverse_init - - self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) - - def initialize(self, input): - with torch.no_grad(): - flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) - mean = ( - flatten.mean(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - std = ( - flatten.std(1) - .unsqueeze(1) - .unsqueeze(2) - .unsqueeze(3) - .permute(1, 0, 2, 3) - ) - - self.loc.data.copy_(-mean) - self.scale.data.copy_(1 / (std + 1e-6)) - - def forward(self, input, reverse=False): - if reverse: - return self.reverse(input) - if len(input.shape) == 2: - input = input[:, :, None, None] - squeeze = True - else: - squeeze = False - - _, _, height, width = input.shape - - if self.training and self.initialized.item() == 0: - self.initialize(input) - self.initialized.fill_(1) - - h = self.scale * (input + self.loc) - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - - if self.logdet: - log_abs = torch.log(torch.abs(self.scale)) - logdet = height * width * torch.sum(log_abs) - logdet = logdet * torch.ones(input.shape[0]).to(input) - return h, logdet - - return h - - def reverse(self, output): - if self.training and self.initialized.item() == 0: - if not self.allow_reverse_init: - raise RuntimeError( - "Initializing ActNorm in reverse direction is " - "disabled by default. Use allow_reverse_init=True to enable." - ) - else: - self.initialize(output) - self.initialized.fill_(1) - - if len(output.shape) == 2: - output = output[:, :, None, None] - squeeze = True - else: - squeeze = False - - h = output / self.scale - self.loc - - if squeeze: - h = h.squeeze(-1).squeeze(-1) - return h diff --git a/sgm/modules/autoencoding/lpips/vqperceptual.py b/sgm/modules/autoencoding/lpips/vqperceptual.py deleted file mode 100644 index 6195f0a6ed7ee6fd32c1bccea071e6075e95ee43..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/lpips/vqperceptual.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -import torch.nn.functional as F - - -def hinge_d_loss(logits_real, logits_fake): - loss_real = torch.mean(F.relu(1.0 - logits_real)) - loss_fake = torch.mean(F.relu(1.0 + logits_fake)) - d_loss = 0.5 * (loss_real + loss_fake) - return d_loss - - -def vanilla_d_loss(logits_real, logits_fake): - d_loss = 0.5 * ( - torch.mean(torch.nn.functional.softplus(-logits_real)) - + torch.mean(torch.nn.functional.softplus(logits_fake)) - ) - return d_loss diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py deleted file mode 100644 index ff2b1815a5ba88892375e8ec9bedacea49024113..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/regularizers/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -from abc import abstractmethod -from typing import Any, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ....modules.distributions.distributions import \ - DiagonalGaussianDistribution -from .base import AbstractRegularizer - - -class DiagonalGaussianRegularizer(AbstractRegularizer): - def __init__(self, sample: bool = True): - super().__init__() - self.sample = sample - - def get_trainable_parameters(self) -> Any: - yield from () - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: - log = dict() - posterior = DiagonalGaussianDistribution(z) - if self.sample: - z = posterior.sample() - else: - z = posterior.mode() - kl_loss = posterior.kl() - kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] - log["kl_loss"] = kl_loss - return z, log diff --git a/sgm/modules/autoencoding/regularizers/base.py b/sgm/modules/autoencoding/regularizers/base.py deleted file mode 100644 index fca681bb3c1f4818b57e956e31b98f76077ccb67..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/regularizers/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from abc import abstractmethod -from typing import Any, Tuple - -import torch -import torch.nn.functional as F -from torch import nn - - -class AbstractRegularizer(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: - raise NotImplementedError() - - @abstractmethod - def get_trainable_parameters(self) -> Any: - raise NotImplementedError() - - -class IdentityRegularizer(AbstractRegularizer): - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: - return z, dict() - - def get_trainable_parameters(self) -> Any: - yield from () - - -def measure_perplexity( - predicted_indices: torch.Tensor, num_centroids: int -) -> Tuple[torch.Tensor, torch.Tensor]: - # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py - # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally - encodings = ( - F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) - ) - avg_probs = encodings.mean(0) - perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() - cluster_use = torch.sum(avg_probs > 0) - return perplexity, cluster_use diff --git a/sgm/modules/autoencoding/regularizers/quantize.py b/sgm/modules/autoencoding/regularizers/quantize.py deleted file mode 100644 index 86a4dbdd10101b24f03bba134c4f8d2ab007f0db..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/regularizers/quantize.py +++ /dev/null @@ -1,487 +0,0 @@ -import logging -from abc import abstractmethod -from typing import Dict, Iterator, Literal, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch import einsum - -from .base import AbstractRegularizer, measure_perplexity - -logpy = logging.getLogger(__name__) - - -class AbstractQuantizer(AbstractRegularizer): - def __init__(self): - super().__init__() - # Define these in your init - # shape (N,) - self.used: Optional[torch.Tensor] - self.re_embed: int - self.unknown_index: Union[Literal["random"], int] - - def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: - assert self.used is not None, "You need to define used indices for remap" - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - match = (inds[:, :, None] == used[None, None, ...]).long() - new = match.argmax(-1) - unknown = match.sum(2) < 1 - if self.unknown_index == "random": - new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( - device=new.device - ) - else: - new[unknown] = self.unknown_index - return new.reshape(ishape) - - def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: - assert self.used is not None, "You need to define used indices for remap" - ishape = inds.shape - assert len(ishape) > 1 - inds = inds.reshape(ishape[0], -1) - used = self.used.to(inds) - if self.re_embed > self.used.shape[0]: # extra token - inds[inds >= self.used.shape[0]] = 0 # simply set to zero - back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) - return back.reshape(ishape) - - @abstractmethod - def get_codebook_entry( - self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None - ) -> torch.Tensor: - raise NotImplementedError() - - def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: - yield from self.parameters() - - -class GumbelQuantizer(AbstractQuantizer): - """ - credit to @karpathy: - https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) - Gumbel Softmax trick quantizer - Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 - https://arxiv.org/abs/1611.01144 - """ - - def __init__( - self, - num_hiddens: int, - embedding_dim: int, - n_embed: int, - straight_through: bool = True, - kl_weight: float = 5e-4, - temp_init: float = 1.0, - remap: Optional[str] = None, - unknown_index: str = "random", - loss_key: str = "loss/vq", - ) -> None: - super().__init__() - - self.loss_key = loss_key - self.embedding_dim = embedding_dim - self.n_embed = n_embed - - self.straight_through = straight_through - self.temperature = temp_init - self.kl_weight = kl_weight - - self.proj = nn.Conv2d(num_hiddens, n_embed, 1) - self.embed = nn.Embedding(n_embed, embedding_dim) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - else: - self.used = None - self.re_embed = n_embed - if unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - else: - assert unknown_index == "random" or isinstance( - unknown_index, int - ), "unknown index needs to be 'random', 'extra' or any integer" - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.remap is not None: - logpy.info( - f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - - def forward( - self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False - ) -> Tuple[torch.Tensor, Dict]: - # force hard = True when we are in eval mode, as we must quantize. - # actually, always true seems to work - hard = self.straight_through if self.training else True - temp = self.temperature if temp is None else temp - out_dict = {} - logits = self.proj(z) - if self.remap is not None: - # continue only with used logits - full_zeros = torch.zeros_like(logits) - logits = logits[:, self.used, ...] - - soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) - if self.remap is not None: - # go back to all entries but unused set to zero - full_zeros[:, self.used, ...] = soft_one_hot - soft_one_hot = full_zeros - z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) - - # + kl divergence to the prior loss - qy = F.softmax(logits, dim=1) - diff = ( - self.kl_weight - * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() - ) - out_dict[self.loss_key] = diff - - ind = soft_one_hot.argmax(dim=1) - out_dict["indices"] = ind - if self.remap is not None: - ind = self.remap_to_used(ind) - - if return_logits: - out_dict["logits"] = logits - - return z_q, out_dict - - def get_codebook_entry(self, indices, shape): - # TODO: shape not yet optional - b, h, w, c = shape - assert b * h * w == indices.shape[0] - indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) - if self.remap is not None: - indices = self.unmap_to_all(indices) - one_hot = ( - F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() - ) - z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) - return z_q - - -class VectorQuantizer(AbstractQuantizer): - """ - ____________________________________________ - Discretization bottleneck part of the VQ-VAE. - Inputs: - - n_e : number of embeddings - - e_dim : dimension of embedding - - beta : commitment cost used in loss term, - beta * ||z_e(x)-sg[e]||^2 - _____________________________________________ - """ - - def __init__( - self, - n_e: int, - e_dim: int, - beta: float = 0.25, - remap: Optional[str] = None, - unknown_index: str = "random", - sane_index_shape: bool = False, - log_perplexity: bool = False, - embedding_weight_norm: bool = False, - loss_key: str = "loss/vq", - ): - super().__init__() - self.n_e = n_e - self.e_dim = e_dim - self.beta = beta - self.loss_key = loss_key - - if not embedding_weight_norm: - self.embedding = nn.Embedding(self.n_e, self.e_dim) - self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) - else: - self.embedding = torch.nn.utils.weight_norm( - nn.Embedding(self.n_e, self.e_dim), dim=1 - ) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - else: - self.used = None - self.re_embed = n_e - if unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - else: - assert unknown_index == "random" or isinstance( - unknown_index, int - ), "unknown index needs to be 'random', 'extra' or any integer" - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.remap is not None: - logpy.info( - f"Remapping {self.n_e} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - - self.sane_index_shape = sane_index_shape - self.log_perplexity = log_perplexity - - def forward( - self, - z: torch.Tensor, - ) -> Tuple[torch.Tensor, Dict]: - do_reshape = z.ndim == 4 - if do_reshape: - # # reshape z -> (batch, height, width, channel) and flatten - z = rearrange(z, "b c h w -> b h w c").contiguous() - - else: - assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" - z = z.contiguous() - - z_flattened = z.view(-1, self.e_dim) - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - - d = ( - torch.sum(z_flattened**2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight**2, dim=1) - - 2 - * torch.einsum( - "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") - ) - ) - - min_encoding_indices = torch.argmin(d, dim=1) - z_q = self.embedding(min_encoding_indices).view(z.shape) - loss_dict = {} - if self.log_perplexity: - perplexity, cluster_usage = measure_perplexity( - min_encoding_indices.detach(), self.n_e - ) - loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) - - # compute loss for embedding - loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( - (z_q - z.detach()) ** 2 - ) - loss_dict[self.loss_key] = loss - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - if do_reshape: - z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() - - if self.remap is not None: - min_encoding_indices = min_encoding_indices.reshape( - z.shape[0], -1 - ) # add batch axis - min_encoding_indices = self.remap_to_used(min_encoding_indices) - min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten - - if self.sane_index_shape: - if do_reshape: - min_encoding_indices = min_encoding_indices.reshape( - z_q.shape[0], z_q.shape[2], z_q.shape[3] - ) - else: - min_encoding_indices = rearrange( - min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] - ) - - loss_dict["min_encoding_indices"] = min_encoding_indices - - return z_q, loss_dict - - def get_codebook_entry( - self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None - ) -> torch.Tensor: - # shape specifying (batch, height, width, channel) - if self.remap is not None: - assert shape is not None, "Need to give shape for remap" - indices = indices.reshape(shape[0], -1) # add batch axis - indices = self.unmap_to_all(indices) - indices = indices.reshape(-1) # flatten again - - # get quantized latent vectors - z_q = self.embedding(indices) - - if shape is not None: - z_q = z_q.view(shape) - # reshape back to match original input shape - z_q = z_q.permute(0, 3, 1, 2).contiguous() - - return z_q - - -class EmbeddingEMA(nn.Module): - def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): - super().__init__() - self.decay = decay - self.eps = eps - weight = torch.randn(num_tokens, codebook_dim) - self.weight = nn.Parameter(weight, requires_grad=False) - self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) - self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) - self.update = True - - def forward(self, embed_id): - return F.embedding(embed_id, self.weight) - - def cluster_size_ema_update(self, new_cluster_size): - self.cluster_size.data.mul_(self.decay).add_( - new_cluster_size, alpha=1 - self.decay - ) - - def embed_avg_ema_update(self, new_embed_avg): - self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) - - def weight_update(self, num_tokens): - n = self.cluster_size.sum() - smoothed_cluster_size = ( - (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n - ) - # normalize embedding average with smoothed cluster size - embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) - self.weight.data.copy_(embed_normalized) - - -class EMAVectorQuantizer(AbstractQuantizer): - def __init__( - self, - n_embed: int, - embedding_dim: int, - beta: float, - decay: float = 0.99, - eps: float = 1e-5, - remap: Optional[str] = None, - unknown_index: str = "random", - loss_key: str = "loss/vq", - ): - super().__init__() - self.codebook_dim = embedding_dim - self.num_tokens = n_embed - self.beta = beta - self.loss_key = loss_key - - self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) - - self.remap = remap - if self.remap is not None: - self.register_buffer("used", torch.tensor(np.load(self.remap))) - self.re_embed = self.used.shape[0] - else: - self.used = None - self.re_embed = n_embed - if unknown_index == "extra": - self.unknown_index = self.re_embed - self.re_embed = self.re_embed + 1 - else: - assert unknown_index == "random" or isinstance( - unknown_index, int - ), "unknown index needs to be 'random', 'extra' or any integer" - self.unknown_index = unknown_index # "random" or "extra" or integer - if self.remap is not None: - logpy.info( - f"Remapping {self.n_embed} indices to {self.re_embed} indices. " - f"Using {self.unknown_index} for unknown indices." - ) - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: - # reshape z -> (batch, height, width, channel) and flatten - # z, 'b c h w -> b h w c' - z = rearrange(z, "b c h w -> b h w c") - z_flattened = z.reshape(-1, self.codebook_dim) - - # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z - d = ( - z_flattened.pow(2).sum(dim=1, keepdim=True) - + self.embedding.weight.pow(2).sum(dim=1) - - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) - ) # 'n d -> d n' - - encoding_indices = torch.argmin(d, dim=1) - - z_q = self.embedding(encoding_indices).view(z.shape) - encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) - avg_probs = torch.mean(encodings, dim=0) - perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) - - if self.training and self.embedding.update: - # EMA cluster size - encodings_sum = encodings.sum(0) - self.embedding.cluster_size_ema_update(encodings_sum) - # EMA embedding average - embed_sum = encodings.transpose(0, 1) @ z_flattened - self.embedding.embed_avg_ema_update(embed_sum) - # normalize embed_avg and update weight - self.embedding.weight_update(self.num_tokens) - - # compute loss for embedding - loss = self.beta * F.mse_loss(z_q.detach(), z) - - # preserve gradients - z_q = z + (z_q - z).detach() - - # reshape back to match original input shape - # z_q, 'b h w c -> b c h w' - z_q = rearrange(z_q, "b h w c -> b c h w") - - out_dict = { - self.loss_key: loss, - "encodings": encodings, - "encoding_indices": encoding_indices, - "perplexity": perplexity, - } - - return z_q, out_dict - - -class VectorQuantizerWithInputProjection(VectorQuantizer): - def __init__( - self, - input_dim: int, - n_codes: int, - codebook_dim: int, - beta: float = 1.0, - output_dim: Optional[int] = None, - **kwargs, - ): - super().__init__(n_codes, codebook_dim, beta, **kwargs) - self.proj_in = nn.Linear(input_dim, codebook_dim) - self.output_dim = output_dim - if output_dim is not None: - self.proj_out = nn.Linear(codebook_dim, output_dim) - else: - self.proj_out = nn.Identity() - - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: - rearr = False - in_shape = z.shape - - if z.ndim > 3: - rearr = self.output_dim is not None - z = rearrange(z, "b c ... -> b (...) c") - z = self.proj_in(z) - z_q, loss_dict = super().forward(z) - - z_q = self.proj_out(z_q) - if rearr: - if len(in_shape) == 4: - z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) - elif len(in_shape) == 5: - z_q = rearrange( - z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] - ) - else: - raise NotImplementedError( - f"rearranging not available for {len(in_shape)}-dimensional input." - ) - - return z_q, loss_dict diff --git a/sgm/modules/autoencoding/temporal_ae.py b/sgm/modules/autoencoding/temporal_ae.py deleted file mode 100644 index 374373e2e4330846ffef28d9061dcc64f70d2722..0000000000000000000000000000000000000000 --- a/sgm/modules/autoencoding/temporal_ae.py +++ /dev/null @@ -1,349 +0,0 @@ -from typing import Callable, Iterable, Union - -import torch -from einops import rearrange, repeat - -from sgm.modules.diffusionmodules.model import ( - XFORMERS_IS_AVAILABLE, - AttnBlock, - Decoder, - MemoryEfficientAttnBlock, - ResnetBlock, -) -from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding -from sgm.modules.video_attention import VideoTransformerBlock -from sgm.util import partialclass - - -class VideoResBlock(ResnetBlock): - def __init__( - self, - out_channels, - *args, - dropout=0.0, - video_kernel_size=3, - alpha=0.0, - merge_strategy="learned", - **kwargs, - ): - super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) - if video_kernel_size is None: - video_kernel_size = [3, 1, 1] - self.time_stack = ResBlock( - channels=out_channels, - emb_channels=0, - dropout=dropout, - dims=3, - use_scale_shift_norm=False, - use_conv=False, - up=False, - down=False, - kernel_size=video_kernel_size, - use_checkpoint=False, - skip_t_emb=True, - ) - - self.merge_strategy = merge_strategy - if self.merge_strategy == "fixed": - self.register_buffer("mix_factor", torch.Tensor([alpha])) - elif self.merge_strategy == "learned": - self.register_parameter( - "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) - ) - else: - raise ValueError(f"unknown merge strategy {self.merge_strategy}") - - def get_alpha(self, bs): - if self.merge_strategy == "fixed": - return self.mix_factor - elif self.merge_strategy == "learned": - return torch.sigmoid(self.mix_factor) - else: - raise NotImplementedError() - - def forward(self, x, temb, skip_video=False, timesteps=None): - if timesteps is None: - timesteps = self.timesteps - - b, c, h, w = x.shape - - x = super().forward(x, temb) - - if not skip_video: - x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) - - x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) - - x = self.time_stack(x, temb) - - alpha = self.get_alpha(bs=b // timesteps) - x = alpha * x + (1.0 - alpha) * x_mix - - x = rearrange(x, "b c t h w -> (b t) c h w") - return x - - -class AE3DConv(torch.nn.Conv2d): - def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): - super().__init__(in_channels, out_channels, *args, **kwargs) - if isinstance(video_kernel_size, Iterable): - padding = [int(k // 2) for k in video_kernel_size] - else: - padding = int(video_kernel_size // 2) - - self.time_mix_conv = torch.nn.Conv3d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=video_kernel_size, - padding=padding, - ) - - def forward(self, input, timesteps, skip_video=False): - x = super().forward(input) - if skip_video: - return x - x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) - x = self.time_mix_conv(x) - return rearrange(x, "b c t h w -> (b t) c h w") - - -class VideoBlock(AttnBlock): - def __init__( - self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" - ): - super().__init__(in_channels) - # no context, single headed, as in base class - self.time_mix_block = VideoTransformerBlock( - dim=in_channels, - n_heads=1, - d_head=in_channels, - checkpoint=False, - ff_in=True, - attn_mode="softmax", - ) - - time_embed_dim = self.in_channels * 4 - self.video_time_embed = torch.nn.Sequential( - torch.nn.Linear(self.in_channels, time_embed_dim), - torch.nn.SiLU(), - torch.nn.Linear(time_embed_dim, self.in_channels), - ) - - self.merge_strategy = merge_strategy - if self.merge_strategy == "fixed": - self.register_buffer("mix_factor", torch.Tensor([alpha])) - elif self.merge_strategy == "learned": - self.register_parameter( - "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) - ) - else: - raise ValueError(f"unknown merge strategy {self.merge_strategy}") - - def forward(self, x, timesteps, skip_video=False): - if skip_video: - return super().forward(x) - - x_in = x - x = self.attention(x) - h, w = x.shape[2:] - x = rearrange(x, "b c h w -> b (h w) c") - - x_mix = x - num_frames = torch.arange(timesteps, device=x.device) - num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) - num_frames = rearrange(num_frames, "b t -> (b t)") - t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) - emb = self.video_time_embed(t_emb) # b, n_channels - emb = emb[:, None, :] - x_mix = x_mix + emb - - alpha = self.get_alpha() - x_mix = self.time_mix_block(x_mix, timesteps=timesteps) - x = alpha * x + (1.0 - alpha) * x_mix # alpha merge - - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.proj_out(x) - - return x_in + x - - def get_alpha( - self, - ): - if self.merge_strategy == "fixed": - return self.mix_factor - elif self.merge_strategy == "learned": - return torch.sigmoid(self.mix_factor) - else: - raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") - - -class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): - def __init__( - self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" - ): - super().__init__(in_channels) - # no context, single headed, as in base class - self.time_mix_block = VideoTransformerBlock( - dim=in_channels, - n_heads=1, - d_head=in_channels, - checkpoint=False, - ff_in=True, - attn_mode="softmax-xformers", - ) - - time_embed_dim = self.in_channels * 4 - self.video_time_embed = torch.nn.Sequential( - torch.nn.Linear(self.in_channels, time_embed_dim), - torch.nn.SiLU(), - torch.nn.Linear(time_embed_dim, self.in_channels), - ) - - self.merge_strategy = merge_strategy - if self.merge_strategy == "fixed": - self.register_buffer("mix_factor", torch.Tensor([alpha])) - elif self.merge_strategy == "learned": - self.register_parameter( - "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) - ) - else: - raise ValueError(f"unknown merge strategy {self.merge_strategy}") - - def forward(self, x, timesteps, skip_time_block=False): - if skip_time_block: - return super().forward(x) - - x_in = x - x = self.attention(x) - h, w = x.shape[2:] - x = rearrange(x, "b c h w -> b (h w) c") - - x_mix = x - num_frames = torch.arange(timesteps, device=x.device) - num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) - num_frames = rearrange(num_frames, "b t -> (b t)") - t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) - emb = self.video_time_embed(t_emb) # b, n_channels - emb = emb[:, None, :] - x_mix = x_mix + emb - - alpha = self.get_alpha() - x_mix = self.time_mix_block(x_mix, timesteps=timesteps) - x = alpha * x + (1.0 - alpha) * x_mix # alpha merge - - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - x = self.proj_out(x) - - return x_in + x - - def get_alpha( - self, - ): - if self.merge_strategy == "fixed": - return self.mix_factor - elif self.merge_strategy == "learned": - return torch.sigmoid(self.mix_factor) - else: - raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") - - -def make_time_attn( - in_channels, - attn_type="vanilla", - attn_kwargs=None, - alpha: float = 0, - merge_strategy: str = "learned", -): - assert attn_type in [ - "vanilla", - "vanilla-xformers", - ], f"attn_type {attn_type} not supported for spatio-temporal attention" - print( - f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" - ) - if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": - print( - f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " - f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" - ) - attn_type = "vanilla" - - if attn_type == "vanilla": - assert attn_kwargs is None - return partialclass( - VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy - ) - elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") - return partialclass( - MemoryEfficientVideoBlock, - in_channels, - alpha=alpha, - merge_strategy=merge_strategy, - ) - else: - return NotImplementedError() - - -class Conv2DWrapper(torch.nn.Conv2d): - def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: - return super().forward(input) - - -class VideoDecoder(Decoder): - available_time_modes = ["all", "conv-only", "attn-only"] - - def __init__( - self, - *args, - video_kernel_size: Union[int, list] = 3, - alpha: float = 0.0, - merge_strategy: str = "learned", - time_mode: str = "conv-only", - **kwargs, - ): - self.video_kernel_size = video_kernel_size - self.alpha = alpha - self.merge_strategy = merge_strategy - self.time_mode = time_mode - assert ( - self.time_mode in self.available_time_modes - ), f"time_mode parameter has to be in {self.available_time_modes}" - super().__init__(*args, **kwargs) - - def get_last_layer(self, skip_time_mix=False, **kwargs): - if self.time_mode == "attn-only": - raise NotImplementedError("TODO") - else: - return ( - self.conv_out.time_mix_conv.weight - if not skip_time_mix - else self.conv_out.weight - ) - - def _make_attn(self) -> Callable: - if self.time_mode not in ["conv-only", "only-last-conv"]: - return partialclass( - make_time_attn, - alpha=self.alpha, - merge_strategy=self.merge_strategy, - ) - else: - return super()._make_attn() - - def _make_conv(self) -> Callable: - if self.time_mode != "attn-only": - return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) - else: - return Conv2DWrapper - - def _make_resblock(self) -> Callable: - if self.time_mode not in ["attn-only", "only-last-conv"]: - return partialclass( - VideoResBlock, - video_kernel_size=self.video_kernel_size, - alpha=self.alpha, - merge_strategy=self.merge_strategy, - ) - else: - return super()._make_resblock() diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py deleted file mode 100644 index d86e7a262d1f036139e41f500d8579a2b95071ef..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/denoiser.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Dict, Union - -import torch -import torch.nn as nn - -from ...util import append_dims, instantiate_from_config -from .denoiser_scaling import DenoiserScaling -from .discretizer import Discretization - - -class Denoiser(nn.Module): - def __init__(self, scaling_config: Dict): - super().__init__() - - self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) - - def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: - return sigma - - def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: - return c_noise - - def forward( - self, - network: nn.Module, - input: torch.Tensor, - sigma: torch.Tensor, - cond: Dict, - **additional_model_inputs, - ) -> torch.Tensor: - sigma = self.possibly_quantize_sigma(sigma) - sigma_shape = sigma.shape - sigma = append_dims(sigma, input.ndim) - c_skip, c_out, c_in, c_noise = self.scaling(sigma) - c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) - return ( - network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out - + input * c_skip - ) - - -class DiscreteDenoiser(Denoiser): - def __init__( - self, - scaling_config: Dict, - num_idx: int, - discretization_config: Dict, - do_append_zero: bool = False, - quantize_c_noise: bool = True, - flip: bool = True, - ): - super().__init__(scaling_config) - self.discretization: Discretization = instantiate_from_config( - discretization_config - ) - sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) - self.register_buffer("sigmas", sigmas) - self.quantize_c_noise = quantize_c_noise - self.num_idx = num_idx - - def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: - dists = sigma - self.sigmas[:, None] - return dists.abs().argmin(dim=0).view(sigma.shape) - - def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: - return self.sigmas[idx] - - def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: - return self.idx_to_sigma(self.sigma_to_idx(sigma)) - - def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: - if self.quantize_c_noise: - return self.sigma_to_idx(c_noise) - else: - return c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py deleted file mode 100644 index f4e287bfe8a82839a9a12fbd25c3446f43ab493b..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/denoiser_scaling.py +++ /dev/null @@ -1,59 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Tuple - -import torch - - -class DenoiserScaling(ABC): - @abstractmethod - def __call__( - self, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - pass - - -class EDMScaling: - def __init__(self, sigma_data: float = 0.5): - self.sigma_data = sigma_data - - def __call__( - self, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) - c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 - c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 - c_noise = 0.25 * sigma.log() - return c_skip, c_out, c_in, c_noise - - -class EpsScaling: - def __call__( - self, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = torch.ones_like(sigma, device=sigma.device) - c_out = -sigma - c_in = 1 / (sigma**2 + 1.0) ** 0.5 - c_noise = sigma.clone() - return c_skip, c_out, c_in, c_noise - - -class VScaling: - def __call__( - self, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = 1.0 / (sigma**2 + 1.0) - c_out = -sigma / (sigma**2 + 1.0) ** 0.5 - c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 - c_noise = sigma.clone() - return c_skip, c_out, c_in, c_noise - - -class VScalingWithEDMcNoise(DenoiserScaling): - def __call__( - self, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - c_skip = 1.0 / (sigma**2 + 1.0) - c_out = -sigma / (sigma**2 + 1.0) ** 0.5 - c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 - c_noise = 0.25 * sigma.log() - return c_skip, c_out, c_in, c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py deleted file mode 100644 index b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/denoiser_weighting.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - - -class UnitWeighting: - def __call__(self, sigma): - return torch.ones_like(sigma, device=sigma.device) - - -class EDMWeighting: - def __init__(self, sigma_data=0.5): - self.sigma_data = sigma_data - - def __call__(self, sigma): - return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - -class VWeighting(EDMWeighting): - def __init__(self): - super().__init__(sigma_data=1.0) - - -class EpsWeighting: - def __call__(self, sigma): - return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py deleted file mode 100644 index 02add6081c5e3164d4402619b44d5be235d3ec58..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/discretizer.py +++ /dev/null @@ -1,69 +0,0 @@ -from abc import abstractmethod -from functools import partial - -import numpy as np -import torch - -from ...modules.diffusionmodules.util import make_beta_schedule -from ...util import append_zero - - -def generate_roughly_equally_spaced_steps( - num_substeps: int, max_step: int -) -> np.ndarray: - return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] - - -class Discretization: - def __call__(self, n, do_append_zero=True, device="cpu", flip=False): - sigmas = self.get_sigmas(n, device=device) - sigmas = append_zero(sigmas) if do_append_zero else sigmas - return sigmas if not flip else torch.flip(sigmas, (0,)) - - @abstractmethod - def get_sigmas(self, n, device): - pass - - -class EDMDiscretization(Discretization): - def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): - self.sigma_min = sigma_min - self.sigma_max = sigma_max - self.rho = rho - - def get_sigmas(self, n, device="cpu"): - ramp = torch.linspace(0, 1, n, device=device) - min_inv_rho = self.sigma_min ** (1 / self.rho) - max_inv_rho = self.sigma_max ** (1 / self.rho) - sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho - return sigmas - - -class LegacyDDPMDiscretization(Discretization): - def __init__( - self, - linear_start=0.00085, - linear_end=0.0120, - num_timesteps=1000, - ): - super().__init__() - self.num_timesteps = num_timesteps - betas = make_beta_schedule( - "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end - ) - alphas = 1.0 - betas - self.alphas_cumprod = np.cumprod(alphas, axis=0) - self.to_torch = partial(torch.tensor, dtype=torch.float32) - - def get_sigmas(self, n, device="cpu"): - if n < self.num_timesteps: - timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) - alphas_cumprod = self.alphas_cumprod[timesteps] - elif n == self.num_timesteps: - alphas_cumprod = self.alphas_cumprod - else: - raise ValueError - - to_torch = partial(torch.tensor, dtype=torch.float32, device=device) - sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 - return torch.flip(sigmas, (0,)) diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py deleted file mode 100644 index 63b5775b6ca857b4706f65f8cf3187cc8e4506d8..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/guiders.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging -from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, Union - -import torch -from einops import rearrange, repeat - -from ...util import append_dims, default - -logpy = logging.getLogger(__name__) - - -class Guider(ABC): - @abstractmethod - def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: - pass - - def prepare_inputs( - self, x: torch.Tensor, s: float, c: Dict, uc: Dict - ) -> Tuple[torch.Tensor, float, Dict]: - pass - - -class VanillaCFG(Guider): - def __init__(self, scale: float): - self.scale = scale - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_c = x.chunk(2) - x_pred = x_u + self.scale * (x_c - x_u) - return x_pred - - def prepare_inputs(self, x, s, c, uc): - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"]: - c_out[k] = torch.cat((uc[k], c[k]), 0) - else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out - - -class IdentityGuider(Guider): - def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: - return x - - def prepare_inputs( - self, x: torch.Tensor, s: float, c: Dict, uc: Dict - ) -> Tuple[torch.Tensor, float, Dict]: - c_out = dict() - - for k in c: - c_out[k] = c[k] - - return x, s, c_out - - -class LinearPredictionGuider(Guider): - def __init__( - self, - max_scale: float, - num_frames: int, - min_scale: float = 1.0, - additional_cond_keys: Optional[Union[List[str], str]] = None, - ): - self.min_scale = min_scale - self.max_scale = max_scale - self.num_frames = num_frames - self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) - - additional_cond_keys = default(additional_cond_keys, []) - if isinstance(additional_cond_keys, str): - additional_cond_keys = [additional_cond_keys] - self.additional_cond_keys = additional_cond_keys - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_c = x.chunk(2) - - x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) - x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) - scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) - scale = append_dims(scale, x_u.ndim).to(x_u.device) - - return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") - - def prepare_inputs( - self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: - c_out[k] = torch.cat((uc[k], c[k]), 0) - else: - if k == "rgb": - continue - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out - - -class CentralPredictionGuider(Guider): - def __init__( - self, - max_scale: float, - num_frames: int, - min_scale: float = 1.0, - additional_cond_keys: Optional[Union[List[str], str]] = None, - ): - self.min_scale = min_scale - self.max_scale = max_scale - self.num_frames = num_frames - # self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) - self.scale = torch.linspace(min_scale, 2 * max_scale, num_frames) - self.scale[num_frames // 2 :] = 2 * max_scale - self.scale[num_frames // 2 :] - self.scale = self.scale.unsqueeze(0) - - additional_cond_keys = default(additional_cond_keys, []) - if isinstance(additional_cond_keys, str): - additional_cond_keys = [additional_cond_keys] - self.additional_cond_keys = additional_cond_keys - - def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: - x_u, x_c = x.chunk(2) - - x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) - x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) - scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) - scale = append_dims(scale, x_u.ndim).to(x_u.device) - - return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") - - def prepare_inputs( - self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict - ) -> Tuple[torch.Tensor, torch.Tensor, dict]: - c_out = dict() - - for k in c: - if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: - c_out[k] = torch.cat((uc[k], c[k]), 0) - else: - assert c[k] == uc[k] - c_out[k] = c[k] - return torch.cat([x] * 2), torch.cat([s] * 2), c_out diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py deleted file mode 100644 index 9b2c437fab37bed10ea79c197560ade7bf511cad..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/loss.py +++ /dev/null @@ -1,187 +0,0 @@ -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from einops import rearrange, repeat - -from ...modules.autoencoding.lpips.loss.lpips import LPIPS -from ...modules.encoders.modules import GeneralConditioner -from ...util import append_dims, instantiate_from_config -from .denoiser import Denoiser - - -class StandardDiffusionLoss(nn.Module): - def __init__( - self, - sigma_sampler_config: dict, - loss_weighting_config: dict, - loss_type: str = "l2", - offset_noise_level: float = 0.0, - batch2model_keys: Optional[Union[str, List[str]]] = None, - ): - super().__init__() - - assert loss_type in ["l2", "l1", "lpips"] - - self.sigma_sampler = instantiate_from_config(sigma_sampler_config) - self.loss_weighting = instantiate_from_config(loss_weighting_config) - - self.loss_type = loss_type - self.offset_noise_level = offset_noise_level - - if loss_type == "lpips": - self.lpips = LPIPS().eval() - - if not batch2model_keys: - batch2model_keys = [] - - if isinstance(batch2model_keys, str): - batch2model_keys = [batch2model_keys] - - self.batch2model_keys = set(batch2model_keys) - - def get_noised_input( - self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor - ) -> torch.Tensor: - noised_input = input + noise * sigmas_bc - return noised_input - - def forward( - self, - network: nn.Module, - denoiser: Denoiser, - conditioner: GeneralConditioner, - input: torch.Tensor, - batch: Dict, - return_model_output: bool = False, - ) -> torch.Tensor: - cond = conditioner(batch) - # for video diffusion - if "num_video_frames" in batch: - num_frames = batch["num_video_frames"] - for k in ["crossattn", "concat"]: - cond[k] = repeat(cond[k], "b ... -> b t ...", t=num_frames) - cond[k] = rearrange(cond[k], "b t ... -> (b t) ...", t=num_frames) - return self._forward(network, denoiser, cond, input, batch, return_model_output) - - def _forward( - self, - network: nn.Module, - denoiser: Denoiser, - cond: Dict, - input: torch.Tensor, - batch: Dict, - return_model_output: bool = False, - ) -> Tuple[torch.Tensor, Dict]: - additional_model_inputs = { - key: batch[key] for key in self.batch2model_keys.intersection(batch) - } - sigmas = self.sigma_sampler(input.shape[0]).to(input) - - noise = torch.randn_like(input) - if self.offset_noise_level > 0.0: - offset_shape = ( - (input.shape[0], 1, input.shape[2]) - if self.n_frames is not None - else (input.shape[0], input.shape[1]) - ) - noise = noise + self.offset_noise_level * append_dims( - torch.randn(offset_shape, device=input.device), - input.ndim, - ) - sigmas_bc = append_dims(sigmas, input.ndim) - noised_input = self.get_noised_input(sigmas_bc, noise, input) - - model_output = denoiser( - network, noised_input, sigmas, cond, **additional_model_inputs - ) - w = append_dims(self.loss_weighting(sigmas), input.ndim) - if not return_model_output: - return self.get_loss(model_output, input, w) - else: - return self.get_loss(model_output, input, w), model_output - - def get_loss(self, model_output, target, w): - if self.loss_type == "l2": - return torch.mean( - (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 - ) - elif self.loss_type == "l1": - return torch.mean( - (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 - ) - elif self.loss_type == "lpips": - loss = self.lpips(model_output, target).reshape(-1) - return loss - else: - raise NotImplementedError(f"Unknown loss type {self.loss_type}") - - -class StandardDiffusionLossWithPixelNeRFLoss(StandardDiffusionLoss): - def __init__( - self, - sigma_sampler_config: Dict, - loss_weighting_config: Dict, - loss_type: str = "l2", - offset_noise_level: float = 0, - batch2model_keys: str | List[str] | None = None, - pixelnerf_loss_weight: float = 1.0, - pixelnerf_loss_type: str = "l2", - ): - super().__init__( - sigma_sampler_config, - loss_weighting_config, - loss_type, - offset_noise_level, - batch2model_keys, - ) - self.pixelnerf_loss_weight = pixelnerf_loss_weight - self.pixelnerf_loss_type = pixelnerf_loss_type - - def get_pixelnerf_loss(self, model_output, target): - if self.pixelnerf_loss_type == "l2": - return torch.mean( - ((model_output - target) ** 2).reshape(target.shape[0], -1), 1 - ) - elif self.pixelnerf_loss_type == "l1": - return torch.mean( - ((model_output - target).abs()).reshape(target.shape[0], -1), 1 - ) - elif self.pixelnerf_loss_type == "lpips": - loss = self.lpips(model_output, target).reshape(-1) - return loss - else: - raise NotImplementedError(f"Unknown loss type {self.loss_type}") - - def forward( - self, - network: nn.Module, - denoiser: Denoiser, - conditioner: GeneralConditioner, - input: torch.Tensor, - batch: Dict, - return_model_output: bool = False, - ) -> torch.Tensor: - cond = conditioner(batch) - return self._forward(network, denoiser, cond, input, batch, return_model_output) - - def _forward( - self, - network: nn.Module, - denoiser: Denoiser, - cond: Dict, - input: torch.Tensor, - batch: Dict, - return_model_output: bool = False, - ) -> Tuple[torch.Tensor | Dict]: - loss = super()._forward( - network, denoiser, cond, input, batch, return_model_output - ) - pixelnerf_loss = self.get_pixelnerf_loss( - cond["rgb"], batch["pixelnerf_input"]["rgb"] - ) - - if not return_model_output: - return loss + self.pixelnerf_loss_weight * pixelnerf_loss - else: - return loss[0] + self.pixelnerf_loss_weight * pixelnerf_loss, loss[1] diff --git a/sgm/modules/diffusionmodules/loss_weighting.py b/sgm/modules/diffusionmodules/loss_weighting.py deleted file mode 100644 index e12c0a76635435babd1af33969e82fa284525af8..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/loss_weighting.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import ABC, abstractmethod - -import torch - - -class DiffusionLossWeighting(ABC): - @abstractmethod - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - pass - - -class UnitWeighting(DiffusionLossWeighting): - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - return torch.ones_like(sigma, device=sigma.device) - - -class EDMWeighting(DiffusionLossWeighting): - def __init__(self, sigma_data: float = 0.5): - self.sigma_data = sigma_data - - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 - - -class VWeighting(EDMWeighting): - def __init__(self): - super().__init__(sigma_data=1.0) - - -class EpsWeighting(DiffusionLossWeighting): - def __call__(self, sigma: torch.Tensor) -> torch.Tensor: - return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py deleted file mode 100644 index 4cf9d92140dee8443a0ea6b5cf218f2879ad88f4..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/model.py +++ /dev/null @@ -1,748 +0,0 @@ -# pytorch_diffusion + derived encoder decoder -import logging -import math -from typing import Any, Callable, Optional - -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange -from packaging import version - -logpy = logging.getLogger(__name__) - -try: - import xformers - import xformers.ops - - XFORMERS_IS_AVAILABLE = True -except: - XFORMERS_IS_AVAILABLE = False - logpy.warning("no module 'xformers'. Processing without...") - -from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention - - -def get_timestep_embedding(timesteps, embedding_dim): - """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". - """ - assert len(timesteps.shape) == 1 - - half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) - emb = emb.to(device=timesteps.device) - emb = timesteps.float()[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) - return emb - - -def nonlinearity(x): - # swish - return x * torch.sigmoid(x) - - -def Normalize(in_channels, num_groups=32): - return torch.nn.GroupNorm( - num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True - ) - - -class Upsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x): - x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") - if self.with_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - def __init__(self, in_channels, with_conv): - super().__init__() - self.with_conv = with_conv - if self.with_conv: - # no asymmetric padding in torch conv, must do it ourselves - self.conv = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=3, stride=2, padding=0 - ) - - def forward(self, x): - if self.with_conv: - pad = (0, 1, 0, 1) - x = torch.nn.functional.pad(x, pad, mode="constant", value=0) - x = self.conv(x) - else: - x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout, - temb_channels=512, - ): - super().__init__() - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - - self.norm1 = Normalize(in_channels) - self.conv1 = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if temb_channels > 0: - self.temb_proj = torch.nn.Linear(temb_channels, out_channels) - self.norm2 = Normalize(out_channels) - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = torch.nn.Conv2d( - out_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - self.conv_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=3, stride=1, padding=1 - ) - else: - self.nin_shortcut = torch.nn.Conv2d( - in_channels, out_channels, kernel_size=1, stride=1, padding=0 - ) - - def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) - - if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] - - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) - - if self.in_channels != self.out_channels: - if self.use_conv_shortcut: - x = self.conv_shortcut(x) - else: - x = self.nin_shortcut(x) - - return x + h - - -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) - - -class AttnBlock(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - - def attention(self, h_: torch.Tensor) -> torch.Tensor: - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - b, c, h, w = q.shape - q, k, v = map( - lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) - ) - h_ = torch.nn.functional.scaled_dot_product_attention( - q, k, v - ) # scale is dim ** -0.5 per default - # compute attention - - return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) - - def forward(self, x, **kwargs): - h_ = x - h_ = self.attention(h_) - h_ = self.proj_out(h_) - return x + h_ - - -class MemoryEfficientAttnBlock(nn.Module): - """ - Uses xformers efficient implementation, - see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - Note: this is a single-head self-attention operation - """ - - # - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, in_channels, kernel_size=1, stride=1, padding=0 - ) - self.attention_op: Optional[Any] = None - - def attention(self, h_: torch.Tensor) -> torch.Tensor: - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - B, C, H, W = q.shape - q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) - - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(B, t.shape[1], 1, C) - .permute(0, 2, 1, 3) - .reshape(B * 1, t.shape[1], C) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention( - q, k, v, attn_bias=None, op=self.attention_op - ) - - out = ( - out.unsqueeze(0) - .reshape(B, 1, out.shape[1], C) - .permute(0, 2, 1, 3) - .reshape(B, out.shape[1], C) - ) - return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) - - def forward(self, x, **kwargs): - h_ = x - h_ = self.attention(h_) - h_ = self.proj_out(h_) - return x + h_ - - -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None, **unused_kwargs): - b, c, h, w = x.shape - x = rearrange(x, "b c h w -> b (h w) c") - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) - return x + out - - -def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): - assert attn_type in [ - "vanilla", - "vanilla-xformers", - "memory-efficient-cross-attn", - "linear", - "none", - ], f"attn_type {attn_type} unknown" - if ( - version.parse(torch.__version__) < version.parse("2.0.0") - and attn_type != "none" - ): - assert XFORMERS_IS_AVAILABLE, ( - f"We do not support vanilla attention in {torch.__version__} anymore, " - f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" - ) - attn_type = "vanilla-xformers" - logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") - if attn_type == "vanilla": - assert attn_kwargs is None - return AttnBlock(in_channels) - elif attn_type == "vanilla-xformers": - logpy.info( - f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." - ) - return MemoryEfficientAttnBlock(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) - elif attn_type == "none": - return nn.Identity(in_channels) - else: - return LinAttnBlock(in_channels) - - -class Model(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - use_timestep=True, - use_linear_attn=False, - attn_type="vanilla", - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = self.ch * 4 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - self.use_timestep = use_timestep - if self.use_timestep: - # timestep embedding - self.temb = nn.Module() - self.temb.dense = nn.ModuleList( - [ - torch.nn.Linear(self.ch, self.temb_ch), - torch.nn.Linear(self.temb_ch, self.temb_ch), - ] - ) - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - skip_in = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - if i_block == self.num_res_blocks: - skip_in = ch * in_ch_mult[i_level] - block.append( - ResnetBlock( - in_channels=block_in + skip_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def forward(self, x, t=None, context=None): - # assert x.shape[2] == x.shape[3] == self.resolution - if context is not None: - # assume aligned context, cat along channel axis - x = torch.cat((x, context), dim=1) - if self.use_timestep: - # timestep embedding - assert t is not None - temb = get_timestep_embedding(t, self.ch) - temb = self.temb.dense[0](temb) - temb = nonlinearity(temb) - temb = self.temb.dense[1](temb) - else: - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block]( - torch.cat([h, hs.pop()], dim=1), temb - ) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - def get_last_layer(self): - return self.conv_out.weight - - -class Encoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - double_z=True, - use_linear_attn=False, - attn_type="vanilla", - **ignore_kwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - - # downsampling - self.conv_in = torch.nn.Conv2d( - in_channels, self.ch, kernel_size=3, stride=1, padding=1 - ) - - curr_res = resolution - in_ch_mult = (1,) + tuple(ch_mult) - self.in_ch_mult = in_ch_mult - self.down = nn.ModuleList() - for i_level in range(self.num_resolutions): - block = nn.ModuleList() - attn = nn.ModuleList() - block_in = ch * in_ch_mult[i_level] - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks): - block.append( - ResnetBlock( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn(block_in, attn_type=attn_type)) - down = nn.Module() - down.block = block - down.attn = attn - if i_level != self.num_resolutions - 1: - down.downsample = Downsample(block_in, resamp_with_conv) - curr_res = curr_res // 2 - self.down.append(down) - - # middle - self.mid = nn.Module() - self.mid.block_1 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) - self.mid.block_2 = ResnetBlock( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # end - self.norm_out = Normalize(block_in) - self.conv_out = torch.nn.Conv2d( - block_in, - 2 * z_channels if double_z else z_channels, - kernel_size=3, - stride=1, - padding=1, - ) - - def forward(self, x): - # timestep embedding - temb = None - - # downsampling - hs = [self.conv_in(x)] - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](hs[-1], temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - hs.append(h) - if i_level != self.num_resolutions - 1: - hs.append(self.down[i_level].downsample(hs[-1])) - - # middle - h = hs[-1] - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) - - # end - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) - return h - - -class Decoder(nn.Module): - def __init__( - self, - *, - ch, - out_ch, - ch_mult=(1, 2, 4, 8), - num_res_blocks, - attn_resolutions, - dropout=0.0, - resamp_with_conv=True, - in_channels, - resolution, - z_channels, - give_pre_end=False, - tanh_out=False, - use_linear_attn=False, - attn_type="vanilla", - **ignorekwargs, - ): - super().__init__() - if use_linear_attn: - attn_type = "linear" - self.ch = ch - self.temb_ch = 0 - self.num_resolutions = len(ch_mult) - self.num_res_blocks = num_res_blocks - self.resolution = resolution - self.in_channels = in_channels - self.give_pre_end = give_pre_end - self.tanh_out = tanh_out - - # compute in_ch_mult, block_in and curr_res at lowest res - in_ch_mult = (1,) + tuple(ch_mult) - block_in = ch * ch_mult[self.num_resolutions - 1] - curr_res = resolution // 2 ** (self.num_resolutions - 1) - self.z_shape = (1, z_channels, curr_res, curr_res) - logpy.info( - "Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape) - ) - ) - - make_attn_cls = self._make_attn() - make_resblock_cls = self._make_resblock() - make_conv_cls = self._make_conv() - # z to block_in - self.conv_in = torch.nn.Conv2d( - z_channels, block_in, kernel_size=3, stride=1, padding=1 - ) - - # middle - self.mid = nn.Module() - self.mid.block_1 = make_resblock_cls( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) - self.mid.block_2 = make_resblock_cls( - in_channels=block_in, - out_channels=block_in, - temb_channels=self.temb_ch, - dropout=dropout, - ) - - # upsampling - self.up = nn.ModuleList() - for i_level in reversed(range(self.num_resolutions)): - block = nn.ModuleList() - attn = nn.ModuleList() - block_out = ch * ch_mult[i_level] - for i_block in range(self.num_res_blocks + 1): - block.append( - make_resblock_cls( - in_channels=block_in, - out_channels=block_out, - temb_channels=self.temb_ch, - dropout=dropout, - ) - ) - block_in = block_out - if curr_res in attn_resolutions: - attn.append(make_attn_cls(block_in, attn_type=attn_type)) - up = nn.Module() - up.block = block - up.attn = attn - if i_level != 0: - up.upsample = Upsample(block_in, resamp_with_conv) - curr_res = curr_res * 2 - self.up.insert(0, up) # prepend to get consistent order - - # end - self.norm_out = Normalize(block_in) - self.conv_out = make_conv_cls( - block_in, out_ch, kernel_size=3, stride=1, padding=1 - ) - - def _make_attn(self) -> Callable: - return make_attn - - def _make_resblock(self) -> Callable: - return ResnetBlock - - def _make_conv(self) -> Callable: - return torch.nn.Conv2d - - def get_last_layer(self, **kwargs): - return self.conv_out.weight - - def forward(self, z, **kwargs): - # assert z.shape[1:] == self.z_shape[1:] - self.last_z_shape = z.shape - - # timestep embedding - temb = None - - # z to block_in - h = self.conv_in(z) - - # middle - h = self.mid.block_1(h, temb, **kwargs) - h = self.mid.attn_1(h, **kwargs) - h = self.mid.block_2(h, temb, **kwargs) - - # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) - - # end - if self.give_pre_end: - return h - - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py deleted file mode 100644 index e762e6823540def71743e27131e284ea28cdb56e..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ /dev/null @@ -1,863 +0,0 @@ -import logging -import math -from abc import abstractmethod -from typing import Iterable, List, Optional, Tuple, Union - -import torch -import torch as th -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from functools import partial - -# from torch.utils.checkpoint import checkpoint - -checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - -from ...modules.attention import SpatialTransformer -from ...modules.diffusionmodules.util import ( - avg_pool_nd, - conv_nd, - linear, - normalization, - timestep_embedding, - zero_module, -) -from ...modules.video_attention import SpatialVideoTransformer -from ...util import exists - -logpy = logging.getLogger(__name__) - - -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: Optional[int] = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x: th.Tensor) -> th.Tensor: - b, c, _ = x.shape - x = x.reshape(b, c, -1) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x: th.Tensor, emb: th.Tensor): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward( - self, - x: th.Tensor, - emb: th.Tensor, - context: Optional[th.Tensor] = None, - image_only_indicator: Optional[th.Tensor] = None, - time_context: Optional[int] = None, - num_video_frames: Optional[int] = None, - ): - from ...modules.diffusionmodules.video_model import VideoResBlock - - for layer in self: - module = layer - - if isinstance(module, TimestepBlock) and not isinstance( - module, VideoResBlock - ): - x = layer(x, emb) - elif isinstance(module, VideoResBlock): - x = layer(x, emb, num_video_frames, image_only_indicator) - elif isinstance(module, SpatialVideoTransformer): - x = layer( - x, - context, - time_context, - num_video_frames, - image_only_indicator, - ) - elif isinstance(module, SpatialTransformer): - x = layer(x, context) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__( - self, - channels: int, - use_conv: bool, - dims: int = 2, - out_channels: Optional[int] = None, - padding: int = 1, - third_up: bool = False, - kernel_size: int = 3, - scale_factor: int = 2, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - self.third_up = third_up - self.scale_factor = scale_factor - if use_conv: - self.conv = conv_nd( - dims, self.channels, self.out_channels, kernel_size, padding=padding - ) - - def forward(self, x: th.Tensor) -> th.Tensor: - assert x.shape[1] == self.channels - - if self.dims == 3: - t_factor = 1 if not self.third_up else self.scale_factor - x = F.interpolate( - x, - ( - t_factor * x.shape[2], - x.shape[3] * self.scale_factor, - x.shape[4] * self.scale_factor, - ), - mode="nearest", - ) - else: - x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__( - self, - channels: int, - use_conv: bool, - dims: int = 2, - out_channels: Optional[int] = None, - padding: int = 1, - third_down: bool = False, - ): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) - if use_conv: - logpy.info(f"Building a Downsample layer with {dims} dims.") - logpy.info( - f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " - f"kernel-size: 3, stride: {stride}, padding: {padding}" - ) - if dims == 3: - logpy.info(f" --> Downsampling third axis (time): {third_down}") - self.op = conv_nd( - dims, - self.channels, - self.out_channels, - 3, - stride=stride, - padding=padding, - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x: th.Tensor) -> th.Tensor: - assert x.shape[1] == self.channels - - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels: int, - emb_channels: int, - dropout: float, - out_channels: Optional[int] = None, - use_conv: bool = False, - use_scale_shift_norm: bool = False, - dims: int = 2, - use_checkpoint: bool = False, - up: bool = False, - down: bool = False, - kernel_size: int = 3, - exchange_temb_dims: bool = False, - skip_t_emb: bool = False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - self.exchange_temb_dims = exchange_temb_dims - - if isinstance(kernel_size, Iterable): - padding = [k // 2 for k in kernel_size] - else: - padding = kernel_size // 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.skip_t_emb = skip_t_emb - self.emb_out_channels = ( - 2 * self.out_channels if use_scale_shift_norm else self.out_channels - ) - if self.skip_t_emb: - logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}") - assert not self.use_scale_shift_norm - self.emb_layers = None - self.exchange_temb_dims = False - else: - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - self.emb_out_channels, - ), - ) - - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd( - dims, - self.out_channels, - self.out_channels, - kernel_size, - padding=padding, - ) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, kernel_size, padding=padding - ) - else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) - - def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - if self.use_checkpoint: - return checkpoint(self._forward, x, emb) - else: - return self._forward(x, emb) - - def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - - if self.skip_t_emb: - emb_out = th.zeros_like(h) - else: - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - if self.exchange_temb_dims: - emb_out = rearrange(emb_out, "b t c ... -> b c t ...") - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels: int, - num_heads: int = 1, - num_head_channels: int = -1, - use_checkpoint: bool = False, - use_new_attention_order: bool = False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x: th.Tensor, **kwargs) -> th.Tensor: - return checkpoint(self._forward, x) - - def _forward(self, x: th.Tensor) -> th.Tensor: - b, c, *spatial = x.shape - x = x.reshape(b, c, -1) - qkv = self.qkv(self.norm(x)) - h = self.attention(qkv) - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads: int): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv: th.Tensor) -> th.Tensor: - """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads: int): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv: th.Tensor) -> th.Tensor: - """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - -class Timestep(nn.Module): - def __init__(self, dim: int): - super().__init__() - self.dim = dim - - def forward(self, t: th.Tensor) -> th.Tensor: - return timestep_embedding(t, self.dim) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - in_channels: int, - model_channels: int, - out_channels: int, - num_res_blocks: int, - attention_resolutions: int, - dropout: float = 0.0, - channel_mult: Union[List, Tuple] = (1, 2, 4, 8), - conv_resample: bool = True, - dims: int = 2, - num_classes: Optional[Union[int, str]] = None, - use_checkpoint: bool = False, - num_heads: int = -1, - num_head_channels: int = -1, - num_heads_upsample: int = -1, - use_scale_shift_norm: bool = False, - resblock_updown: bool = False, - transformer_depth: int = 1, - context_dim: Optional[int] = None, - disable_self_attentions: Optional[List[bool]] = None, - num_attention_blocks: Optional[List[int]] = None, - disable_middle_self_attn: bool = False, - disable_middle_transformer: bool = False, - use_linear_in_transformer: bool = False, - spatial_transformer_attn_type: str = "softmax", - adm_in_channels: Optional[int] = None, - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert ( - num_head_channels != -1 - ), "Either num_heads or num_head_channels has to be set" - - if num_head_channels == -1: - assert ( - num_heads != -1 - ), "Either num_heads or num_head_channels has to be set" - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - transformer_depth_middle = transformer_depth[-1] - - if isinstance(num_res_blocks, int): - self.num_res_blocks = len(channel_mult) * [num_res_blocks] - else: - if len(num_res_blocks) != len(channel_mult): - raise ValueError( - "provide num_res_blocks either as an int (globally constant) or " - "as a list/tuple (per-level) with the same length as channel_mult" - ) - self.num_res_blocks = num_res_blocks - - if disable_self_attentions is not None: - assert len(disable_self_attentions) == len(channel_mult) - if num_attention_blocks is not None: - assert len(num_attention_blocks) == len(self.num_res_blocks) - assert all( - map( - lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], - range(len(num_attention_blocks)), - ) - ) - logpy.info( - f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " - f"This option has LESS priority than attention_resolutions {attention_resolutions}, " - f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " - f"attention will still not be set." - ) - - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - elif self.num_classes == "continuous": - logpy.info("setting up linear c_adm embedding layer") - self.label_emb = nn.Linear(1, time_embed_dim) - elif self.num_classes == "timestep": - self.label_emb = nn.Sequential( - Timestep(model_channels), - nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ), - ) - elif self.num_classes == "sequential": - assert adm_in_channels is not None - self.label_emb = nn.Sequential( - nn.Sequential( - linear(adm_in_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - ) - else: - raise ValueError - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - for level, mult in enumerate(channel_mult): - for nr in range(self.num_res_blocks[level]): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - if context_dim is not None and exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if ( - not exists(num_attention_blocks) - or nr < num_attention_blocks[level] - ): - layers.append( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth[level], - context_dim=context_dim, - disable_self_attn=disabled_sa, - use_linear=use_linear_in_transformer, - attn_type=spatial_transformer_attn_type, - use_checkpoint=use_checkpoint, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth_middle, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - attn_type=spatial_transformer_attn_type, - use_checkpoint=use_checkpoint, - ) - if not disable_middle_transformer - else th.nn.Identity(), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(self.num_res_blocks[level] + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - if exists(disable_self_attentions): - disabled_sa = disable_self_attentions[level] - else: - disabled_sa = False - - if ( - not exists(num_attention_blocks) - or i < num_attention_blocks[level] - ): - layers.append( - SpatialTransformer( - ch, - num_heads, - dim_head, - depth=transformer_depth[level], - context_dim=context_dim, - disable_self_attn=disabled_sa, - use_linear=use_linear_in_transformer, - attn_type=spatial_transformer_attn_type, - use_checkpoint=use_checkpoint, - ) - ) - if level and i == self.num_res_blocks[level]: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - - def forward( - self, - x: th.Tensor, - timesteps: Optional[th.Tensor] = None, - context: Optional[th.Tensor] = None, - y: Optional[th.Tensor] = None, - **kwargs, - ) -> th.Tensor: - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for module in self.input_blocks: - h = module(h, emb, context) - hs.append(h) - h = self.middle_block(h, emb, context) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb, context) - h = h.type(x.dtype) - - return self.out(h) diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py deleted file mode 100644 index 6346829c86a76ab549ed69431f1704e01379535a..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/sampling.py +++ /dev/null @@ -1,365 +0,0 @@ -""" - Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py -""" - - -from typing import Dict, Union - -import torch -from omegaconf import ListConfig, OmegaConf -from tqdm import tqdm - -from ...modules.diffusionmodules.sampling_utils import ( - get_ancestral_step, - linear_multistep_coeff, - to_d, - to_neg_log_sigma, - to_sigma, -) -from ...util import append_dims, default, instantiate_from_config - -DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} - - -class BaseDiffusionSampler: - def __init__( - self, - discretization_config: Union[Dict, ListConfig, OmegaConf], - num_steps: Union[int, None] = None, - guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, - verbose: bool = False, - device: str = "cuda", - ): - self.num_steps = num_steps - self.discretization = instantiate_from_config(discretization_config) - self.guider = instantiate_from_config( - default( - guider_config, - DEFAULT_GUIDER, - ) - ) - self.verbose = verbose - self.device = device - - def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): - sigmas = self.discretization( - self.num_steps if num_steps is None else num_steps, device=self.device - ) - uc = default(uc, cond) - - x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) - num_sigmas = len(sigmas) - - s_in = x.new_ones([x.shape[0]]) - - return x, s_in, sigmas, num_sigmas, cond, uc - - def denoise(self, x, denoiser, sigma, cond, uc): - denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) - denoised = self.guider(denoised, sigma) - return denoised - - def get_sigma_gen(self, num_sigmas): - sigma_generator = range(num_sigmas - 1) - if self.verbose: - print("#" * 30, " Sampling setting ", "#" * 30) - print(f"Sampler: {self.__class__.__name__}") - print(f"Discretization: {self.discretization.__class__.__name__}") - print(f"Guider: {self.guider.__class__.__name__}") - sigma_generator = tqdm( - sigma_generator, - total=num_sigmas, - desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", - ) - return sigma_generator - - -class SingleStepDiffusionSampler(BaseDiffusionSampler): - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): - raise NotImplementedError - - def euler_step(self, x, d, dt): - return x + dt * d - - -class EDMSampler(SingleStepDiffusionSampler): - def __init__( - self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs - ): - super().__init__(*args, **kwargs) - - self.s_churn = s_churn - self.s_tmin = s_tmin - self.s_tmax = s_tmax - self.s_noise = s_noise - - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): - sigma_hat = sigma * (gamma + 1.0) - if gamma > 0: - eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 - - denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) - d = to_d(x, sigma_hat, denoised) - dt = append_dims(next_sigma - sigma_hat, x.ndim) - - euler_step = self.euler_step(x, d, dt) - x = self.possible_correction_step( - euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ) - return x - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - for i in self.get_sigma_gen(num_sigmas): - gamma = ( - min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) - if self.s_tmin <= sigmas[i] <= self.s_tmax - else 0.0 - ) - x = self.sampler_step( - s_in * sigmas[i], - s_in * sigmas[i + 1], - denoiser, - x, - cond, - uc, - gamma, - ) - - return x - - -class AncestralSampler(SingleStepDiffusionSampler): - def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): - super().__init__(*args, **kwargs) - - self.eta = eta - self.s_noise = s_noise - self.noise_sampler = lambda x: torch.randn_like(x) - - def ancestral_euler_step(self, x, denoised, sigma, sigma_down): - d = to_d(x, sigma, denoised) - dt = append_dims(sigma_down - sigma, x.ndim) - - return self.euler_step(x, d, dt) - - def ancestral_step(self, x, sigma, next_sigma, sigma_up): - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, - x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), - x, - ) - return x - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - for i in self.get_sigma_gen(num_sigmas): - x = self.sampler_step( - s_in * sigmas[i], - s_in * sigmas[i + 1], - denoiser, - x, - cond, - uc, - ) - - return x - - -class LinearMultistepSampler(BaseDiffusionSampler): - def __init__( - self, - order=4, - *args, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.order = order - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - ds = [] - sigmas_cpu = sigmas.detach().cpu().numpy() - for i in self.get_sigma_gen(num_sigmas): - sigma = s_in * sigmas[i] - denoised = denoiser( - *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs - ) - denoised = self.guider(denoised, sigma) - d = to_d(x, sigma, denoised) - ds.append(d) - if len(ds) > self.order: - ds.pop(0) - cur_order = min(i + 1, self.order) - coeffs = [ - linear_multistep_coeff(cur_order, sigmas_cpu, i, j) - for j in range(cur_order) - ] - x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) - - return x - - -class EulerEDMSampler(EDMSampler): - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): - return euler_step - - -class HeunEDMSampler(EDMSampler): - def possible_correction_step( - self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc - ): - if torch.sum(next_sigma) < 1e-14: - # Save a network evaluation if all noise levels are 0 - return euler_step - else: - denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) - d_new = to_d(euler_step, next_sigma, denoised) - d_prime = (d + d_new) / 2.0 - - # apply correction if noise level is not 0 - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step - ) - return x - - -class EulerAncestralSampler(AncestralSampler): - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): - sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) - denoised = self.denoise(x, denoiser, sigma, cond, uc) - x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) - x = self.ancestral_step(x, sigma, next_sigma, sigma_up) - - return x - - -class DPMPP2SAncestralSampler(AncestralSampler): - def get_variables(self, sigma, sigma_down): - t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] - h = t_next - t - s = t + 0.5 * h - return h, s, t, t_next - - def get_mult(self, h, s, t, t_next): - mult1 = to_sigma(s) / to_sigma(t) - mult2 = (-0.5 * h).expm1() - mult3 = to_sigma(t_next) / to_sigma(t) - mult4 = (-h).expm1() - - return mult1, mult2, mult3, mult4 - - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): - sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) - denoised = self.denoise(x, denoiser, sigma, cond, uc) - x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) - - if torch.sum(sigma_down) < 1e-14: - # Save a network evaluation if all noise levels are 0 - x = x_euler - else: - h, s, t, t_next = self.get_variables(sigma, sigma_down) - mult = [ - append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) - ] - - x2 = mult[0] * x - mult[1] * denoised - denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) - x_dpmpp2s = mult[2] * x - mult[3] * denoised2 - - # apply correction if noise level is not 0 - x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) - - x = self.ancestral_step(x, sigma, next_sigma, sigma_up) - return x - - -class DPMPP2MSampler(BaseDiffusionSampler): - def get_variables(self, sigma, next_sigma, previous_sigma=None): - t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] - h = t_next - t - - if previous_sigma is not None: - h_last = t - to_neg_log_sigma(previous_sigma) - r = h_last / h - return h, r, t, t_next - else: - return h, None, t, t_next - - def get_mult(self, h, r, t, t_next, previous_sigma): - mult1 = to_sigma(t_next) / to_sigma(t) - mult2 = (-h).expm1() - - if previous_sigma is not None: - mult3 = 1 + 1 / (2 * r) - mult4 = 1 / (2 * r) - return mult1, mult2, mult3, mult4 - else: - return mult1, mult2 - - def sampler_step( - self, - old_denoised, - previous_sigma, - sigma, - next_sigma, - denoiser, - x, - cond, - uc=None, - ): - denoised = self.denoise(x, denoiser, sigma, cond, uc) - - h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) - mult = [ - append_dims(mult, x.ndim) - for mult in self.get_mult(h, r, t, t_next, previous_sigma) - ] - - x_standard = mult[0] * x - mult[1] * denoised - if old_denoised is None or torch.sum(next_sigma) < 1e-14: - # Save a network evaluation if all noise levels are 0 or on the first step - return x_standard, denoised - else: - denoised_d = mult[2] * denoised - mult[3] * old_denoised - x_advanced = mult[0] * x - mult[1] * denoised_d - - # apply correction if noise level is not 0 and not first step - x = torch.where( - append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard - ) - - return x, denoised - - def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): - x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( - x, cond, uc, num_steps - ) - - old_denoised = None - for i in self.get_sigma_gen(num_sigmas): - x, old_denoised = self.sampler_step( - old_denoised, - None if i == 0 else s_in * sigmas[i - 1], - s_in * sigmas[i], - s_in * sigmas[i + 1], - denoiser, - x, - cond, - uc=uc, - ) - - return x diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py deleted file mode 100644 index ce78527ea9052a8bfd0856ed2278901516fb9130..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/sampling_utils.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -from scipy import integrate - -from ...util import append_dims - - -def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): - if order - 1 > i: - raise ValueError(f"Order {order} too high for step {i}") - - def fn(tau): - prod = 1.0 - for k in range(order): - if j == k: - continue - prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) - return prod - - return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] - - -def get_ancestral_step(sigma_from, sigma_to, eta=1.0): - if not eta: - return sigma_to, 0.0 - sigma_up = torch.minimum( - sigma_to, - eta - * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, - ) - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - return sigma_down, sigma_up - - -def to_d(x, sigma, denoised): - return (x - denoised) / append_dims(sigma, x.ndim) - - -def to_neg_log_sigma(sigma): - return sigma.log().neg() - - -def to_sigma(neg_log_sigma): - return neg_log_sigma.neg().exp() diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py deleted file mode 100644 index d54724c6ef6a7b8067784a4192b0fe2f41123063..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/sigma_sampling.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from ...util import default, instantiate_from_config - - -class EDMSampling: - def __init__(self, p_mean=-1.2, p_std=1.2): - self.p_mean = p_mean - self.p_std = p_std - - def __call__(self, n_samples, rand=None): - log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) - return log_sigma.exp() - - -class DiscreteSampling: - def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): - self.num_idx = num_idx - self.sigmas = instantiate_from_config(discretization_config)( - num_idx, do_append_zero=do_append_zero, flip=flip - ) - - def idx_to_sigma(self, idx): - return self.sigmas[idx] - - def __call__(self, n_samples, rand=None): - idx = default( - rand, - torch.randint(0, self.num_idx, (n_samples,)), - ) - return self.idx_to_sigma(idx) diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py deleted file mode 100644 index 389f0e449367b1b628d61dca105343d066dbefff..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/util.py +++ /dev/null @@ -1,369 +0,0 @@ -""" -partially adopted from -https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py -and -https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -and -https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py - -thanks! -""" - -import math -from typing import Optional - -import torch -import torch.nn as nn -from einops import rearrange, repeat - - -def make_beta_schedule( - schedule, - n_timestep, - linear_start=1e-4, - linear_end=2e-2, -): - if schedule == "linear": - betas = ( - torch.linspace( - linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 - ) - ** 2 - ) - return betas.numpy() - - -def extract_into_tensor(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def mixed_checkpoint(func, inputs: dict, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function - borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that - it also works with non-tensor inputs - :param func: the function to evaluate. - :param inputs: the argument dictionary to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] - tensor_inputs = [ - inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) - ] - non_tensor_keys = [ - key for key in inputs if not isinstance(inputs[key], torch.Tensor) - ] - non_tensor_inputs = [ - inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) - ] - args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) - return MixedCheckpointFunction.apply( - func, - len(tensor_inputs), - len(non_tensor_inputs), - tensor_keys, - non_tensor_keys, - *args, - ) - else: - return func(**inputs) - - -class MixedCheckpointFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - run_function, - length_tensors, - length_non_tensors, - tensor_keys, - non_tensor_keys, - *args, - ): - ctx.end_tensors = length_tensors - ctx.end_non_tensors = length_tensors + length_non_tensors - ctx.gpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled(), - } - assert ( - len(tensor_keys) == length_tensors - and len(non_tensor_keys) == length_non_tensors - ) - - ctx.input_tensors = { - key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) - } - ctx.input_non_tensors = { - key: val - for (key, val) in zip( - non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) - ) - } - ctx.run_function = run_function - ctx.input_params = list(args[ctx.end_non_tensors :]) - - with torch.no_grad(): - output_tensors = ctx.run_function( - **ctx.input_tensors, **ctx.input_non_tensors - ) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} - ctx.input_tensors = { - key: ctx.input_tensors[key].detach().requires_grad_(True) - for key in ctx.input_tensors - } - - with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = { - key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) - for key in ctx.input_tensors - } - # shallow_copies.update(additional_args) - output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) - input_grads = torch.autograd.grad( - output_tensors, - list(ctx.input_tensors.values()) + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return ( - (None, None, None, None, None) - + input_grads[: ctx.end_tensors] - + (None,) * (ctx.end_non_tensors - ctx.end_tensors) - + input_grads[ctx.end_tensors :] - ) - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - ctx.gpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled(), - "dtype": torch.get_autocast_gpu_dtype(), - "cache_enabled": torch.is_autocast_cache_enabled(), - } - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads - - -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - else: - embedding = repeat(timesteps, "b -> b d", d=dim) - return embedding - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * torch.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -class AlphaBlender(nn.Module): - strategies = ["learned", "fixed", "learned_with_images"] - - def __init__( - self, - alpha: float, - merge_strategy: str = "learned_with_images", - rearrange_pattern: str = "b t -> (b t) 1 1", - ): - super().__init__() - self.merge_strategy = merge_strategy - self.rearrange_pattern = rearrange_pattern - - assert ( - merge_strategy in self.strategies - ), f"merge_strategy needs to be in {self.strategies}" - - if self.merge_strategy == "fixed": - self.register_buffer("mix_factor", torch.Tensor([alpha])) - elif ( - self.merge_strategy == "learned" - or self.merge_strategy == "learned_with_images" - ): - self.register_parameter( - "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) - ) - else: - raise ValueError(f"unknown merge strategy {self.merge_strategy}") - - def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: - if self.merge_strategy == "fixed": - alpha = self.mix_factor - elif self.merge_strategy == "learned": - alpha = torch.sigmoid(self.mix_factor) - elif self.merge_strategy == "learned_with_images": - assert image_only_indicator is not None, "need image_only_indicator ..." - alpha = torch.where( - image_only_indicator.bool(), - torch.ones(1, 1, device=image_only_indicator.device), - rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), - ) - alpha = rearrange(alpha, self.rearrange_pattern) - else: - raise NotImplementedError - return alpha - - def forward( - self, - x_spatial: torch.Tensor, - x_temporal: torch.Tensor, - image_only_indicator: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - alpha = self.get_alpha(image_only_indicator) - x = ( - alpha.to(x_spatial.dtype) * x_spatial - + (1.0 - alpha).to(x_spatial.dtype) * x_temporal - ) - return x diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py deleted file mode 100644 index ff2d077c7d0c7ed1c4a2c21f14105c266abc4926..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/video_model.py +++ /dev/null @@ -1,493 +0,0 @@ -from functools import partial -from typing import List, Optional, Union - -from einops import rearrange - -from ...modules.diffusionmodules.openaimodel import * -from ...modules.video_attention import SpatialVideoTransformer -from ...util import default -from .util import AlphaBlender - - -class VideoResBlock(ResBlock): - def __init__( - self, - channels: int, - emb_channels: int, - dropout: float, - video_kernel_size: Union[int, List[int]] = 3, - merge_strategy: str = "fixed", - merge_factor: float = 0.5, - out_channels: Optional[int] = None, - use_conv: bool = False, - use_scale_shift_norm: bool = False, - dims: int = 2, - use_checkpoint: bool = False, - up: bool = False, - down: bool = False, - ): - super().__init__( - channels, - emb_channels, - dropout, - out_channels=out_channels, - use_conv=use_conv, - use_scale_shift_norm=use_scale_shift_norm, - dims=dims, - use_checkpoint=use_checkpoint, - up=up, - down=down, - ) - - self.time_stack = ResBlock( - default(out_channels, channels), - emb_channels, - dropout=dropout, - dims=3, - out_channels=default(out_channels, channels), - use_scale_shift_norm=False, - use_conv=False, - up=False, - down=False, - kernel_size=video_kernel_size, - use_checkpoint=use_checkpoint, - exchange_temb_dims=True, - ) - self.time_mixer = AlphaBlender( - alpha=merge_factor, - merge_strategy=merge_strategy, - rearrange_pattern="b t -> b 1 t 1 1", - ) - - def forward( - self, - x: th.Tensor, - emb: th.Tensor, - num_video_frames: int, - image_only_indicator: Optional[th.Tensor] = None, - ) -> th.Tensor: - x = super().forward(x, emb) - - x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) - x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) - - x = self.time_stack( - x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) - ) - x = self.time_mixer( - x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator - ) - x = rearrange(x, "b c t h w -> (b t) c h w") - return x - - -class VideoUNet(nn.Module): - def __init__( - self, - in_channels: int, - model_channels: int, - out_channels: int, - num_res_blocks: int, - attention_resolutions: int, - dropout: float = 0.0, - channel_mult: List[int] = (1, 2, 4, 8), - conv_resample: bool = True, - dims: int = 2, - num_classes: Optional[int] = None, - use_checkpoint: bool = False, - num_heads: int = -1, - num_head_channels: int = -1, - num_heads_upsample: int = -1, - use_scale_shift_norm: bool = False, - resblock_updown: bool = False, - transformer_depth: Union[List[int], int] = 1, - transformer_depth_middle: Optional[int] = None, - context_dim: Optional[int] = None, - time_downup: bool = False, - time_context_dim: Optional[int] = None, - extra_ff_mix_layer: bool = False, - use_spatial_context: bool = False, - merge_strategy: str = "fixed", - merge_factor: float = 0.5, - spatial_transformer_attn_type: str = "softmax", - video_kernel_size: Union[int, List[int]] = 3, - use_linear_in_transformer: bool = False, - adm_in_channels: Optional[int] = None, - disable_temporal_crossattention: bool = False, - max_ddpm_temb_period: int = 10000, - ): - super().__init__() - assert context_dim is not None - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - if num_heads == -1: - assert num_head_channels != -1 - - if num_head_channels == -1: - assert num_heads != -1 - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - if isinstance(transformer_depth, int): - transformer_depth = len(channel_mult) * [transformer_depth] - transformer_depth_middle = default( - transformer_depth_middle, transformer_depth[-1] - ) - - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - if isinstance(self.num_classes, int): - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") - self.label_emb = nn.Linear(1, time_embed_dim) - elif self.num_classes == "timestep": - self.label_emb = nn.Sequential( - Timestep(model_channels), - nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ), - ) - - elif self.num_classes == "sequential": - assert adm_in_channels is not None - self.label_emb = nn.Sequential( - nn.Sequential( - linear(adm_in_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - ) - else: - raise ValueError() - - self.input_blocks = nn.ModuleList( - [ - TimestepEmbedSequential( - conv_nd(dims, in_channels, model_channels, 3, padding=1) - ) - ] - ) - self._feature_size = model_channels - input_block_chans = [model_channels] - ch = model_channels - ds = 1 - - def get_attention_layer( - ch, - num_heads, - dim_head, - depth=1, - context_dim=None, - use_checkpoint=False, - disabled_sa=False, - ): - return SpatialVideoTransformer( - ch, - num_heads, - dim_head, - depth=depth, - context_dim=context_dim, - time_context_dim=time_context_dim, - dropout=dropout, - ff_in=extra_ff_mix_layer, - use_spatial_context=use_spatial_context, - merge_strategy=merge_strategy, - merge_factor=merge_factor, - checkpoint=use_checkpoint, - use_linear=use_linear_in_transformer, - attn_mode=spatial_transformer_attn_type, - disable_self_attn=disabled_sa, - disable_temporal_crossattention=disable_temporal_crossattention, - max_time_embed_period=max_ddpm_temb_period, - ) - - def get_resblock( - merge_factor, - merge_strategy, - video_kernel_size, - ch, - time_embed_dim, - dropout, - out_ch, - dims, - use_checkpoint, - use_scale_shift_norm, - down=False, - up=False, - ): - return VideoResBlock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - channels=ch, - emb_channels=time_embed_dim, - dropout=dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=down, - up=up, - ) - - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_ch=mult * model_channels, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = mult * model_channels - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - layers.append( - get_attention_layer( - ch, - num_heads, - dim_head, - depth=transformer_depth[level], - context_dim=context_dim, - use_checkpoint=use_checkpoint, - disabled_sa=False, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - ds *= 2 - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_ch=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, - conv_resample, - dims=dims, - out_channels=out_ch, - third_down=time_downup, - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - - self._feature_size += ch - - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - self.middle_block = TimestepEmbedSequential( - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - out_ch=None, - dropout=dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - get_attention_layer( - ch, - num_heads, - dim_head, - depth=transformer_depth_middle, - context_dim=context_dim, - use_checkpoint=use_checkpoint, - ), - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - out_ch=None, - time_embed_dim=time_embed_dim, - dropout=dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch + ich, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_ch=model_channels * mult, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = model_channels * mult - if ds in attention_resolutions: - if num_head_channels == -1: - dim_head = ch // num_heads - else: - num_heads = ch // num_head_channels - dim_head = num_head_channels - - layers.append( - get_attention_layer( - ch, - num_heads, - dim_head, - depth=transformer_depth[level], - context_dim=context_dim, - use_checkpoint=use_checkpoint, - disabled_sa=False, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - ds //= 2 - layers.append( - get_resblock( - merge_factor=merge_factor, - merge_strategy=merge_strategy, - video_kernel_size=video_kernel_size, - ch=ch, - time_embed_dim=time_embed_dim, - dropout=dropout, - out_ch=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample( - ch, - conv_resample, - dims=dims, - out_channels=out_ch, - third_up=time_downup, - ) - ) - - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), - ) - - def forward( - self, - x: th.Tensor, - timesteps: th.Tensor, - context: Optional[th.Tensor] = None, - y: Optional[th.Tensor] = None, - time_context: Optional[th.Tensor] = None, - num_video_frames: Optional[int] = None, - image_only_indicator: Optional[th.Tensor] = None, - ): - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for module in self.input_blocks: - h = module( - h, - emb, - context=context, - image_only_indicator=image_only_indicator, - time_context=time_context, - num_video_frames=num_video_frames, - ) - hs.append(h) - h = self.middle_block( - h, - emb, - context=context, - image_only_indicator=image_only_indicator, - time_context=time_context, - num_video_frames=num_video_frames, - ) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module( - h, - emb, - context=context, - image_only_indicator=image_only_indicator, - time_context=time_context, - num_video_frames=num_video_frames, - ) - h = h.type(x.dtype) - return self.out(h) diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py deleted file mode 100644 index 37449ea63e992b9f89856f1f47c18ba68be8e334..0000000000000000000000000000000000000000 --- a/sgm/modules/diffusionmodules/wrappers.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch -import torch.nn as nn -from packaging import version - -OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" - - -class IdentityWrapper(nn.Module): - def __init__(self, diffusion_model, compile_model: bool = False): - super().__init__() - compile = ( - torch.compile - if (version.parse(torch.__version__) >= version.parse("2.0.0")) - and compile_model - else lambda x: x - ) - self.diffusion_model = compile(diffusion_model) - - def forward(self, *args, **kwargs): - return self.diffusion_model(*args, **kwargs) - - -class OpenAIWrapper(IdentityWrapper): - def forward( - self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs - ) -> torch.Tensor: - x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) - return self.diffusion_model( - x, - timesteps=t, - context=c.get("crossattn", None), - y=c.get("vector", None), - **kwargs, - ) diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py deleted file mode 100644 index 016be35523187ea366db9ade391fe8ee276db60b..0000000000000000000000000000000000000000 --- a/sgm/modules/distributions/distributions.py +++ /dev/null @@ -1,102 +0,0 @@ -import numpy as np -import torch - - -class AbstractDistribution: - def sample(self): - raise NotImplementedError() - - def mode(self): - raise NotImplementedError() - - -class DiracDistribution(AbstractDistribution): - def __init__(self, value): - self.value = value - - def sample(self): - return self.value - - def mode(self): - return self.value - - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters, deterministic=False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to( - device=self.parameters.device - ) - - def sample(self): - x = self.mean + self.std * torch.randn(self.mean.shape).to( - device=self.parameters.device - ) - return x - - def kl(self, other=None): - if self.deterministic: - return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - - def nll(self, sample, dims=[1, 2, 3]): - if self.deterministic: - return torch.Tensor([0.0]) - logtwopi = np.log(2.0 * np.pi) - return 0.5 * torch.sum( - logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, - dim=dims, - ) - - def mode(self): - return self.mean - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 - Compute the KL divergence between two gaussians. - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, torch.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for torch.exp(). - logvar1, logvar2 = [ - x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) - for x in (logvar1, logvar2) - ] - - return 0.5 * ( - -1.0 - + logvar2 - - logvar1 - + torch.exp(logvar1 - logvar2) - + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) - ) diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py deleted file mode 100644 index 97b5ae2b230f89b4dba57e44c4f851478ad86f68..0000000000000000000000000000000000000000 --- a/sgm/modules/ema.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -from torch import nn - - -class LitEma(nn.Module): - def __init__(self, model, decay=0.9999, use_num_upates=True): - super().__init__() - if decay < 0.0 or decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.m_name2s_name = {} - self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) - self.register_buffer( - "num_updates", - torch.tensor(0, dtype=torch.int) - if use_num_upates - else torch.tensor(-1, dtype=torch.int), - ) - - for name, p in model.named_parameters(): - if p.requires_grad: - # remove as '.'-character is not allowed in buffers - s_name = name.replace(".", "") - self.m_name2s_name.update({name: s_name}) - self.register_buffer(s_name, p.clone().detach().data) - - self.collected_params = [] - - def reset_num_updates(self): - del self.num_updates - self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) - - def forward(self, model): - decay = self.decay - - if self.num_updates >= 0: - self.num_updates += 1 - decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) - - one_minus_decay = 1.0 - decay - - with torch.no_grad(): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - - for key in m_param: - if m_param[key].requires_grad: - sname = self.m_name2s_name[key] - shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) - shadow_params[sname].sub_( - one_minus_decay * (shadow_params[sname] - m_param[key]) - ) - else: - assert not key in self.m_name2s_name - - def copy_to(self, model): - m_param = dict(model.named_parameters()) - shadow_params = dict(self.named_buffers()) - for key in m_param: - if m_param[key].requires_grad: - m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) - else: - assert not key in self.m_name2s_name - - def store(self, parameters): - """ - Save the current parameters for restoring later. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. - """ - self.collected_params = [param.clone() for param in parameters] - - def restore(self, parameters): - """ - Restore the parameters stored with the `store` method. - Useful to validate the model with EMA parameters without affecting the - original optimization process. Store the parameters before the - `copy_to` method. After validation (or model saving), use this to - restore the former parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. - """ - for c_param, param in zip(self.collected_params, parameters): - param.data.copy_(c_param.data) diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/sgm/modules/encoders/image_encoder.py b/sgm/modules/encoders/image_encoder.py deleted file mode 100644 index 60d693245bc562987376b7d0fff80086fb936279..0000000000000000000000000000000000000000 --- a/sgm/modules/encoders/image_encoder.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright 2020 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -import importlib - - -def class_for_name(module_name, class_name): - # load the module, will raise ImportError if module cannot be loaded - m = importlib.import_module(module_name) - return getattr(m, class_name) - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, - out_planes, - kernel_size=3, - stride=stride, - padding=dilation, - groups=groups, - bias=False, - dilation=dilation, - padding_mode="reflect", - ) - - -def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" - return nn.Conv2d( - in_planes, - out_planes, - kernel_size=1, - stride=stride, - bias=False, - padding_mode="reflect", - ) - - -class BasicBlock(nn.Module): - expansion = 1 - - def __init__( - self, - inplanes, - planes, - stride=1, - downsample=None, - groups=1, - base_width=64, - dilation=1, - norm_layer=None, - ): - super(BasicBlock, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - # norm_layer = nn.InstanceNorm2d - if groups != 1 or base_width != 64: - raise ValueError("BasicBlock only supports groups=1 and base_width=64") - if dilation > 1: - raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv3x3(inplanes, planes, stride) - self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) - self.relu = nn.ReLU(inplace=True) - self.conv2 = conv3x3(planes, planes) - self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class Bottleneck(nn.Module): - # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) - # while original implementation places the stride at the first 1x1 convolution(self.conv1) - # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. - # This variant is also known as ResNet V1.5 and improves accuracy according to - # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. - - expansion = 4 - - def __init__( - self, - inplanes, - planes, - stride=1, - downsample=None, - groups=1, - base_width=64, - dilation=1, - norm_layer=None, - ): - super(Bottleneck, self).__init__() - if norm_layer is None: - norm_layer = nn.BatchNorm2d - # norm_layer = nn.InstanceNorm2d - width = int(planes * (base_width / 64.0)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 - self.conv1 = conv1x1(inplanes, width) - self.bn1 = norm_layer(width, track_running_stats=False, affine=True) - self.conv2 = conv3x3(width, width, stride, groups, dilation) - self.bn2 = norm_layer(width, track_running_stats=False, affine=True) - self.conv3 = conv1x1(width, planes * self.expansion) - self.bn3 = norm_layer( - planes * self.expansion, track_running_stats=False, affine=True - ) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - identity = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - - return out - - -class conv(nn.Module): - def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): - super(conv, self).__init__() - self.kernel_size = kernel_size - self.conv = nn.Conv2d( - num_in_layers, - num_out_layers, - kernel_size=kernel_size, - stride=stride, - padding=(self.kernel_size - 1) // 2, - padding_mode="reflect", - ) - # self.bn = nn.InstanceNorm2d( - # num_out_layers, track_running_stats=False, affine=True - # ) - self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True) - # self.bn = nn.LayerNorm(num_out_layers) - - def forward(self, x): - return F.elu(self.bn(self.conv(x)), inplace=True) - - -class upconv(nn.Module): - def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): - super(upconv, self).__init__() - self.scale = scale - self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) - - def forward(self, x): - x = nn.functional.interpolate( - x, scale_factor=self.scale, align_corners=True, mode="bilinear" - ) - return self.conv(x) - - -class ResUNet(nn.Module): - def __init__( - self, - encoder="resnet34", - coarse_out_ch=32, - fine_out_ch=32, - norm_layer=None, - coarse_only=False, - ): - super(ResUNet, self).__init__() - assert encoder in [ - "resnet18", - "resnet34", - "resnet50", - "resnet101", - "resnet152", - ], "Incorrect encoder type" - if encoder in ["resnet18", "resnet34"]: - filters = [64, 128, 256, 512] - else: - filters = [256, 512, 1024, 2048] - self.coarse_only = coarse_only - if self.coarse_only: - fine_out_ch = 0 - self.coarse_out_ch = coarse_out_ch - self.fine_out_ch = fine_out_ch - out_ch = coarse_out_ch + fine_out_ch - - # original - layers = [3, 4, 6, 3] - if norm_layer is None: - norm_layer = nn.BatchNorm2d - # norm_layer = nn.InstanceNorm2d - self._norm_layer = norm_layer - self.dilation = 1 - block = BasicBlock - replace_stride_with_dilation = [False, False, False] - self.inplanes = 64 - self.groups = 1 - self.base_width = 64 - self.conv1 = nn.Conv2d( - 3, - self.inplanes, - kernel_size=7, - stride=2, - padding=3, - bias=False, - padding_mode="reflect", - ) - self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) - self.relu = nn.ReLU(inplace=True) - self.layer1 = self._make_layer(block, 64, layers[0], stride=2) - self.layer2 = self._make_layer( - block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] - ) - self.layer3 = self._make_layer( - block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] - ) - - # decoder - self.upconv3 = upconv(filters[2], 128, 3, 2) - self.iconv3 = conv(filters[1] + 128, 128, 3, 1) - self.upconv2 = upconv(128, 64, 3, 2) - self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) - - # fine-level conv - self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) - - def _make_layer(self, block, planes, blocks, stride=1, dilate=False): - norm_layer = self._norm_layer - downsample = None - previous_dilation = self.dilation - if dilate: - self.dilation *= stride - stride = 1 - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - conv1x1(self.inplanes, planes * block.expansion, stride), - norm_layer( - planes * block.expansion, track_running_stats=False, affine=True - ), - ) - - layers = [] - layers.append( - block( - self.inplanes, - planes, - stride, - downsample, - self.groups, - self.base_width, - previous_dilation, - norm_layer, - ) - ) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append( - block( - self.inplanes, - planes, - groups=self.groups, - base_width=self.base_width, - dilation=self.dilation, - norm_layer=norm_layer, - ) - ) - - return nn.Sequential(*layers) - - def skipconnect(self, x1, x2): - diffY = x2.size()[2] - x1.size()[2] - diffX = x2.size()[3] - x1.size()[3] - - x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) - - # for padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd - - x = torch.cat([x2, x1], dim=1) - return x - - def forward(self, x): - x = self.relu(self.bn1(self.conv1(x))) - - x1 = self.layer1(x) - x2 = self.layer2(x1) - x3 = self.layer3(x2) - - x = self.upconv3(x3) - x = self.skipconnect(x2, x) - x = self.iconv3(x) - - x = self.upconv2(x) - x = self.skipconnect(x1, x) - x = self.iconv2(x) - - x_out = self.out_conv(x) - - return x_out - - # if self.coarse_only: - # x_coarse = x_out - # x_fine = None - # else: - # x_coarse = x_out[:, : self.coarse_out_ch, :] - # x_fine = x_out[:, -self.fine_out_ch :, :] - # return x_coarse, x_fine diff --git a/sgm/modules/encoders/image_encoder_v2.py b/sgm/modules/encoders/image_encoder_v2.py deleted file mode 100644 index 72c782b3edee155fa4367e697a94d6b8b6b86b85..0000000000000000000000000000000000000000 --- a/sgm/modules/encoders/image_encoder_v2.py +++ /dev/null @@ -1,160 +0,0 @@ -""" -UNet Network in PyTorch, modified from https://github.com/milesial/Pytorch-UNet -with architecture referenced from https://keras.io/examples/vision/depth_estimation -for monocular depth estimation from RGB images, i.e. one output channel. -""" - -import torch -from torch import nn - - -class UNet(nn.Module): - """ - The overall UNet architecture. - """ - - def __init__(self): - super().__init__() - - self.downscale_blocks = nn.ModuleList( - [ - DownBlock(16, 32), - DownBlock(32, 64), - DownBlock(64, 128), - DownBlock(128, 256), - ] - ) - self.upscale_blocks = nn.ModuleList( - [ - UpBlock(256, 128), - UpBlock(128, 64), - UpBlock(64, 32), - UpBlock(32, 16), - ] - ) - - self.input_conv = nn.Conv2d(3, 16, kernel_size=3, padding="same") - self.output_conv = nn.Conv2d(16, 1, kernel_size=1) - self.bridge = BottleNeckBlock(256) - self.activation = nn.Sigmoid() - - def forward(self, x): - x = self.input_conv(x) - - skip_features = [] - for block in self.downscale_blocks: - c, x = block(x) - skip_features.append(c) - - x = self.bridge(x) - - skip_features.reverse() - for block, skip in zip(self.upscale_blocks, skip_features): - x = block(x, skip) - - x = self.output_conv(x) - x = self.activation(x) - return x - - -class DownBlock(nn.Module): - """ - Module that performs downscaling with residual connections. - """ - - def __init__(self, in_channels, out_channels, padding="same", stride=1): - super().__init__() - self.conv1 = nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=padding, - bias=False, - ) - self.conv2 = nn.Conv2d( - out_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=padding, - bias=False, - ) - self.bn1 = nn.BatchNorm2d(out_channels) - self.bn2 = nn.BatchNorm2d(out_channels) - self.relu = nn.LeakyReLU(0.2) - self.maxpool = nn.MaxPool2d(2) - - def forward(self, x): - d = self.conv1(x) - x = self.bn1(d) - x = self.relu(x) - - x = self.conv2(x) - x = self.bn2(x) - x = self.relu(x) - - x = x + d - p = self.maxpool(x) - return x, p - - -class UpBlock(nn.Module): - """ - Module that performs upscaling after concatenation with skip connections. - """ - - def __init__(self, in_channels, out_channels, padding="same", stride=1): - super().__init__() - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv1 = nn.Conv2d( - in_channels * 2, - in_channels, - kernel_size=3, - stride=stride, - padding=padding, - bias=False, - ) - self.conv2 = nn.Conv2d( - in_channels, - out_channels, - kernel_size=3, - stride=stride, - padding=padding, - bias=False, - ) - self.bn1 = nn.BatchNorm2d(in_channels) - self.bn2 = nn.BatchNorm2d(out_channels) - self.relu = nn.LeakyReLU(0.2) - - def forward(self, x, skip): - x = self.up(x) - x = torch.cat([x, skip], dim=1) - - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - - x = self.conv2(x) - x = self.bn2(x) - x = self.relu(x) - return x - - -class BottleNeckBlock(nn.Module): - """ - BottleNeckBlock that serves as the UNet bridge. - """ - - def __init__(self, channels, padding="same", strides=1): - super().__init__() - self.conv1 = nn.Conv2d(channels, channels, 3, 1, "same") - self.conv2 = nn.Conv2d(channels, channels, 3, 1, "same") - self.relu = nn.LeakyReLU(0.2) - - def forward(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.conv2(x) - x = self.relu(x) - return x \ No newline at end of file diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py deleted file mode 100644 index 9860779362c766f4e9171d98c7411a2b178a842d..0000000000000000000000000000000000000000 --- a/sgm/modules/encoders/modules.py +++ /dev/null @@ -1,1189 +0,0 @@ -import math -from contextlib import nullcontext -from functools import partial -from typing import Dict, List, Optional, Tuple, Union - -import kornia -import numpy as np -import open_clip -import torch -import torch.nn as nn -from einops import rearrange, repeat -from omegaconf import ListConfig - -# from torch.utils.checkpoint import checkpoint - -checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) - -from transformers import ( - ByT5Tokenizer, - CLIPTextModel, - CLIPTokenizer, - T5EncoderModel, - T5Tokenizer, -) - -from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer -from ...modules.diffusionmodules.model import Encoder -from ...modules.diffusionmodules.openaimodel import Timestep -from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule -from ...modules.distributions.distributions import DiagonalGaussianDistribution -from ...util import ( - append_dims, - autocast, - count_params, - default, - disabled_train, - expand_dims_like, - instantiate_from_config, -) - - -class AbstractEmbModel(nn.Module): - def __init__(self): - super().__init__() - self._is_trainable = None - self._ucg_rate = None - self._input_key = None - - @property - def is_trainable(self) -> bool: - return self._is_trainable - - @property - def ucg_rate(self) -> Union[float, torch.Tensor]: - return self._ucg_rate - - @property - def input_key(self) -> str: - return self._input_key - - @is_trainable.setter - def is_trainable(self, value: bool): - self._is_trainable = value - - @ucg_rate.setter - def ucg_rate(self, value: Union[float, torch.Tensor]): - self._ucg_rate = value - - @input_key.setter - def input_key(self, value: str): - self._input_key = value - - @is_trainable.deleter - def is_trainable(self): - del self._is_trainable - - @ucg_rate.deleter - def ucg_rate(self): - del self._ucg_rate - - @input_key.deleter - def input_key(self): - del self._input_key - - -class GeneralConditioner(nn.Module): - OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} - KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} - - def __init__(self, emb_models: Union[List, ListConfig]): - super().__init__() - embedders = [] - for n, embconfig in enumerate(emb_models): - embedder = instantiate_from_config(embconfig) - assert isinstance( - embedder, AbstractEmbModel - ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" - embedder.is_trainable = embconfig.get("is_trainable", False) - embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) - if not embedder.is_trainable: - embedder.train = disabled_train - for param in embedder.parameters(): - param.requires_grad = False - embedder.eval() - print( - f"Initialized embedder #{n}: {embedder.__class__.__name__} " - f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" - ) - - if "input_key" in embconfig: - embedder.input_key = embconfig["input_key"] - elif "input_keys" in embconfig: - embedder.input_keys = embconfig["input_keys"] - else: - raise KeyError( - f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" - ) - - embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) - if embedder.legacy_ucg_val is not None: - embedder.ucg_prng = np.random.RandomState() - - embedders.append(embedder) - self.embedders = nn.ModuleList(embedders) - - def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: - assert embedder.legacy_ucg_val is not None - p = embedder.ucg_rate - val = embedder.legacy_ucg_val - for i in range(len(batch[embedder.input_key])): - if embedder.ucg_prng.choice(2, p=[1 - p, p]): - batch[embedder.input_key][i] = val - return batch - - def forward( - self, batch: Dict, force_zero_embeddings: Optional[List] = None - ) -> Dict: - output = dict() - if force_zero_embeddings is None: - force_zero_embeddings = [] - for embedder in self.embedders: - embedding_context = nullcontext if embedder.is_trainable else torch.no_grad - with embedding_context(): - if hasattr(embedder, "input_key") and (embedder.input_key is not None): - if embedder.legacy_ucg_val is not None: - batch = self.possibly_get_ucg_val(embedder, batch) - emb_out = embedder(batch[embedder.input_key]) - elif hasattr(embedder, "input_keys"): - emb_out = embedder(*[batch[k] for k in embedder.input_keys]) - assert isinstance( - emb_out, (torch.Tensor, list, tuple) - ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" - if not isinstance(emb_out, (list, tuple)): - emb_out = [emb_out] - for emb in emb_out: - out_key = self.OUTPUT_DIM2KEYS[emb.dim()] - if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: - emb = ( - expand_dims_like( - torch.bernoulli( - (1.0 - embedder.ucg_rate) - * torch.ones(emb.shape[0], device=emb.device) - ), - emb, - ) - * emb - ) - if ( - hasattr(embedder, "input_key") - and embedder.input_key in force_zero_embeddings - ): - emb = torch.zeros_like(emb) - if out_key in output: - output[out_key] = torch.cat( - (output[out_key], emb), self.KEY2CATDIM[out_key] - ) - else: - output[out_key] = emb - - # if "num_video_frames" in batch: - # num_frames = batch["num_video_frames"] - # for k in ["crossattn", "concat"]: - # output[k] = repeat(output[k], "b ... -> b t ...", t=num_frames) - # output[k] = rearrange(output[k], "b t ... -> (b t) ...", t=num_frames) - - return output - - def get_unconditional_conditioning( - self, - batch_c: Dict, - batch_uc: Optional[Dict] = None, - force_uc_zero_embeddings: Optional[List[str]] = None, - force_cond_zero_embeddings: Optional[List[str]] = None, - ): - if force_uc_zero_embeddings is None: - force_uc_zero_embeddings = [] - ucg_rates = list() - for embedder in self.embedders: - ucg_rates.append(embedder.ucg_rate) - embedder.ucg_rate = 0.0 - c = self(batch_c, force_cond_zero_embeddings) - uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) - - for embedder, rate in zip(self.embedders, ucg_rates): - embedder.ucg_rate = rate - return c, uc - - -class InceptionV3(nn.Module): - """Wrapper around the https://github.com/mseitzer/pytorch-fid inception - port with an additional squeeze at the end""" - - def __init__(self, normalize_input=False, **kwargs): - super().__init__() - from pytorch_fid import inception - - kwargs["resize_input"] = True - self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) - - def forward(self, inp): - outp = self.model(inp) - - if len(outp) == 1: - return outp[0].squeeze() - - return outp - - -class IdentityEncoder(AbstractEmbModel): - def encode(self, x): - return x - - def forward(self, x): - return x - - -class ClassEmbedder(AbstractEmbModel): - def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): - super().__init__() - self.embedding = nn.Embedding(n_classes, embed_dim) - self.n_classes = n_classes - self.add_sequence_dim = add_sequence_dim - - def forward(self, c): - c = self.embedding(c) - if self.add_sequence_dim: - c = c[:, None, :] - return c - - def get_unconditional_conditioning(self, bs, device="cuda"): - uc_class = ( - self.n_classes - 1 - ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) - uc = torch.ones((bs,), device=device) * uc_class - uc = {self.key: uc.long()} - return uc - - -class ClassEmbedderForMultiCond(ClassEmbedder): - def forward(self, batch, key=None, disable_dropout=False): - out = batch - key = default(key, self.key) - islist = isinstance(batch[key], list) - if islist: - batch[key] = batch[key][0] - c_out = super().forward(batch, key, disable_dropout) - out[key] = [c_out] if islist else c_out - return out - - -class FrozenT5Embedder(AbstractEmbModel): - """Uses the T5 transformer encoder for text""" - - def __init__( - self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True - ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl - super().__init__() - self.tokenizer = T5Tokenizer.from_pretrained(version) - self.transformer = T5EncoderModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - with torch.autocast("cuda", enabled=False): - outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenByT5Embedder(AbstractEmbModel): - """ - Uses the ByT5 transformer encoder for text. Is character-aware. - """ - - def __init__( - self, version="google/byt5-base", device="cuda", max_length=77, freeze=True - ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl - super().__init__() - self.tokenizer = ByT5Tokenizer.from_pretrained(version) - self.transformer = T5EncoderModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - def freeze(self): - self.transformer = self.transformer.eval() - - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - with torch.autocast("cuda", enabled=False): - outputs = self.transformer(input_ids=tokens) - z = outputs.last_hidden_state - return z - - def encode(self, text): - return self(text) - - -class FrozenCLIPEmbedder(AbstractEmbModel): - """Uses the CLIP transformer encoder for text (from huggingface)""" - - LAYERS = ["last", "pooled", "hidden"] - - def __init__( - self, - version="openai/clip-vit-large-patch14", - device="cuda", - max_length=77, - freeze=True, - layer="last", - layer_idx=None, - always_return_pooled=False, - ): # clip-vit-base-patch32 - super().__init__() - assert layer in self.LAYERS - self.tokenizer = CLIPTokenizer.from_pretrained(version) - self.transformer = CLIPTextModel.from_pretrained(version) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - self.layer = layer - self.layer_idx = layer_idx - self.return_pooled = always_return_pooled - if layer == "hidden": - assert layer_idx is not None - assert 0 <= abs(layer_idx) <= 12 - - def freeze(self): - self.transformer = self.transformer.eval() - - for param in self.parameters(): - param.requires_grad = False - - @autocast - def forward(self, text): - batch_encoding = self.tokenizer( - text, - truncation=True, - max_length=self.max_length, - return_length=True, - return_overflowing_tokens=False, - padding="max_length", - return_tensors="pt", - ) - tokens = batch_encoding["input_ids"].to(self.device) - outputs = self.transformer( - input_ids=tokens, output_hidden_states=self.layer == "hidden" - ) - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - if self.return_pooled: - return z, outputs.pooler_output - return z - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPEmbedder2(AbstractEmbModel): - """ - Uses the OpenCLIP transformer encoder for text - """ - - LAYERS = ["pooled", "last", "penultimate"] - - def __init__( - self, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - layer="last", - always_return_pooled=False, - legacy=True, - ): - super().__init__() - assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms( - arch, - device=torch.device("cpu"), - pretrained=version, - ) - del model.visual - self.model = model - - self.device = device - self.max_length = max_length - self.return_pooled = always_return_pooled - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 - else: - raise NotImplementedError() - self.legacy = legacy - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - @autocast - def forward(self, text): - tokens = open_clip.tokenize(text) - z = self.encode_with_transformer(tokens.to(self.device)) - if not self.return_pooled and self.legacy: - return z - if self.return_pooled: - assert not self.legacy - return z[self.layer], z["pooled"] - return z[self.layer] - - def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - if self.legacy: - x = x[self.layer] - x = self.model.ln_final(x) - return x - else: - # x is a dict and will stay a dict - o = x["last"] - o = self.model.ln_final(o) - pooled = self.pool(o, text) - x["pooled"] = pooled - return x - - def pool(self, x, text): - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = ( - x[torch.arange(x.shape[0]), text.argmax(dim=-1)] - @ self.model.text_projection - ) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): - outputs = {} - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - 1: - outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD - if ( - self.model.transformer.grad_checkpointing - and not torch.jit.is_scripting() - ): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - outputs["last"] = x.permute(1, 0, 2) # LND -> NLD - return outputs - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPEmbedder(AbstractEmbModel): - LAYERS = [ - # "pooled", - "last", - "penultimate", - ] - - def __init__( - self, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - layer="last", - ): - super().__init__() - assert layer in self.LAYERS - model, _, _ = open_clip.create_model_and_transforms( - arch, - device=torch.device("cpu"), - pretrained=version, - ) - del model.visual - self.model = model - - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - self.layer = layer - if self.layer == "last": - self.layer_idx = 0 - elif self.layer == "penultimate": - self.layer_idx = 1 - else: - raise NotImplementedError() - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - def forward(self, text): - tokens = open_clip.tokenize(text) - z = self.encode_with_transformer(tokens.to(self.device)) - return z - - def encode_with_transformer(self, text): - x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] - x = x + self.model.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.model.ln_final(x) - return x - - def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): - for i, r in enumerate(self.model.transformer.resblocks): - if i == len(self.model.transformer.resblocks) - self.layer_idx: - break - if ( - self.model.transformer.grad_checkpointing - and not torch.jit.is_scripting() - ): - x = checkpoint(r, x, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - def encode(self, text): - return self(text) - - -class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): - """ - Uses the OpenCLIP vision transformer encoder for images - """ - - def __init__( - self, - arch="ViT-H-14", - version="laion2b_s32b_b79k", - device="cuda", - max_length=77, - freeze=True, - antialias=True, - ucg_rate=0.0, - unsqueeze_dim=False, - repeat_to_max_len=False, - num_image_crops=0, - output_tokens=False, - init_device=None, - ): - super().__init__() - model, _, _ = open_clip.create_model_and_transforms( - arch, - device=torch.device(default(init_device, "cpu")), - pretrained=version, - ) - del model.transformer - self.model = model - self.max_crops = num_image_crops - self.pad_to_max_len = self.max_crops > 0 - self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) - self.device = device - self.max_length = max_length - if freeze: - self.freeze() - - self.antialias = antialias - - self.register_buffer( - "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False - ) - self.register_buffer( - "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False - ) - self.ucg_rate = ucg_rate - self.unsqueeze_dim = unsqueeze_dim - self.stored_batch = None - self.model.visual.output_tokens = output_tokens - self.output_tokens = output_tokens - - def preprocess(self, x): - # normalize to [0,1] - x = kornia.geometry.resize( - x, - (224, 224), - interpolation="bicubic", - align_corners=True, - antialias=self.antialias, - ) - x = (x + 1.0) / 2.0 - # renormalize according to clip - x = kornia.enhance.normalize(x, self.mean, self.std) - return x - - def freeze(self): - self.model = self.model.eval() - for param in self.parameters(): - param.requires_grad = False - - @autocast - def forward(self, image, no_dropout=False): - z = self.encode_with_vision_transformer(image) - tokens = None - if self.output_tokens: - z, tokens = z[0], z[1] - z = z.to(image.dtype) - if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): - z = ( - torch.bernoulli( - (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) - )[:, None] - * z - ) - if tokens is not None: - tokens = ( - expand_dims_like( - torch.bernoulli( - (1.0 - self.ucg_rate) - * torch.ones(tokens.shape[0], device=tokens.device) - ), - tokens, - ) - * tokens - ) - if self.unsqueeze_dim: - z = z[:, None, :] - if self.output_tokens: - assert not self.repeat_to_max_len - assert not self.pad_to_max_len - return tokens, z - if self.repeat_to_max_len: - if z.dim() == 2: - z_ = z[:, None, :] - else: - z_ = z - return repeat(z_, "b 1 d -> b n d", n=self.max_length), z - elif self.pad_to_max_len: - assert z.dim() == 3 - z_pad = torch.cat( - ( - z, - torch.zeros( - z.shape[0], - self.max_length - z.shape[1], - z.shape[2], - device=z.device, - ), - ), - 1, - ) - return z_pad, z_pad[:, 0, ...] - return z - - def encode_with_vision_transformer(self, img): - # if self.max_crops > 0: - # img = self.preprocess_by_cropping(img) - if img.dim() == 5: - assert self.max_crops == img.shape[1] - img = rearrange(img, "b n c h w -> (b n) c h w") - img = self.preprocess(img) - if not self.output_tokens: - assert not self.model.visual.output_tokens - x = self.model.visual(img) - tokens = None - else: - assert self.model.visual.output_tokens - x, tokens = self.model.visual(img) - if self.max_crops > 0: - x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) - # drop out between 0 and all along the sequence axis - x = ( - torch.bernoulli( - (1.0 - self.ucg_rate) - * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) - ) - * x - ) - if tokens is not None: - tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) - print( - f"You are running very experimental token-concat in {self.__class__.__name__}. " - f"Check what you are doing, and then remove this message." - ) - if self.output_tokens: - return x, tokens - return x - - def encode(self, text): - return self(text) - - -class FrozenCLIPT5Encoder(AbstractEmbModel): - def __init__( - self, - clip_version="openai/clip-vit-large-patch14", - t5_version="google/t5-v1_1-xl", - device="cuda", - clip_max_length=77, - t5_max_length=77, - ): - super().__init__() - self.clip_encoder = FrozenCLIPEmbedder( - clip_version, device, max_length=clip_max_length - ) - self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print( - f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " - f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." - ) - - def encode(self, text): - return self(text) - - def forward(self, text): - clip_z = self.clip_encoder.encode(text) - t5_z = self.t5_encoder.encode(text) - return [clip_z, t5_z] - - -class SpatialRescaler(nn.Module): - def __init__( - self, - n_stages=1, - method="bilinear", - multiplier=0.5, - in_channels=3, - out_channels=None, - bias=False, - wrap_video=False, - kernel_size=1, - remap_output=False, - ): - super().__init__() - self.n_stages = n_stages - assert self.n_stages >= 0 - assert method in [ - "nearest", - "linear", - "bilinear", - "trilinear", - "bicubic", - "area", - ] - self.multiplier = multiplier - self.interpolator = partial(torch.nn.functional.interpolate, mode=method) - self.remap_output = out_channels is not None or remap_output - if self.remap_output: - print( - f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." - ) - self.channel_mapper = nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - bias=bias, - padding=kernel_size // 2, - ) - self.wrap_video = wrap_video - - def forward(self, x): - if self.wrap_video and x.ndim == 5: - B, C, T, H, W = x.shape - x = rearrange(x, "b c t h w -> b t c h w") - x = rearrange(x, "b t c h w -> (b t) c h w") - - for stage in range(self.n_stages): - x = self.interpolator(x, scale_factor=self.multiplier) - - if self.wrap_video: - x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) - x = rearrange(x, "b t c h w -> b c t h w") - if self.remap_output: - x = self.channel_mapper(x) - return x - - def encode(self, x): - return self(x) - - -class LowScaleEncoder(nn.Module): - def __init__( - self, - model_config, - linear_start, - linear_end, - timesteps=1000, - max_noise_level=250, - output_size=64, - scale_factor=1.0, - ): - super().__init__() - self.max_noise_level = max_noise_level - self.model = instantiate_from_config(model_config) - self.augmentation_schedule = self.register_schedule( - timesteps=timesteps, linear_start=linear_start, linear_end=linear_end - ) - self.out_size = output_size - self.scale_factor = scale_factor - - def register_schedule( - self, - beta_schedule="linear", - timesteps=1000, - linear_start=1e-4, - linear_end=2e-2, - cosine_s=8e-3, - ): - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s, - ) - alphas = 1.0 - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - assert ( - alphas_cumprod.shape[0] == self.num_timesteps - ), "alphas have to be defined for each timestep" - - to_torch = partial(torch.tensor, dtype=torch.float32) - - self.register_buffer("betas", to_torch(betas)) - self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) - self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) - self.register_buffer( - "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) - ) - self.register_buffer( - "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) - ) - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - return ( - extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise - ) - - def forward(self, x): - z = self.model.encode(x) - if isinstance(z, DiagonalGaussianDistribution): - z = z.sample() - z = z * self.scale_factor - noise_level = torch.randint( - 0, self.max_noise_level, (x.shape[0],), device=x.device - ).long() - z = self.q_sample(z, noise_level) - if self.out_size is not None: - z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") - return z, noise_level - - def decode(self, z): - z = z / self.scale_factor - return self.model.decode(z) - - -class ConcatTimestepEmbedderND(AbstractEmbModel): - """embeds each dimension independently and concatenates them""" - - def __init__(self, outdim): - super().__init__() - self.timestep = Timestep(outdim) - self.outdim = outdim - - def forward(self, x): - if x.ndim == 1: - x = x[:, None] - assert len(x.shape) == 2 - b, dims = x.shape[0], x.shape[1] - x = rearrange(x, "b d -> (b d)") - emb = self.timestep(x) - emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) - return emb - - -class GaussianEncoder(Encoder, AbstractEmbModel): - def __init__( - self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs - ): - super().__init__(*args, **kwargs) - self.posterior = DiagonalGaussianRegularizer() - self.weight = weight - self.flatten_output = flatten_output - - def forward(self, x) -> Tuple[Dict, torch.Tensor]: - z = super().forward(x) - z, log = self.posterior(z) - log["loss"] = log["kl_loss"] - log["weight"] = self.weight - if self.flatten_output: - z = rearrange(z, "b c h w -> b (h w ) c") - return log, z - - -class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): - def __init__( - self, - n_cond_frames: int, - n_copies: int, - encoder_config: dict, - sigma_sampler_config: Optional[dict] = None, - sigma_cond_config: Optional[dict] = None, - is_ae: bool = False, - scale_factor: float = 1.0, - disable_encoder_autocast: bool = False, - en_and_decode_n_samples_a_time: Optional[int] = None, - ): - super().__init__() - - self.n_cond_frames = n_cond_frames - self.n_copies = n_copies - self.encoder = instantiate_from_config(encoder_config) - self.sigma_sampler = ( - instantiate_from_config(sigma_sampler_config) - if sigma_sampler_config is not None - else None - ) - self.sigma_cond = ( - instantiate_from_config(sigma_cond_config) - if sigma_cond_config is not None - else None - ) - self.is_ae = is_ae - self.scale_factor = scale_factor - self.disable_encoder_autocast = disable_encoder_autocast - self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time - - def forward( - self, vid: torch.Tensor - ) -> Union[ - torch.Tensor, - Tuple[torch.Tensor, torch.Tensor], - Tuple[torch.Tensor, dict], - Tuple[Tuple[torch.Tensor, torch.Tensor], dict], - ]: - if self.sigma_sampler is not None: - b = vid.shape[0] // self.n_cond_frames - sigmas = self.sigma_sampler(b).to(vid.device) - if self.sigma_cond is not None: - sigma_cond = self.sigma_cond(sigmas) - sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) - sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) - noise = torch.randn_like(vid) - vid = vid + noise * append_dims(sigmas, vid.ndim) - - with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): - n_samples = ( - self.en_and_decode_n_samples_a_time - if self.en_and_decode_n_samples_a_time is not None - else vid.shape[0] - ) - n_rounds = math.ceil(vid.shape[0] / n_samples) - all_out = [] - for n in range(n_rounds): - if self.is_ae: - out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) - else: - out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) - all_out.append(out) - - vid = torch.cat(all_out, dim=0) - vid *= self.scale_factor - - vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) - vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) - # modified for svd - # vid = repeat(vid, "b 1 c h w -> b t c h w", t=self.n_copies) - - return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid - - return return_val - - -class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): - def __init__( - self, - open_clip_embedding_config: Dict, - n_cond_frames: int, - n_copies: int, - ): - super().__init__() - - self.n_cond_frames = n_cond_frames - self.n_copies = n_copies - self.open_clip = instantiate_from_config(open_clip_embedding_config) - - def forward(self, vid): - vid = self.open_clip(vid) - vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) - vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) - - return vid - - -class PixelNeRFEmbedder(AbstractEmbModel): - def __init__( - self, - image_encoder_config: dict, - pixelnerf_encoder_config: dict, - render_size: int, - num_video_frames: int, - ): - super().__init__() - self.render_size = render_size - self.num_video_frames = num_video_frames - self.image_encoder = instantiate_from_config(image_encoder_config) - self.pixelnerf_encoder = instantiate_from_config(pixelnerf_encoder_config) - - def forward(self, pixelnerf_input): - if "source_index" not in pixelnerf_input: - source_images = pixelnerf_input["frames"][:, 0] - image_feats = self.image_encoder(source_images) - image_feats = image_feats[:, None] - source_cameras = pixelnerf_input["cameras"][:, :1] - else: - # source_images = pixelnerf_input["frames"][ - # :, pixelnerf_input["source_index"] - # ] - source_images = pixelnerf_input["source_images"] - n_source_images = source_images.shape[1] - source_images = rearrange(source_images, "b t c h w -> (b t) c h w") - image_feats = self.image_encoder(source_images) - image_feats = rearrange( - image_feats, "(b t) c h w -> b t c h w", t=n_source_images - ) - source_cameras = pixelnerf_input["source_cameras"] - cameras = pixelnerf_input["cameras"] - target_cameras = cameras[:, :] - # source_images = source_images[:, None, ...] - source_c2ws = source_cameras[..., :16].reshape(*source_cameras.shape[:-1], 4, 4) - source_intrinsics = source_cameras[..., 16:].reshape( - *source_cameras.shape[:-1], 3, 3 - ) - target_c2ws = target_cameras[..., :16].reshape(*target_cameras.shape[:-1], 4, 4) - target_intrinsics = target_cameras[..., 16:].reshape( - *target_cameras.shape[:-1], 3, 3 - ) - - rgb, feats = self.pixelnerf_encoder( - image_feats, - source_c2ws, - source_intrinsics, - target_c2ws, - target_intrinsics, - self.render_size, - ) - - rgb = rearrange(rgb, "b t c h w -> (b t) c h w") - feats = rearrange(feats, "b t c h w -> (b t) c h w") - - return rgb, feats - - -class ExtraConditioner(GeneralConditioner): - def forward(self, batch: Dict, force_zero_embeddings: List | None = None) -> Dict: - bs = batch["frames"].shape[0] - num_frames = batch["num_video_frames"] - output = dict() - if force_zero_embeddings is None: - force_zero_embeddings = [] - for embedder in self.embedders: - embedding_context = nullcontext if embedder.is_trainable else torch.no_grad - with embedding_context(): - if hasattr(embedder, "input_key") and (embedder.input_key is not None): - if embedder.legacy_ucg_val is not None: - batch = self.possibly_get_ucg_val(embedder, batch) - emb_out = embedder(batch[embedder.input_key]) - elif hasattr(embedder, "input_keys"): - emb_out = embedder(*[batch[k] for k in embedder.input_keys]) - assert isinstance( - emb_out, (torch.Tensor, list, tuple) - ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" - if not isinstance(emb_out, (list, tuple)): - emb_out = [emb_out] - if isinstance(embedder, PixelNeRFEmbedder): - # a hack for pixelnerf input - output["rgb"] = emb_out[0] - emb_out = emb_out[1:] - for emb in emb_out: - out_key = self.OUTPUT_DIM2KEYS[emb.dim()] - if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: - emb = ( - expand_dims_like( - torch.bernoulli( - (1.0 - embedder.ucg_rate) - * torch.ones(emb.shape[0], device=emb.device) - ), - emb, - ) - * emb - ) - if ( - hasattr(embedder, "input_key") - and embedder.input_key in force_zero_embeddings - ): - emb = torch.zeros_like(emb) - if out_key in output: - output[out_key] = torch.cat( - (output[out_key], emb), self.KEY2CATDIM[out_key] - ) - else: - output[out_key] = emb - - if out_key in ["crossattn", "concat"]: - if output[out_key].shape[0] != bs: - output[out_key] = repeat( - output[out_key], "b ... -> (b t) ...", t=num_frames - ) - return output diff --git a/sgm/modules/encoders/pixelnerf.py b/sgm/modules/encoders/pixelnerf.py deleted file mode 100644 index 515699c3aa52097e27ddde98c3491547c2e3a0b7..0000000000000000000000000000000000000000 --- a/sgm/modules/encoders/pixelnerf.py +++ /dev/null @@ -1,368 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.autograd.profiler as profiler -import numpy as np -from einops import rearrange, repeat, einsum - -from .math_utils import get_ray_limits_box, linspace - -from ...modules.diffusionmodules.openaimodel import Timestep - - -class ImageEncoder(nn.Module): - def __init__(self, output_dim: int = 64) -> None: - super().__init__() - self.output_dim = output_dim - - def forward(self, image): - return image - - -class PositionalEncoding(torch.nn.Module): - """ - Implement NeRF's positional encoding - """ - - def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): - super().__init__() - self.num_freqs = num_freqs - self.d_in = d_in - self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) - self.d_out = self.num_freqs * 2 * d_in - self.include_input = include_input - if include_input: - self.d_out += d_in - # f1 f1 f2 f2 ... to multiply x by - self.register_buffer( - "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) - ) - # 0 pi/2 0 pi/2 ... so that - # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) - _phases = torch.zeros(2 * self.num_freqs) - _phases[1::2] = np.pi * 0.5 - self.register_buffer("_phases", _phases.view(1, -1, 1)) - - def forward(self, x): - """ - Apply positional encoding (new implementation) - :param x (batch, self.d_in) - :return (batch, self.d_out) - """ - with profiler.record_function("positional_enc"): - # embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) - embed = repeat(x, "... C -> ... N C", N=self.num_freqs * 2) - embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) - embed = rearrange(embed, "... N C -> ... (N C)") - if self.include_input: - embed = torch.cat((x, embed), dim=-1) - return embed - - -class RayGenerator(torch.nn.Module): - """ - from camera pose and intrinsics to ray origins and directions - """ - - def __init__(self): - super().__init__() - ( - self.ray_origins_h, - self.ray_directions, - self.depths, - self.image_coords, - self.rendering_options, - ) = (None, None, None, None, None) - - def forward(self, cam2world_matrix, intrinsics, render_size): - """ - Create batches of rays and return origins and directions. - - cam2world_matrix: (N, 4, 4) - intrinsics: (N, 3, 3) - render_size: int - - ray_origins: (N, M, 3) - ray_dirs: (N, M, 2) - """ - - N, M = cam2world_matrix.shape[0], render_size**2 - cam_locs_world = cam2world_matrix[:, :3, 3] - fx = intrinsics[:, 0, 0] - fy = intrinsics[:, 1, 1] - cx = intrinsics[:, 0, 2] - cy = intrinsics[:, 1, 2] - sk = intrinsics[:, 0, 1] - - uv = torch.stack( - torch.meshgrid( - torch.arange( - render_size, dtype=torch.float32, device=cam2world_matrix.device - ), - torch.arange( - render_size, dtype=torch.float32, device=cam2world_matrix.device - ), - indexing="ij", - ) - ) - uv = uv.flip(0).reshape(2, -1).transpose(1, 0) - uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) - - x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) - y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) - z_cam = torch.ones((N, M), device=cam2world_matrix.device) - - x_lift = ( - ( - x_cam - - cx.unsqueeze(-1) - + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) - - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1) - ) - / fx.unsqueeze(-1) - * z_cam - ) - y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam - - cam_rel_points = torch.stack( - (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1 - ) - - # NOTE: this should be named _blender2opencv - _opencv2blender = ( - torch.tensor( - [ - [1, 0, 0, 0], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1], - ], - dtype=torch.float32, - device=cam2world_matrix.device, - ) - .unsqueeze(0) - .repeat(N, 1, 1) - ) - - cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) - - world_rel_points = torch.bmm( - cam2world_matrix, cam_rel_points.permute(0, 2, 1) - ).permute(0, 2, 1)[:, :, :3] - - ray_dirs = world_rel_points - cam_locs_world[:, None, :] - ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) - - ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) - - return ray_origins, ray_dirs - - -class RaySampler(torch.nn.Module): - def __init__( - self, - num_samples_per_ray, - bbox_length=1.0, - near=0.5, - far=10000.0, - disparity=False, - ): - super().__init__() - self.num_samples_per_ray = num_samples_per_ray - self.bbox_length = bbox_length - self.near = near - self.far = far - self.disparity = disparity - - def forward(self, ray_origins, ray_directions): - if not self.disparity: - t_start, t_end = get_ray_limits_box( - ray_origins, ray_directions, 2 * self.bbox_length - ) - else: - t_start = torch.full_like(ray_origins, self.near) - t_end = torch.full_like(ray_origins, self.far) - is_ray_valid = t_end > t_start - if torch.any(is_ray_valid).item(): - t_start[~is_ray_valid] = t_start[is_ray_valid].min() - t_end[~is_ray_valid] = t_start[is_ray_valid].max() - - if not self.disparity: - depths = linspace(t_start, t_end, self.num_samples_per_ray) - depths += ( - torch.rand_like(depths) - * (t_end - t_start) - / (self.num_samples_per_ray - 1) - ) - else: - step = 1.0 / self.num_samples_per_ray - z_steps = torch.linspace( - 0, 1 - step, self.num_samples_per_ray, device=ray_origins.device - ) - z_steps += torch.rand_like(z_steps) * step - depths = 1 / (1 / self.near * (1 - z_steps) + 1 / self.far * z_steps) - depths = depths[..., None, None, None] - - return ray_origins[None] + ray_directions[None] * depths - - -class PixelNeRF(torch.nn.Module): - def __init__( - self, - num_samples_per_ray: int = 128, - feature_dim: int = 64, - interp: str = "bilinear", - padding: str = "border", - disparity: bool = False, - near: float = 0.5, - far: float = 10000.0, - use_feats_std: bool = False, - use_pos_emb: bool = False, - ) -> None: - super().__init__() - # self.positional_encoder = Timestep(3) # TODO - self.num_samples_per_ray = num_samples_per_ray - self.ray_generator = RayGenerator() - self.ray_sampler = RaySampler( - num_samples_per_ray, near=near, far=far, disparity=disparity - ) # TODO - self.interp = interp - self.padding = padding - - self.positional_encoder = PositionalEncoding() - - # self.feature_aggregator = nn.Linear(128, 129) # TODO - self.use_feats_std = use_feats_std - self.use_pos_emb = use_pos_emb - d_in = feature_dim - if use_feats_std: - d_in += feature_dim - if use_pos_emb: - d_in += self.positional_encoder.d_out - self.feature_aggregator = nn.Sequential( - nn.Linear(d_in, 128), - nn.ReLU(), - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 129), - ) - - # self.decoder = nn.Linear(128, 131) # TODO - self.decoder = nn.Sequential( - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 128), - nn.ReLU(), - nn.Linear(128, 131), - ) - - def project(self, ray_samples, source_c2ws, source_instrincs): - # TODO: implement - # S for number of source cameras - # ray_samples: [B, N, H * W, N_sample, 3] - # source_c2ws: [B, S, 4, 4] - # source_intrinsics: [B, S, 3, 3] - # return [B, S, N, H * W, N_sample, 2] - S = source_c2ws.shape[1] - B = ray_samples.shape[0] - N = ray_samples.shape[1] - HW = ray_samples.shape[2] - ray_samples = repeat( - ray_samples, - "B N HW N_sample C -> B S N HW N_sample C", - S=source_c2ws.shape[1], - ) - padding = torch.ones((B, S, N, HW, self.num_samples_per_ray, 1)).to(ray_samples) - ray_samples_homo = torch.cat([ray_samples, padding], dim=-1) - source_c2ws = repeat(source_c2ws, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) - source_instrincs = repeat(source_instrincs, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) - source_w2c = source_c2ws.inverse() - projected_samples = einsum( - source_w2c, ray_samples_homo, "... i j, ... j -> ... i" - )[..., :3] - # NOTE: assumes opengl convention - projected_samples = -1 * projected_samples[..., :2] / projected_samples[..., 2:] - # NOTE: intrinsics here are normalized by resolution - fx = source_instrincs[..., 0, 0] - fy = source_instrincs[..., 1, 1] - cx = source_instrincs[..., 0, 2] - cy = source_instrincs[..., 1, 2] - x = projected_samples[..., 0] * fx + cx - # negative sign here is caused by opengl, F.grid_sample is consistent with openCV convention - y = -projected_samples[..., 1] * fy + cy - - return torch.stack([x, y], dim=-1) - - def forward( - self, image_feats, source_c2ws, source_intrinsics, c2ws, intrinsics, render_size - ): - # image_feats: [B S C H W] - B = c2ws.shape[0] - T = c2ws.shape[1] - ray_origins, ray_directions = self.ray_generator( - c2ws.reshape(-1, 4, 4), intrinsics.reshape(-1, 3, 3), render_size - ) # [B * N, H * W, 3] - # breakpoint() - - ray_samples = self.ray_sampler( - ray_origins, ray_directions - ) # [N_sample, B * N, H * W, 3] - ray_samples = rearrange(ray_samples, "Ns (B N) HW C -> B N HW Ns C", B=B) - - projected_samples = self.project(ray_samples, source_c2ws, source_intrinsics) - # # debug - # p = projected_samples[:, :, 0, :, 0, :] - # p = p.reshape(p.shape[0] * p.shape[1], *p.shape[2:]) - - # breakpoint() - - # image_feats = repeat(image_feats, "B S C H W -> (B S N) C H W", N=T) - image_feats = rearrange(image_feats, "B S C H W -> (B S) C H W") - projected_samples = rearrange( - projected_samples, "B S N HW Ns xy -> (B S) (N Ns) HW xy" - ) - # make sure the projected samples are in the range of [-1, 1], as required by F.grid_sample - joint = F.grid_sample( - image_feats, - projected_samples * 2.0 - 1.0, - padding_mode=self.padding, - mode=self.interp, - align_corners=True, - ) - # print("image_feats", image_feats.max(), image_feats.min()) - # print("samples", projected_samples.max(), projected_samples.min()) - joint = rearrange( - joint, - "(B S) C (N Ns) HW -> B S N HW Ns C", - B=B, - Ns=self.num_samples_per_ray, - ) - - reduced = torch.mean(joint, dim=1) # reduce on source dimension - if self.use_feats_std: - if not joint.shape[1] == 1: - reduced = torch.cat((reduced, joint.std(dim=1)), dim=-1) - else: - reduced = torch.cat((reduced, torch.zeros_like(reduced)), dim=-1) - - if self.use_pos_emb: - reduced = torch.cat((reduced, self.positional_encoder(ray_samples)), dim=-1) - reduced = self.feature_aggregator(reduced) - - feats, weights = reduced.split([reduced.shape[-1] - 1, 1], dim=-1) - # feats: [B, N, H * W, N_samples, N_c] - # weights: [B, N, H * W, N_samples, 1] - weights = F.softmax(weights, dim=-2) - - feats = torch.sum(feats * weights, dim=-2) - - rgb, feats = self.decoder(feats).split([3, 128], dim=-1) - - rgb = F.sigmoid(rgb) - rgb = rearrange(rgb, "B N (H W) C -> B N C H W", H=render_size) - feats = rearrange(feats, "B N (H W) C -> B N C H W", H=render_size) - - # print(rgb.max(), rgb.min()) - # print(feats.max(), feats.min()) - - return rgb, feats diff --git a/sgm/modules/video_attention.py b/sgm/modules/video_attention.py deleted file mode 100644 index 783395aa554144936766b57380f35dab29c093c3..0000000000000000000000000000000000000000 --- a/sgm/modules/video_attention.py +++ /dev/null @@ -1,301 +0,0 @@ -import torch - -from ..modules.attention import * -from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding - - -class TimeMixSequential(nn.Sequential): - def forward(self, x, context=None, timesteps=None): - for layer in self: - x = layer(x, context, timesteps) - - return x - - -class VideoTransformerBlock(nn.Module): - ATTENTION_MODES = { - "softmax": CrossAttention, - "softmax-xformers": MemoryEfficientCrossAttention, - } - - def __init__( - self, - dim, - n_heads, - d_head, - dropout=0.0, - context_dim=None, - gated_ff=True, - checkpoint=True, - timesteps=None, - ff_in=False, - inner_dim=None, - attn_mode="softmax", - disable_self_attn=False, - disable_temporal_crossattention=False, - switch_temporal_ca_to_sa=False, - ): - super().__init__() - - attn_cls = self.ATTENTION_MODES[attn_mode] - - self.ff_in = ff_in or inner_dim is not None - if inner_dim is None: - inner_dim = dim - - assert int(n_heads * d_head) == inner_dim - - self.is_res = inner_dim == dim - - if self.ff_in: - self.norm_in = nn.LayerNorm(dim) - self.ff_in = FeedForward( - dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff - ) - - self.timesteps = timesteps - self.disable_self_attn = disable_self_attn - if self.disable_self_attn: - self.attn1 = attn_cls( - query_dim=inner_dim, - heads=n_heads, - dim_head=d_head, - context_dim=context_dim, - dropout=dropout, - ) # is a cross-attention - else: - self.attn1 = attn_cls( - query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - - self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) - - if disable_temporal_crossattention: - if switch_temporal_ca_to_sa: - raise ValueError - else: - self.attn2 = None - else: - self.norm2 = nn.LayerNorm(inner_dim) - if switch_temporal_ca_to_sa: - self.attn2 = attn_cls( - query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - else: - self.attn2 = attn_cls( - query_dim=inner_dim, - context_dim=context_dim, - heads=n_heads, - dim_head=d_head, - dropout=dropout, - ) # is self-attn if context is none - - self.norm1 = nn.LayerNorm(inner_dim) - self.norm3 = nn.LayerNorm(inner_dim) - self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa - - self.checkpoint = checkpoint - if self.checkpoint: - print(f"{self.__class__.__name__} is using checkpointing") - - def forward( - self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None - ) -> torch.Tensor: - if self.checkpoint: - return checkpoint(self._forward, x, context, timesteps) - else: - return self._forward(x, context, timesteps=timesteps) - - def _forward(self, x, context=None, timesteps=None): - assert self.timesteps or timesteps - assert not (self.timesteps and timesteps) or self.timesteps == timesteps - timesteps = self.timesteps or timesteps - B, S, C = x.shape - x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) - - if self.ff_in: - x_skip = x - x = self.ff_in(self.norm_in(x)) - if self.is_res: - x += x_skip - - if self.disable_self_attn: - x = self.attn1(self.norm1(x), context=context) + x - else: - x = self.attn1(self.norm1(x)) + x - - if self.attn2 is not None: - if self.switch_temporal_ca_to_sa: - x = self.attn2(self.norm2(x)) + x - else: - x = self.attn2(self.norm2(x), context=context) + x - x_skip = x - x = self.ff(self.norm3(x)) - if self.is_res: - x += x_skip - - x = rearrange( - x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps - ) - return x - - def get_last_layer(self): - return self.ff.net[-1].weight - - -class SpatialVideoTransformer(SpatialTransformer): - def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - use_linear=False, - context_dim=None, - use_spatial_context=False, - timesteps=None, - merge_strategy: str = "fixed", - merge_factor: float = 0.5, - time_context_dim=None, - ff_in=False, - checkpoint=False, - time_depth=1, - attn_mode="softmax", - disable_self_attn=False, - disable_temporal_crossattention=False, - max_time_embed_period: int = 10000, - ): - super().__init__( - in_channels, - n_heads, - d_head, - depth=depth, - dropout=dropout, - attn_type=attn_mode, - use_checkpoint=checkpoint, - context_dim=context_dim, - use_linear=use_linear, - disable_self_attn=disable_self_attn, - ) - self.time_depth = time_depth - self.depth = depth - self.max_time_embed_period = max_time_embed_period - - time_mix_d_head = d_head - n_time_mix_heads = n_heads - - time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) - - inner_dim = n_heads * d_head - if use_spatial_context: - time_context_dim = context_dim - - self.time_stack = nn.ModuleList( - [ - VideoTransformerBlock( - inner_dim, - n_time_mix_heads, - time_mix_d_head, - dropout=dropout, - context_dim=time_context_dim, - timesteps=timesteps, - checkpoint=checkpoint, - ff_in=ff_in, - inner_dim=time_mix_inner_dim, - attn_mode=attn_mode, - disable_self_attn=disable_self_attn, - disable_temporal_crossattention=disable_temporal_crossattention, - ) - for _ in range(self.depth) - ] - ) - - assert len(self.time_stack) == len(self.transformer_blocks) - - self.use_spatial_context = use_spatial_context - self.in_channels = in_channels - - time_embed_dim = self.in_channels * 4 - self.time_pos_embed = nn.Sequential( - linear(self.in_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, self.in_channels), - ) - - self.time_mixer = AlphaBlender( - alpha=merge_factor, merge_strategy=merge_strategy - ) - - def forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - time_context: Optional[torch.Tensor] = None, - timesteps: Optional[int] = None, - image_only_indicator: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - _, _, h, w = x.shape - x_in = x - spatial_context = None - if exists(context): - spatial_context = context - - if self.use_spatial_context: - assert ( - context.ndim == 3 - ), f"n dims of spatial context should be 3 but are {context.ndim}" - - time_context = context - time_context_first_timestep = time_context[::timesteps] - time_context = repeat( - time_context_first_timestep, "b ... -> (b n) ...", n=h * w - ) - elif time_context is not None and not self.use_spatial_context: - time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) - if time_context.ndim == 2: - time_context = rearrange(time_context, "b c -> b 1 c") - - x = self.norm(x) - if not self.use_linear: - x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c") - if self.use_linear: - x = self.proj_in(x) - - num_frames = torch.arange(timesteps, device=x.device) - num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) - num_frames = rearrange(num_frames, "b t -> (b t)") - t_emb = timestep_embedding( - num_frames, - self.in_channels, - repeat_only=False, - max_period=self.max_time_embed_period, - ) - emb = self.time_pos_embed(t_emb) - emb = emb[:, None, :] - - for it_, (block, mix_block) in enumerate( - zip(self.transformer_blocks, self.time_stack) - ): - x = block( - x, - context=spatial_context, - ) - - x_mix = x - x_mix = x_mix + emb - - x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) - x = self.time_mixer( - x_spatial=x, - x_temporal=x_mix, - image_only_indicator=image_only_indicator, - ) - if self.use_linear: - x = self.proj_out(x) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) - if not self.use_linear: - x = self.proj_out(x) - out = x + x_in - return out diff --git a/sgm/sampling/hier.py b/sgm/sampling/hier.py deleted file mode 100644 index 375261c89b9f2fb38b2b853af8872ef4f0f500af..0000000000000000000000000000000000000000 --- a/sgm/sampling/hier.py +++ /dev/null @@ -1 +0,0 @@ -# hierachical sampling, (autogressive sampling like GeNVS) diff --git a/sgm/util.py b/sgm/util.py deleted file mode 100644 index 49cc0df0e14326087e1adaf515b76137c2977fbe..0000000000000000000000000000000000000000 --- a/sgm/util.py +++ /dev/null @@ -1,310 +0,0 @@ -import functools -import importlib -import os -from functools import partial -from inspect import isfunction - -import fsspec -import numpy as np -import torch -from PIL import Image, ImageDraw, ImageFont -from safetensors.torch import load_file as load_safetensors -from einops import rearrange -from mediapy import write_image - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def get_string_from_tuple(s): - try: - # Check if the string starts and ends with parentheses - if s[0] == "(" and s[-1] == ")": - # Convert the string to a tuple - t = eval(s) - # Check if the type of t is tuple - if type(t) == tuple: - return t[0] - else: - pass - except: - pass - return s - - -def is_power_of_two(n): - """ - chat.openai.com/chat - Return True if n is a power of 2, otherwise return False. - - The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. - The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. - If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. - Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. - - """ - if n <= 0: - return False - return (n & (n - 1)) == 0 - - -def autocast(f, enabled=True): - def do_autocast(*args, **kwargs): - with torch.cuda.amp.autocast( - enabled=enabled, - dtype=torch.get_autocast_gpu_dtype(), - cache_enabled=torch.is_autocast_cache_enabled(), - ): - return f(*args, **kwargs) - - return do_autocast - - -def load_partial_from_config(config): - return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) - - -def log_txt_as_img(wh, xc, size=10): - # wh a tuple of (width, height) - # xc a list of captions to plot - b = len(xc) - txts = list() - for bi in range(b): - txt = Image.new("RGB", wh, color="white") - draw = ImageDraw.Draw(txt) - font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) - nc = int(40 * (wh[0] / 256)) - if isinstance(xc[bi], list): - text_seq = xc[bi][0] - else: - text_seq = xc[bi] - lines = "\n".join( - text_seq[start : start + nc] for start in range(0, len(text_seq), nc) - ) - - try: - draw.text((0, 0), lines, fill="black", font=font) - except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") - - txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 - txts.append(txt) - txts = np.stack(txts) - txts = torch.tensor(txts) - return txts - - -def partialclass(cls, *args, **kwargs): - class NewCls(cls): - __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) - - return NewCls - - -def make_path_absolute(path): - fs, p = fsspec.core.url_to_fs(path) - if fs.protocol == "file": - return os.path.abspath(p) - return path - - -def ismap(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] > 3) - - -def isimage(x): - if not isinstance(x, torch.Tensor): - return False - return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) - - -def isheatmap(x): - if not isinstance(x, torch.Tensor): - return False - - return x.ndim == 2 - - -def isneighbors(x): - if not isinstance(x, torch.Tensor): - return False - return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) - - -def exists(x): - return x is not None - - -def expand_dims_like(x, y): - while x.dim() != y.dim(): - x = x.unsqueeze(-1) - return x - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def mean_flat(tensor): - """ - https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == "__is_first_stage__": - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False, invalidate_cache=True): - module, cls = string.rsplit(".", 1) - if invalidate_cache: - importlib.invalidate_caches() - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - - -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError( - f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" - ) - return x[(...,) + (None,) * dims_to_append] - - -def load_model_from_config(config, ckpt, verbose=True, freeze=True): - print(f"Loading model from {ckpt}") - if ckpt.endswith("ckpt"): - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - elif ckpt.endswith("safetensors"): - sd = load_safetensors(ckpt) - else: - raise NotImplementedError - - model = instantiate_from_config(config.model) - - m, u = model.load_state_dict(sd, strict=False) - - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - if freeze: - for param in model.parameters(): - param.requires_grad = False - - model.eval() - return model - - -def get_configs_path() -> str: - """ - Get the `configs` directory. - For a working copy, this is the one in the root of the repository, - but for an installed copy, it's in the `sgm` package (see pyproject.toml). - """ - this_dir = os.path.dirname(__file__) - candidates = ( - os.path.join(this_dir, "configs"), - os.path.join(this_dir, "..", "configs"), - ) - for candidate in candidates: - candidate = os.path.abspath(candidate) - if os.path.isdir(candidate): - return candidate - raise FileNotFoundError(f"Could not find SGM configs in {candidates}") - - -def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): - """ - Will return the result of a recursive get attribute call. - E.g.: - a.b.c - = getattr(getattr(a, "b"), "c") - = get_nested_attribute(a, "b.c") - If any part of the attribute call is an integer x with current obj a, will - try to call a[x] instead of a.x first. - """ - attributes = attribute_path.split(".") - if depth is not None and depth > 0: - attributes = attributes[:depth] - assert len(attributes) > 0, "At least one attribute should be selected" - current_attribute = obj - current_key = None - for level, attribute in enumerate(attributes): - current_key = ".".join(attributes[: level + 1]) - try: - id_ = int(attribute) - current_attribute = current_attribute[id_] - except ValueError: - current_attribute = getattr(current_attribute, attribute) - - return (current_attribute, current_key) if return_key else current_attribute - - -def video_frames_as_grid(frames, save_path): - # frames: [T, C, H, W] - frames = frames.detach().cpu() - frames = rearrange(frames, "t c h w -> h (t w) c") - write_image(save_path, frames) - - -def server_safe_call(keep_trying: bool = False): - """Decorator for server calls. If the call fails, it will keep trying until it succeeds. - - Args: - keep_trying (bool, optional): whether to call again if the first try fails. Defaults to False. - """ - - def decorator(func): - def wrapper(*args, **kwargs): - success = False - while not success: - try: - ret = func(*args, **kwargs) - success = True - except KeyboardInterrupt: - raise - except: - if not keep_trying: - break - return ret - - return wrapper - - return decorator diff --git a/scripts/__init__.py b/src/__init__.py similarity index 100% rename from scripts/__init__.py rename to src/__init__.py diff --git a/scripts/util/__init__.py b/src/data/__init__.py similarity index 100% rename from scripts/util/__init__.py rename to src/data/__init__.py diff --git a/src/data/objaverse.py b/src/data/objaverse.py new file mode 100644 index 0000000000000000000000000000000000000000..dd27f86c2469e74da28e27929929d84cd1718965 --- /dev/null +++ b/src/data/objaverse.py @@ -0,0 +1,329 @@ +import os, sys +import math +import json +import importlib +from pathlib import Path + +import cv2 +import random +import numpy as np +from PIL import Image +import webdataset as wds +import pytorch_lightning as pl + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision import transforms + +from src.utils.train_util import instantiate_from_config +from src.utils.camera_util import ( + FOV_to_intrinsics, + center_looking_at_camera_pose, + get_surrounding_views, +) + + +class DataModuleFromConfig(pl.LightningDataModule): + def __init__( + self, + batch_size=8, + num_workers=4, + train=None, + validation=None, + test=None, + **kwargs, + ): + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + + self.dataset_configs = dict() + if train is not None: + self.dataset_configs['train'] = train + if validation is not None: + self.dataset_configs['validation'] = validation + if test is not None: + self.dataset_configs['test'] = test + + def setup(self, stage): + + if stage in ['fit']: + self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) + else: + raise NotImplementedError + + def train_dataloader(self): + + sampler = DistributedSampler(self.datasets['train']) + return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + + sampler = DistributedSampler(self.datasets['validation']) + return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def test_dataloader(self): + + return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + +class ObjaverseData(Dataset): + def __init__(self, + root_dir='objaverse/', + meta_fname='valid_paths.json', + input_image_dir='rendering_random_32views', + target_image_dir='rendering_random_32views', + input_view_num=6, + target_view_num=2, + total_view_n=32, + fov=50, + camera_rotation=True, + validation=False, + ): + self.root_dir = Path(root_dir) + self.input_image_dir = input_image_dir + self.target_image_dir = target_image_dir + + self.input_view_num = input_view_num + self.target_view_num = target_view_num + self.total_view_n = total_view_n + self.fov = fov + self.camera_rotation = camera_rotation + + with open(os.path.join(root_dir, meta_fname)) as f: + filtered_dict = json.load(f) + paths = filtered_dict['good_objs'] + self.paths = paths + + self.depth_scale = 4.0 + + total_objects = len(self.paths) + print('============= length of dataset %d =============' % len(self.paths)) + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + pil_img = Image.open(path) + + image = np.asarray(pil_img, dtype=np.float32) / 255. + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + + image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() + alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() + return image, alpha + + def __getitem__(self, index): + # load data + while True: + input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index]) + target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index]) + + indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False) + input_indices = indices[:self.input_view_num] + target_indices = indices[self.input_view_num:] + + '''background color, default: white''' + bg_white = [1., 1., 1.] + bg_black = [0., 0., 0.] + + image_list = [] + alpha_list = [] + depth_list = [] + normal_list = [] + pose_list = [] + + try: + input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses'] + for idx in input_indices: + image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white) + normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black) + depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale + depth = torch.from_numpy(depth).unsqueeze(0) + pose = input_cameras[idx] + pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) + + image_list.append(image) + alpha_list.append(alpha) + depth_list.append(depth) + normal_list.append(normal) + pose_list.append(pose) + + target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses'] + for idx in target_indices: + image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white) + normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black) + depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale + depth = torch.from_numpy(depth).unsqueeze(0) + pose = target_cameras[idx] + pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) + + image_list.append(image) + alpha_list.append(alpha) + depth_list.append(depth) + normal_list.append(normal) + pose_list.append(pose) + + except Exception as e: + print(e) + index = np.random.randint(0, len(self.paths)) + continue + + break + + images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W) + alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W) + depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W) + normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W) + w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4) + c2ws = torch.linalg.inv(w2cs).float() + + normals = normals * 2.0 - 1.0 + normals = F.normalize(normals, dim=1) + normals = (normals + 1.0) / 2.0 + normals = torch.lerp(torch.zeros_like(normals), normals, alphas) + + # random rotation along z axis + if self.camera_rotation: + degree = np.random.uniform(0, math.pi * 2) + rot = torch.tensor([ + [np.cos(degree), -np.sin(degree), 0, 0], + [np.sin(degree), np.cos(degree), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ]).unsqueeze(0).float() + c2ws = torch.matmul(rot, c2ws) + + # rotate normals + N, _, H, W = normals.shape + normals = normals * 2.0 - 1.0 + normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W) + normals = F.normalize(normals, dim=1) + normals = (normals + 1.0) / 2.0 + normals = torch.lerp(torch.zeros_like(normals), normals, alphas) + + # random scaling + if np.random.rand() < 0.5: + scale = np.random.uniform(0.8, 1.0) + c2ws[:, :3, 3] *= scale + depths *= scale + + # instrinsics of perspective cameras + K = FOV_to_intrinsics(self.fov) + Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float() + + data = { + 'input_images': images[:self.input_view_num], # (6, 3, H, W) + 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) + 'input_depths': depths[:self.input_view_num], # (6, 1, H, W) + 'input_normals': normals[:self.input_view_num], # (6, 3, H, W) + 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4) + 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) + + # lrm generator input and supervision + 'target_images': images[self.input_view_num:], # (V, 3, H, W) + 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W) + 'target_depths': depths[self.input_view_num:], # (V, 1, H, W) + 'target_normals': normals[self.input_view_num:], # (V, 3, H, W) + 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4) + 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3) + + 'depth_available': 1, + } + return data + + +class ValidationData(Dataset): + def __init__(self, + root_dir='objaverse/', + input_view_num=6, + input_image_size=256, + fov=50, + ): + self.root_dir = Path(root_dir) + self.input_view_num = input_view_num + self.input_image_size = input_image_size + self.fov = fov + + self.paths = sorted(os.listdir(self.root_dir)) + print('============= length of dataset %d =============' % len(self.paths)) + + cam_distance = 2.5 + azimuths = np.array([30, 90, 150, 210, 270, 330]) + elevations = np.array([30, -20, 30, -20, 30, -20]) + azimuths = np.deg2rad(azimuths) + elevations = np.deg2rad(elevations) + + x = cam_distance * np.cos(elevations) * np.cos(azimuths) + y = cam_distance * np.cos(elevations) * np.sin(azimuths) + z = cam_distance * np.sin(elevations) + + cam_locations = np.stack([x, y, z], axis=-1) + cam_locations = torch.from_numpy(cam_locations).float() + c2ws = center_looking_at_camera_pose(cam_locations) + self.c2ws = c2ws.float() + self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() + + render_c2ws = get_surrounding_views(M=8, radius=cam_distance) + render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) + self.render_c2ws = render_c2ws.float() + self.render_Ks = render_Ks.float() + + def __len__(self): + return len(self.paths) + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + pil_img = Image.open(path) + pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) + + image = np.asarray(pil_img, dtype=np.float32) / 255. + if image.shape[-1] == 4: + alpha = image[:, :, 3:] + image = image[:, :, :3] * alpha + color * (1 - alpha) + else: + alpha = np.ones_like(image[:, :, :1]) + + image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() + alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() + return image, alpha + + def __getitem__(self, index): + # load data + input_image_path = os.path.join(self.root_dir, self.paths[index]) + + '''background color, default: white''' + # color = np.random.uniform(0.48, 0.52) + bkg_color = [1.0, 1.0, 1.0] + + image_list = [] + alpha_list = [] + + for idx in range(self.input_view_num): + image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color) + image_list.append(image) + alpha_list.append(alpha) + + images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W) + alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W) + + data = { + 'input_images': images, # (6, 3, H, W) + 'input_alphas': alphas, # (6, 1, H, W) + 'input_c2ws': self.c2ws, # (6, 4, 4) + 'input_Ks': self.Ks, # (6, 3, 3) + + 'render_c2ws': self.render_c2ws, + 'render_Ks': self.render_Ks, + } + return data diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000000000000000000000000000000000000..584a6dcc59a641104f8942e7f4b4fc225e551f6a --- /dev/null +++ b/src/model.py @@ -0,0 +1,310 @@ +import os +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.transforms import v2 +from torchvision.utils import make_grid, save_image +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +import pytorch_lightning as pl +from einops import rearrange, repeat + +from src.utils.train_util import instantiate_from_config + + +class MVRecon(pl.LightningModule): + def __init__( + self, + lrm_generator_config, + lrm_path=None, + input_size=256, + render_size=192, + ): + super(MVRecon, self).__init__() + + self.input_size = input_size + self.render_size = render_size + + # init modules + self.lrm_generator = instantiate_from_config(lrm_generator_config) + if lrm_path is not None: + lrm_ckpt = torch.load(lrm_path) + self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False) + + self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') + + self.validation_step_outputs = [] + + def on_fit_start(self): + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) + + def prepare_batch_data(self, batch): + lrm_generator_input = {} + render_gt = {} # for supervision + + # input images + images = batch['input_images'] + images = v2.functional.resize( + images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) + + lrm_generator_input['images'] = images.to(self.device) + + # input cameras and render cameras + input_c2ws = batch['input_c2ws'].flatten(-2) + input_Ks = batch['input_Ks'].flatten(-2) + target_c2ws = batch['target_c2ws'].flatten(-2) + target_Ks = batch['target_Ks'].flatten(-2) + render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1) + render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1) + render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1) + + input_extrinsics = input_c2ws[:, :, :12] + input_intrinsics = torch.stack([ + input_Ks[:, :, 0], input_Ks[:, :, 4], + input_Ks[:, :, 2], input_Ks[:, :, 5], + ], dim=-1) + cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) + + # add noise to input cameras + cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02 + + lrm_generator_input['cameras'] = cameras.to(self.device) + lrm_generator_input['render_cameras'] = render_cameras.to(self.device) + + # target images + target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1) + target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1) + target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1) + + # random crop + render_size = np.random.randint(self.render_size, 513) + target_images = v2.functional.resize( + target_images, render_size, interpolation=3, antialias=True).clamp(0, 1) + target_depths = v2.functional.resize( + target_depths, render_size, interpolation=0, antialias=True) + target_alphas = v2.functional.resize( + target_alphas, render_size, interpolation=0, antialias=True) + + crop_params = v2.RandomCrop.get_params( + target_images, output_size=(self.render_size, self.render_size)) + target_images = v2.functional.crop(target_images, *crop_params) + target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1] + target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1] + + lrm_generator_input['render_size'] = render_size + lrm_generator_input['crop_params'] = crop_params + + render_gt['target_images'] = target_images.to(self.device) + render_gt['target_depths'] = target_depths.to(self.device) + render_gt['target_alphas'] = target_alphas.to(self.device) + + return lrm_generator_input, render_gt + + def prepare_validation_batch_data(self, batch): + lrm_generator_input = {} + + # input images + images = batch['input_images'] + images = v2.functional.resize( + images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) + + lrm_generator_input['images'] = images.to(self.device) + + input_c2ws = batch['input_c2ws'].flatten(-2) + input_Ks = batch['input_Ks'].flatten(-2) + + input_extrinsics = input_c2ws[:, :, :12] + input_intrinsics = torch.stack([ + input_Ks[:, :, 0], input_Ks[:, :, 4], + input_Ks[:, :, 2], input_Ks[:, :, 5], + ], dim=-1) + cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) + + lrm_generator_input['cameras'] = cameras.to(self.device) + + render_c2ws = batch['render_c2ws'].flatten(-2) + render_Ks = batch['render_Ks'].flatten(-2) + render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1) + + lrm_generator_input['render_cameras'] = render_cameras.to(self.device) + lrm_generator_input['render_size'] = 384 + lrm_generator_input['crop_params'] = None + + return lrm_generator_input + + def forward_lrm_generator( + self, + images, + cameras, + render_cameras, + render_size=192, + crop_params=None, + chunk_size=1, + ): + planes = torch.utils.checkpoint.checkpoint( + self.lrm_generator.forward_planes, + images, + cameras, + use_reentrant=False, + ) + frames = [] + for i in range(0, render_cameras.shape[1], chunk_size): + frames.append( + torch.utils.checkpoint.checkpoint( + self.lrm_generator.synthesizer, + planes, + cameras=render_cameras[:, i:i+chunk_size], + render_size=render_size, + crop_params=crop_params, + use_reentrant=False + ) + ) + frames = { + k: torch.cat([r[k] for r in frames], dim=1) + for k in frames[0].keys() + } + return frames + + def forward(self, lrm_generator_input): + images = lrm_generator_input['images'] + cameras = lrm_generator_input['cameras'] + render_cameras = lrm_generator_input['render_cameras'] + render_size = lrm_generator_input['render_size'] + crop_params = lrm_generator_input['crop_params'] + + out = self.forward_lrm_generator( + images, + cameras, + render_cameras, + render_size=render_size, + crop_params=crop_params, + chunk_size=1, + ) + render_images = torch.clamp(out['images_rgb'], 0.0, 1.0) + render_depths = out['images_depth'] + render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0) + + out = { + 'render_images': render_images, + 'render_depths': render_depths, + 'render_alphas': render_alphas, + } + return out + + def training_step(self, batch, batch_idx): + lrm_generator_input, render_gt = self.prepare_batch_data(batch) + + render_out = self.forward(lrm_generator_input) + + loss, loss_dict = self.compute_loss(render_out, render_gt) + + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + + if self.global_step % 1000 == 0 and self.global_rank == 0: + B, N, C, H, W = render_gt['target_images'].shape + N_in = lrm_generator_input['images'].shape[1] + + input_images = v2.functional.resize( + lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1) + input_images = torch.cat( + [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1) + + input_images = rearrange( + input_images, 'b n c h w -> b c h (n w)') + target_images = rearrange( + render_gt['target_images'], 'b n c h w -> b c h (n w)') + render_images = rearrange( + render_out['render_images'], 'b n c h w -> b c h (n w)') + target_alphas = rearrange( + repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + render_alphas = rearrange( + repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + target_depths = rearrange( + repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + render_depths = rearrange( + repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + MAX_DEPTH = torch.max(target_depths) + target_depths = target_depths / MAX_DEPTH * target_alphas + render_depths = render_depths / MAX_DEPTH + + grid = torch.cat([ + input_images, + target_images, render_images, + target_alphas, render_alphas, + target_depths, render_depths, + ], dim=-2) + grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) + + save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) + + return loss + + def compute_loss(self, render_out, render_gt): + # NOTE: the rgb value range of OpenLRM is [0, 1] + render_images = render_out['render_images'] + target_images = render_gt['target_images'].to(render_images) + render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 + target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 + + loss_mse = F.mse_loss(render_images, target_images) + loss_lpips = 2.0 * self.lpips(render_images, target_images) + + render_alphas = render_out['render_alphas'] + target_alphas = render_gt['target_alphas'] + loss_mask = F.mse_loss(render_alphas, target_alphas) + + loss = loss_mse + loss_lpips + loss_mask + + prefix = 'train' + loss_dict = {} + loss_dict.update({f'{prefix}/loss_mse': loss_mse}) + loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) + loss_dict.update({f'{prefix}/loss_mask': loss_mask}) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + lrm_generator_input = self.prepare_validation_batch_data(batch) + + render_out = self.forward(lrm_generator_input) + render_images = render_out['render_images'] + render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') + + self.validation_step_outputs.append(render_images) + + def on_validation_epoch_end(self): + images = torch.cat(self.validation_step_outputs, dim=-1) + + all_images = self.all_gather(images) + all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') + + if self.global_rank == 0: + image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') + + grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) + save_image(grid, image_path) + print(f"Saved image to {image_path}") + + self.validation_step_outputs.clear() + + def configure_optimizers(self): + lr = self.learning_rate + + params = [] + + lrm_params_fast, lrm_params_slow = [], [] + for n, p in self.lrm_generator.named_parameters(): + if 'adaLN_modulation' in n or 'camera_embedder' in n: + lrm_params_fast.append(p) + else: + lrm_params_slow.append(p) + params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 }) + params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 }) + + optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95)) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) + + return {'optimizer': optimizer, 'lr_scheduler': scheduler} diff --git a/src/model_mesh.py b/src/model_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..99945a0b410242a71678ad0034bf38315a34571b --- /dev/null +++ b/src/model_mesh.py @@ -0,0 +1,325 @@ +import os +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.transforms import v2 +from torchvision.utils import make_grid, save_image +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +import pytorch_lightning as pl +from einops import rearrange, repeat + +from src.utils.train_util import instantiate_from_config + + +# Regulrarization loss for FlexiCubes +def sdf_reg_loss_batch(sdf, all_edges): + sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ + F.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +class MVRecon(pl.LightningModule): + def __init__( + self, + lrm_generator_config, + input_size=256, + render_size=512, + init_ckpt=None, + ): + super(MVRecon, self).__init__() + + self.input_size = input_size + self.render_size = render_size + + # init modules + self.lrm_generator = instantiate_from_config(lrm_generator_config) + + self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') + + # Load weights from pretrained MVRecon model, and use the mlp + # weights to initialize the weights of sdf and rgb mlps. + if init_ckpt is not None: + sd = torch.load(init_ckpt, map_location='cpu')['state_dict'] + sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')} + sd_fc = {} + for k, v in sd.items(): + if k.startswith('lrm_generator.synthesizer.decoder.net.'): + if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer + # Here we assume the density filed's isosurface threshold is t, + # we reverse the sign of density filed to initialize SDF field. + # -(w*x + b - t) = (-w)*x + (t - b) + if 'weight' in k: + sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1] + else: + sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1] + sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4] + else: + sd_fc[k.replace('net.', 'net_sdf.')] = v + sd_fc[k.replace('net.', 'net_rgb.')] = v + else: + sd_fc[k] = v + sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()} + # missing `net_deformation` and `net_weight` parameters + self.lrm_generator.load_state_dict(sd_fc, strict=False) + print(f'Loaded weights from {init_ckpt}') + + self.validation_step_outputs = [] + + def on_fit_start(self): + device = torch.device(f'cuda:{self.global_rank}') + self.lrm_generator.init_flexicubes_geometry(device) + if self.global_rank == 0: + os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) + os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) + + def prepare_batch_data(self, batch): + lrm_generator_input = {} + render_gt = {} + + # input images + images = batch['input_images'] + images = v2.functional.resize( + images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) + + lrm_generator_input['images'] = images.to(self.device) + + # input cameras and render cameras + input_c2ws = batch['input_c2ws'] + input_Ks = batch['input_Ks'] + target_c2ws = batch['target_c2ws'] + + render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1) + render_w2cs = torch.linalg.inv(render_c2ws) + + input_extrinsics = input_c2ws.flatten(-2) + input_extrinsics = input_extrinsics[:, :, :12] + input_intrinsics = input_Ks.flatten(-2) + input_intrinsics = torch.stack([ + input_intrinsics[:, :, 0], input_intrinsics[:, :, 4], + input_intrinsics[:, :, 2], input_intrinsics[:, :, 5], + ], dim=-1) + cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) + + # add noise to input_cameras + cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02 + + lrm_generator_input['cameras'] = cameras.to(self.device) + lrm_generator_input['render_cameras'] = render_w2cs.to(self.device) + + # target images + target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1) + target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1) + target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1) + target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1) + + render_size = self.render_size + target_images = v2.functional.resize( + target_images, render_size, interpolation=3, antialias=True).clamp(0, 1) + target_depths = v2.functional.resize( + target_depths, render_size, interpolation=0, antialias=True) + target_alphas = v2.functional.resize( + target_alphas, render_size, interpolation=0, antialias=True) + target_normals = v2.functional.resize( + target_normals, render_size, interpolation=3, antialias=True) + + lrm_generator_input['render_size'] = render_size + + render_gt['target_images'] = target_images.to(self.device) + render_gt['target_depths'] = target_depths.to(self.device) + render_gt['target_alphas'] = target_alphas.to(self.device) + render_gt['target_normals'] = target_normals.to(self.device) + + return lrm_generator_input, render_gt + + def prepare_validation_batch_data(self, batch): + lrm_generator_input = {} + + # input images + images = batch['input_images'] + images = v2.functional.resize( + images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) + + lrm_generator_input['images'] = images.to(self.device) + + # input cameras + input_c2ws = batch['input_c2ws'].flatten(-2) + input_Ks = batch['input_Ks'].flatten(-2) + + input_extrinsics = input_c2ws[:, :, :12] + input_intrinsics = torch.stack([ + input_Ks[:, :, 0], input_Ks[:, :, 4], + input_Ks[:, :, 2], input_Ks[:, :, 5], + ], dim=-1) + cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) + + lrm_generator_input['cameras'] = cameras.to(self.device) + + # render cameras + render_c2ws = batch['render_c2ws'] + render_w2cs = torch.linalg.inv(render_c2ws) + + lrm_generator_input['render_cameras'] = render_w2cs.to(self.device) + lrm_generator_input['render_size'] = 384 + + return lrm_generator_input + + def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512): + planes = torch.utils.checkpoint.checkpoint( + self.lrm_generator.forward_planes, + images, + cameras, + use_reentrant=False, + ) + out = self.lrm_generator.forward_geometry( + planes, + render_cameras, + render_size, + ) + return out + + def forward(self, lrm_generator_input): + images = lrm_generator_input['images'] + cameras = lrm_generator_input['cameras'] + render_cameras = lrm_generator_input['render_cameras'] + render_size = lrm_generator_input['render_size'] + + out = self.forward_lrm_generator( + images, cameras, render_cameras, render_size=render_size) + + return out + + def training_step(self, batch, batch_idx): + lrm_generator_input, render_gt = self.prepare_batch_data(batch) + + render_out = self.forward(lrm_generator_input) + + loss, loss_dict = self.compute_loss(render_out, render_gt) + + self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) + + if self.global_step % 1000 == 0 and self.global_rank == 0: + B, N, C, H, W = render_gt['target_images'].shape + N_in = lrm_generator_input['images'].shape[1] + + target_images = rearrange( + render_gt['target_images'], 'b n c h w -> b c h (n w)') + render_images = rearrange( + render_out['img'], 'b n c h w -> b c h (n w)') + target_alphas = rearrange( + repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + render_alphas = rearrange( + repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + target_depths = rearrange( + repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + render_depths = rearrange( + repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') + target_normals = rearrange( + render_gt['target_normals'], 'b n c h w -> b c h (n w)') + render_normals = rearrange( + render_out['normal'], 'b n c h w -> b c h (n w)') + MAX_DEPTH = torch.max(target_depths) + target_depths = target_depths / MAX_DEPTH * target_alphas + render_depths = render_depths / MAX_DEPTH + + grid = torch.cat([ + target_images, render_images, + target_alphas, render_alphas, + target_depths, render_depths, + target_normals, render_normals, + ], dim=-2) + grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) + + image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png') + save_image(grid, image_path) + print(f"Saved image to {image_path}") + + return loss + + def compute_loss(self, render_out, render_gt): + # NOTE: the rgb value range of OpenLRM is [0, 1] + render_images = render_out['img'] + target_images = render_gt['target_images'].to(render_images) + render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 + target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 + loss_mse = F.mse_loss(render_images, target_images) + loss_lpips = 2.0 * self.lpips(render_images, target_images) + + render_alphas = render_out['mask'] + target_alphas = render_gt['target_alphas'] + loss_mask = F.mse_loss(render_alphas, target_alphas) + + render_depths = render_out['depth'] + target_depths = render_gt['target_depths'] + loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0]) + + render_normals = render_out['normal'] * 2.0 - 1.0 + target_normals = render_gt['target_normals'] * 2.0 - 1.0 + similarity = (render_normals * target_normals).sum(dim=-3).abs() + normal_mask = target_alphas.squeeze(-3) + loss_normal = 1 - similarity[normal_mask>0].mean() + loss_normal = 0.2 * loss_normal + + # flexicubes regularization loss + sdf = render_out['sdf'] + sdf_reg_loss = render_out['sdf_reg_loss'] + sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01 + _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss + flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5 + flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1 + + loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg + + loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg + + prefix = 'train' + loss_dict = {} + loss_dict.update({f'{prefix}/loss_mse': loss_mse}) + loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) + loss_dict.update({f'{prefix}/loss_mask': loss_mask}) + loss_dict.update({f'{prefix}/loss_normal': loss_normal}) + loss_dict.update({f'{prefix}/loss_depth': loss_depth}) + loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy}) + loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg}) + loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg}) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + lrm_generator_input = self.prepare_validation_batch_data(batch) + + render_out = self.forward(lrm_generator_input) + render_images = render_out['img'] + render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') + + self.validation_step_outputs.append(render_images) + + def on_validation_epoch_end(self): + images = torch.cat(self.validation_step_outputs, dim=-1) + + all_images = self.all_gather(images) + all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') + + if self.global_rank == 0: + image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') + + grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) + save_image(grid, image_path) + print(f"Saved image to {image_path}") + + self.validation_step_outputs.clear() + + def configure_optimizers(self): + lr = self.learning_rate + + optimizer = torch.optim.AdamW( + self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01) + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0) + + return {'optimizer': optimizer, 'lr_scheduler': scheduler} \ No newline at end of file diff --git a/scripts/util/detection/__init__.py b/src/models/__init__.py similarity index 100% rename from scripts/util/detection/__init__.py rename to src/models/__init__.py diff --git a/sgm/modules/autoencoding/__init__.py b/src/models/decoder/__init__.py similarity index 100% rename from sgm/modules/autoencoding/__init__.py rename to src/models/decoder/__init__.py diff --git a/src/models/decoder/transformer.py b/src/models/decoder/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e628c0bf589ee827908c894b93cc107f1c58b9 --- /dev/null +++ b/src/models/decoder/transformer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class BasicTransformerBlock(nn.Module): + """ + Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. + """ + # use attention from torch.nn.MultiHeadAttention + # Block contains a cross-attention layer, a self-attention layer, and a MLP + def __init__( + self, + inner_dim: int, + cond_dim: int, + num_heads: int, + eps: float, + attn_drop: float = 0., + attn_bias: bool = False, + mlp_ratio: float = 4., + mlp_drop: float = 0., + ): + super().__init__() + + self.norm1 = nn.LayerNorm(inner_dim) + self.cross_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm2 = nn.LayerNorm(inner_dim) + self.self_attn = nn.MultiheadAttention( + embed_dim=inner_dim, num_heads=num_heads, + dropout=attn_drop, bias=attn_bias, batch_first=True) + self.norm3 = nn.LayerNorm(inner_dim) + self.mlp = nn.Sequential( + nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), + nn.GELU(), + nn.Dropout(mlp_drop), + nn.Linear(int(inner_dim * mlp_ratio), inner_dim), + nn.Dropout(mlp_drop), + ) + + def forward(self, x, cond): + # x: [N, L, D] + # cond: [N, L_cond, D_cond] + x = x + self.cross_attn(self.norm1(x), cond, cond)[0] + before_sa = self.norm2(x) + x = x + self.self_attn(before_sa, before_sa, before_sa)[0] + x = x + self.mlp(self.norm3(x)) + return x + + +class TriplaneTransformer(nn.Module): + """ + Transformer with condition that generates a triplane representation. + + Reference: + Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 + """ + def __init__( + self, + inner_dim: int, + image_feat_dim: int, + triplane_low_res: int, + triplane_high_res: int, + triplane_dim: int, + num_layers: int, + num_heads: int, + eps: float = 1e-6, + ): + super().__init__() + + # attributes + self.triplane_low_res = triplane_low_res + self.triplane_high_res = triplane_high_res + self.triplane_dim = triplane_dim + + # modules + # initialize pos_embed with 1/sqrt(dim) * N(0, 1) + self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5) + self.layers = nn.ModuleList([ + BasicTransformerBlock( + inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps) + for _ in range(num_layers) + ]) + self.norm = nn.LayerNorm(inner_dim, eps=eps) + self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0) + + def forward(self, image_feats): + # image_feats: [N, L_cond, D_cond] + + N = image_feats.shape[0] + H = W = self.triplane_low_res + L = 3 * H * W + + x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] + for layer in self.layers: + x = layer(x, image_feats) + x = self.norm(x) + + # separate each plane and apply deconv + x = x.view(N, 3, H, W, -1) + x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] + x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] + x = self.deconv(x) # [3*N, D', H', W'] + x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] + x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] + x = x.contiguous() + + return x diff --git a/sgm/modules/autoencoding/lpips/__init__.py b/src/models/encoder/__init__.py similarity index 100% rename from sgm/modules/autoencoding/lpips/__init__.py rename to src/models/encoder/__init__.py diff --git a/src/models/encoder/dino.py b/src/models/encoder/dino.py new file mode 100644 index 0000000000000000000000000000000000000000..684444cab2a13979bcd5688069e9f7729d4ca784 --- /dev/null +++ b/src/models/encoder/dino.py @@ -0,0 +1,550 @@ +# coding=utf-8 +# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch ViT model.""" + + +import collections.abc +import math +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, +) +from transformers import PreTrainedModel, ViTConfig +from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer + + +class ViTEmbeddings(nn.Module): + """ + Construct the CLS token, position and patch embeddings. Optionally, also the mask token. + """ + + def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: + super().__init__() + + self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None + self.patch_embeddings = ViTPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.config = config + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: + """ + This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + + num_patches = embeddings.shape[1] - 1 + num_positions = self.position_embeddings.shape[1] - 1 + if num_patches == num_positions and height == width: + return self.position_embeddings + class_pos_embed = self.position_embeddings[:, 0] + patch_pos_embed = self.position_embeddings[:, 1:] + dim = embeddings.shape[-1] + h0 = height // self.config.patch_size + w0 = width // self.config.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + h0, w0 = h0 + 0.1, w0 + 0.1 + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), + mode="bicubic", + align_corners=False, + ) + assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: Optional[torch.BoolTensor] = None, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) + + if bool_masked_pos is not None: + seq_length = embeddings.shape[1] + mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) + # replace the masked visual tokens by mask_tokens + mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) + embeddings = embeddings * (1.0 - mask) + mask_tokens * mask + + # add the [CLS] token to the embedded patch tokens + cls_tokens = self.cls_token.expand(batch_size, -1, -1) + embeddings = torch.cat((cls_tokens, embeddings), dim=1) + + # add positional encoding to each token + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) + else: + embeddings = embeddings + self.position_embeddings + + embeddings = self.dropout(embeddings) + + return embeddings + + +class ViTPatchEmbeddings(nn.Module): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: + batch_size, num_channels, height, width = pixel_values.shape + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + if not interpolate_pos_encoding: + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) + return embeddings + + +class ViTSelfAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " + f"heads {config.num_attention_heads}." + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + mixed_query_layer = self.query(hidden_states) + + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + query_layer = self.transpose_for_scores(mixed_query_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class ViTSelfOutput(nn.Module): + """ + The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the + layernorm applied before each block. + """ + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + return hidden_states + + +class ViTAttention(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.attention = ViTSelfAttention(config) + self.output = ViTSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads: Set[int]) -> None: + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.attention.query = prune_linear_layer(self.attention.query, index) + self.attention.key = prune_linear_layer(self.attention.key, index) + self.attention.value = prune_linear_layer(self.attention.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + self_outputs = self.attention(hidden_states, head_mask, output_attentions) + + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class ViTIntermediate(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +class ViTOutput(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + + hidden_states = hidden_states + input_tensor + + return hidden_states + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class ViTLayer(nn.Module): + """This corresponds to the Block class in the timm implementation.""" + + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = ViTAttention(config) + self.intermediate = ViTIntermediate(config) + self.output = ViTOutput(config) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) + ) + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + + def forward( + self, + hidden_states: torch.Tensor, + adaln_input: torch.Tensor = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) + + self_attention_outputs = self.attention( + modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = attention_output + hidden_states + + # in ViT, layernorm is also applied after self-attention + layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp) + layer_output = self.intermediate(layer_output) + + # second residual connection is done here + layer_output = self.output(layer_output, hidden_states) + + outputs = (layer_output,) + outputs + + return outputs + + +class ViTEncoder(nn.Module): + def __init__(self, config: ViTConfig) -> None: + super().__init__() + self.config = config + self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + adaln_input: torch.Tensor = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> Union[tuple, BaseModelOutput]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + adaln_input, + layer_head_mask, + output_attentions, + ) + else: + layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ViTPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ViTConfig + base_model_prefix = "vit" + main_input_name = "pixel_values" + supports_gradient_checkpointing = True + _no_split_modules = ["ViTEmbeddings", "ViTLayer"] + + def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_( + module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range + ).to(module.weight.dtype) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, ViTEmbeddings): + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.position_embeddings.dtype) + + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), + mean=0.0, + std=self.config.initializer_range, + ).to(module.cls_token.dtype) + + +class ViTModel(ViTPreTrainedModel): + def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): + super().__init__(config) + self.config = config + + self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) + self.encoder = ViTEncoder(config) + + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.pooler = ViTPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> ViTPatchEmbeddings: + return self.embeddings.patch_embeddings + + def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + adaln_input: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings( + pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding + ) + + encoder_outputs = self.encoder( + embedding_output, + adaln_input=adaln_input, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class ViTPooler(nn.Module): + def __init__(self, config: ViTConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output \ No newline at end of file diff --git a/src/models/encoder/dino_wrapper.py b/src/models/encoder/dino_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e84fd51e7dfcfd1a969b763f5a49aeb7f608e6f9 --- /dev/null +++ b/src/models/encoder/dino_wrapper.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch.nn as nn +from transformers import ViTImageProcessor +from einops import rearrange, repeat +from .dino import ViTModel + + +class DinoWrapper(nn.Module): + """ + Dino v1 wrapper using huggingface transformer implementation. + """ + def __init__(self, model_name: str, freeze: bool = True): + super().__init__() + self.model, self.processor = self._build_dino(model_name) + self.camera_embedder = nn.Sequential( + nn.Linear(16, self.model.config.hidden_size, bias=True), + nn.SiLU(), + nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True) + ) + if freeze: + self._freeze() + + def forward(self, image, camera): + # image: [B, N, C, H, W] + # camera: [B, N, D] + # RGB image with [0,1] scale and properly sized + if image.ndim == 5: + image = rearrange(image, 'b n c h w -> (b n) c h w') + dtype = image.dtype + inputs = self.processor( + images=image.float(), + return_tensors="pt", + do_rescale=False, + do_resize=False, + ).to(self.model.device).to(dtype) + # embed camera + N = camera.shape[1] + camera_embeddings = self.camera_embedder(camera) + camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') + embeddings = camera_embeddings + # This resampling of positional embedding uses bicubic interpolation + outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True) + last_hidden_states = outputs.last_hidden_state + return last_hidden_states + + def _freeze(self): + print(f"======== Freezing DinoWrapper ========") + self.model.eval() + for name, param in self.model.named_parameters(): + param.requires_grad = False + + @staticmethod + def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): + import requests + try: + model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) + processor = ViTImageProcessor.from_pretrained(model_name) + return model, processor + except requests.exceptions.ProxyError as err: + if proxy_error_retries > 0: + print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") + import time + time.sleep(proxy_error_cooldown) + return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) + else: + raise err diff --git a/src/models/geometry/__init__.py b/src/models/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..89e9a6c2fffe82a55693885dae78c1a630924389 --- /dev/null +++ b/src/models/geometry/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. diff --git a/src/models/geometry/camera/__init__.py b/src/models/geometry/camera/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c5c7082e47c65a08e25489b3c3fd010d07ad9758 --- /dev/null +++ b/src/models/geometry/camera/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from torch import nn + + +class Camera(nn.Module): + def __init__(self): + super(Camera, self).__init__() + pass diff --git a/src/models/geometry/camera/perspective_camera.py b/src/models/geometry/camera/perspective_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcab0d2a321a77a5d3c2d4c3f40ba2cc32f6dfa --- /dev/null +++ b/src/models/geometry/camera/perspective_camera.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +from . import Camera +import numpy as np + + +def projection(x=0.1, n=1.0, f=50.0, near_plane=None): + if near_plane is None: + near_plane = n + return np.array( + [[n / x, 0, 0, 0], + [0, n / -x, 0, 0], + [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], + [0, 0, -1, 0]]).astype(np.float32) + + +class PerspectiveCamera(Camera): + def __init__(self, fovy=49.0, device='cuda'): + super(PerspectiveCamera, self).__init__() + self.device = device + focal = np.tan(fovy / 180.0 * np.pi * 0.5) + self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) + + def project(self, points_bxnx4): + out = torch.matmul( + points_bxnx4, + torch.transpose(self.proj_mtx, 1, 2)) + return out diff --git a/src/models/geometry/render/__init__.py b/src/models/geometry/render/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..483cfabbf395853f1ca3e67b856d5f17b9889d1b --- /dev/null +++ b/src/models/geometry/render/__init__.py @@ -0,0 +1,8 @@ +import torch + +class Renderer(): + def __init__(self): + pass + + def forward(self): + pass \ No newline at end of file diff --git a/src/models/geometry/render/neural_render.py b/src/models/geometry/render/neural_render.py new file mode 100644 index 0000000000000000000000000000000000000000..473464480125c050ee6dba973450678a197145fb --- /dev/null +++ b/src/models/geometry/render/neural_render.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import torch.nn.functional as F +import nvdiffrast.torch as dr +from . import Renderer + +_FG_LUT = None + + +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate( + attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + +def xfm_points(points, matrix, use_python=True): + '''Transform points. + Args: + points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] + matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] + use_python: Use PyTorch's torch.matmul (for validation) + Returns: + Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. + ''' + out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) + if torch.is_anomaly_enabled(): + assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" + return out + + +def dot(x, y): + return torch.sum(x * y, -1, keepdim=True) + + +def compute_vertex_normal(v_pos, t_pos_idx): + i0 = t_pos_idx[:, 0] + i1 = t_pos_idx[:, 1] + i2 = t_pos_idx[:, 2] + + v0 = v_pos[i0, :] + v1 = v_pos[i1, :] + v2 = v_pos[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + v_nrm = torch.zeros_like(v_pos) + v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + v_nrm = torch.where( + dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) + ) + v_nrm = F.normalize(v_nrm, dim=1) + assert torch.all(torch.isfinite(v_nrm)) + + return v_nrm + + +class NeuralRender(Renderer): + def __init__(self, device='cuda', camera_model=None): + super(NeuralRender, self).__init__() + self.device = device + self.ctx = dr.RasterizeCudaContext(device=device) + self.projection_mtx = None + self.camera = camera_model + + def render_mesh( + self, + mesh_v_pos_bxnx3, + mesh_t_pos_idx_fx3, + camera_mv_bx4x4, + mesh_v_feat_bxnxd, + resolution=256, + spp=1, + device='cuda', + hierarchical_mask=False + ): + assert not hierarchical_mask + + mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 + v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates + v_pos_clip = self.camera.project(v_pos) # Projection in the camera + + v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates + + # Render the image, + # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render + num_layers = 1 + mask_pyramid = None + assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes + mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos + + with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: + for _ in range(num_layers): + rast, db = peeler.rasterize_next_layer() + gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) + + hard_mask = torch.clamp(rast[..., -1:], 0, 1) + antialias_mask = dr.antialias( + hard_mask.clone().contiguous(), rast, v_pos_clip, + mesh_t_pos_idx_fx3) + + depth = gb_feat[..., -2:-1] + ori_mesh_feature = gb_feat[..., :-4] + + normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) + normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) + normal = F.normalize(normal, dim=-1) + normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background + + return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal diff --git a/src/models/geometry/rep_3d/__init__.py b/src/models/geometry/rep_3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d5628a8433298477d1963f92578d47106b4a0f --- /dev/null +++ b/src/models/geometry/rep_3d/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np + + +class Geometry(): + def __init__(self): + pass + + def forward(self): + pass diff --git a/src/models/geometry/rep_3d/dmtet.py b/src/models/geometry/rep_3d/dmtet.py new file mode 100644 index 0000000000000000000000000000000000000000..b6a709380abac0bbf66fd1c8582485f3982223e4 --- /dev/null +++ b/src/models/geometry/rep_3d/dmtet.py @@ -0,0 +1,504 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .dmtet_utils import get_center_boundary_index +import torch.nn.functional as F + + +############################################################################### +# DMTet utility functions +############################################################################### +def create_mt_variable(device): + triangle_table = torch.tensor( + [ + [-1, -1, -1, -1, -1, -1], + [1, 0, 2, -1, -1, -1], + [4, 0, 3, -1, -1, -1], + [1, 4, 2, 1, 3, 4], + [3, 1, 5, -1, -1, -1], + [2, 3, 0, 2, 5, 3], + [1, 4, 0, 1, 5, 4], + [4, 2, 5, -1, -1, -1], + [4, 5, 2, -1, -1, -1], + [4, 1, 0, 4, 5, 1], + [3, 2, 0, 3, 5, 2], + [1, 3, 5, -1, -1, -1], + [4, 1, 2, 4, 3, 1], + [3, 0, 4, -1, -1, -1], + [2, 0, 1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1] + ], dtype=torch.long, device=device) + + num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) + base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) + v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) + return triangle_table, num_triangles_table, base_tet_edges, v_id + + +def sort_edges(edges_ex2): + with torch.no_grad(): + order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() + order = order.unsqueeze(dim=1) + a = torch.gather(input=edges_ex2, index=order, dim=1) + b = torch.gather(input=edges_ex2, index=1 - order, dim=1) + return torch.stack([a, b], -1) + + +############################################################################### +# marching tetrahedrons (differentiable) +############################################################################### + +def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + return verts, faces + + +def create_tetmesh_variables(device='cuda'): + tet_table = torch.tensor( + [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], + [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], + [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], + [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], + [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], + [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], + [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], + [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], + [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], + [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], + [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], + [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) + num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) + return tet_table, num_tets_table + + +def marching_tets_tetmesh( + pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, + return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): + with torch.no_grad(): + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) + occ_sum = occ_sum[valid_tets] + + # find all vertices + all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) + idx_map = mapping[idx_map] # map edges to verts + + interp_v = unique_edges[mask_edges] # .long() + edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) + edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) + edges_to_interp_sdf[:, -1] *= -1 + + denominator = edges_to_interp_sdf.sum(1, keepdim=True) + + edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator + verts = (edges_to_interp * edges_to_interp_sdf).sum(1) + + idx_map = idx_map.reshape(-1, 6) + + tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) + num_triangles = num_triangles_table[tetindex] + + # Generate triangle indices + faces = torch.cat( + ( + torch.gather( + input=idx_map[num_triangles == 1], dim=1, + index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), + torch.gather( + input=idx_map[num_triangles == 2], dim=1, + index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), + ), dim=0) + if not return_tet_mesh: + return verts, faces + occupied_verts = ori_v[occ_n] + mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 + mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") + tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) + + idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 + tet_verts = torch.cat([verts, occupied_verts], 0) + num_tets = num_tets_table[tetindex] + + tets = torch.cat( + ( + torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( + -1, + 4), + torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( + -1, + 4), + ), dim=0) + # add fully occupied tets + fully_occupied = occ_fx4.sum(-1) == 4 + tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] + tets = torch.cat([tets, tet_fully_occupied]) + + return verts, faces, tet_verts, tets + + +############################################################################### +# Compact tet grid +############################################################################### + +def compact_tets(pos_nx3, sdf_n, tet_fx4): + with torch.no_grad(): + # Find surface tets + occ_n = sdf_n > 0 + occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) + occ_sum = torch.sum(occ_fx4, -1) + valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets + + valid_vtx = tet_fx4[valid_tets].reshape(-1) + unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) + new_pos = pos_nx3[unique_vtx] + new_sdf = sdf_n[unique_vtx] + new_tets = idx_map.reshape(-1, 4) + return new_pos, new_sdf, new_tets + + +############################################################################### +# Subdivide volume +############################################################################### + +def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + device = tet_pos_bxnx3.device + # get new verts + tet_fx4 = tet_bxfx4[0] + edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] + all_edges = tet_fx4[:, edges].reshape(-1, 2) + all_edges = sort_edges(all_edges) + unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) + idx_map = idx_map + tet_pos_bxnx3.shape[1] + all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) + mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( + all_values.shape[0], -1, 2, + all_values.shape[-1]).mean(2) + new_v = torch.cat([all_values, mid_points_pos], 1) + new_v, new_sdf = new_v[..., :3], new_v[..., 3] + + # get new tets + + idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] + idx_ab = idx_map[0::6] + idx_ac = idx_map[1::6] + idx_ad = idx_map[2::6] + idx_bc = idx_map[3::6] + idx_bd = idx_map[4::6] + idx_cd = idx_map[5::6] + + tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) + tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) + tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) + tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) + tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) + tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) + tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) + tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) + + tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) + tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) + tet = tet_np.long().to(device) + + return new_v, tet, new_sdf + + +############################################################################### +# Adjacency +############################################################################### +def tet_to_tet_adj_sparse(tet_tx4): + # include self connection!!!!!!!!!!!!!!!!!!! + with torch.no_grad(): + t = tet_tx4.shape[0] + device = tet_tx4.device + idx_array = torch.LongTensor( + [0, 1, 2, + 1, 0, 3, + 2, 3, 0, + 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) + + # get all faces + all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( + -1, + 3) # (tx4, 3) + all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) + # sort and group + all_faces_sorted, _ = torch.sort(all_faces, dim=1) + + all_faces_unique, inverse_indices, counts = torch.unique( + all_faces_sorted, dim=0, return_counts=True, + return_inverse=True) + tet_face_fx3 = all_faces_unique[counts == 2] + counts = counts[inverse_indices] # tx4 + valid = (counts == 2) + + group = inverse_indices[valid] + # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) + _, indices = torch.sort(group) + all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] + tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) + + tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) + adj_self = torch.arange(t, device=tet_tx4.device) + adj_self = torch.stack([adj_self, adj_self], -1) + tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) + + tet_adj_idx = torch.unique(tet_adj_idx, dim=0) + values = torch.ones( + tet_adj_idx.shape[0], device=tet_tx4.device).float() + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + + # normalization + neighbor_num = 1.0 / torch.sparse.sum( + adj_sparse, dim=1).to_dense() + values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) + adj_sparse = torch.sparse.FloatTensor( + tet_adj_idx.t(), values, torch.Size([t, t])) + return adj_sparse + + +############################################################################### +# Compact grid +############################################################################### + +def get_tet_bxfx4x3(bxnxz, bxfx4): + n_batch, z = bxnxz.shape[0], bxnxz.shape[2] + gather_input = bxnxz.unsqueeze(2).expand( + n_batch, bxnxz.shape[1], 4, z) + gather_index = bxfx4.unsqueeze(-1).expand( + n_batch, bxfx4.shape[1], 4, z).long() + tet_bxfx4xz = torch.gather( + input=gather_input, dim=1, index=gather_index) + + return tet_bxfx4xz + + +def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): + with torch.no_grad(): + assert tet_pos_bxnx3.shape[0] == 1 + + occ = grid_sdf[0] > 0 + occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) + mask = (occ_sum > 0) & (occ_sum < 4) + + # build connectivity graph + adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) + mask = mask.float().unsqueeze(-1) + + # Include a one ring of neighbors + for i in range(1): + mask = torch.sparse.mm(adj_matrix, mask) + mask = mask.squeeze(-1) > 0 + + mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) + new_tet_bxfx4 = tet_bxfx4[:, mask].long() + selected_verts_idx = torch.unique(new_tet_bxfx4) + new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] + mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) + new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) + new_grid_sdf = grid_sdf[:, selected_verts_idx] + return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf + + +############################################################################### +# Regularizer +############################################################################### + +def sdf_reg_loss(sdf, all_edges): + sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 0], + (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits( + sdf_f1x6x2[..., 1], + (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +def sdf_reg_loss_batch(sdf, all_edges): + sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) + mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) + sdf_f1x6x2 = sdf_f1x6x2[mask] + sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ + torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) + return sdf_diff + + +############################################################################### +# Geometry interface +############################################################################### +class DMTetGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(DMTetGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + tets = np.load('data/tets/%d_compress.npz' % (grid_res)) + self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) + # Make sure the tet is zero-centered and length is equal to 1 + length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] + length = length.max() + mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 + self.verts = (self.verts - mid.unsqueeze(dim=0)) / length + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + self.indices = torch.from_numpy(tets['tets']).long().to(self.device) + self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) + self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) + # Parameters for regularization computation + edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) + all_edges = self.indices[:, edges].reshape(-1, 2) + all_edges_sorted = torch.sort(all_edges, dim=1)[0] + self.all_edges = torch.unique(all_edges_sorted, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces = marching_tets( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces + + def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): + if indices is None: + indices = self.indices + verts, faces, tet_verts, tets = marching_tets_tetmesh( + v_deformed_nx3, sdf_n, indices, self.triangle_table, + self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, + num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) + faces = torch.cat( + [faces[:, 0:1], + faces[:, 2:3], + faces[:, 1:2], ], dim=-1) + return verts, faces, tet_verts, tets + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/src/models/geometry/rep_3d/dmtet_utils.py b/src/models/geometry/rep_3d/dmtet_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8d466a9e78c49d947c115707693aa18d759885ad --- /dev/null +++ b/src/models/geometry/rep_3d/dmtet_utils.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch + + +def get_center_boundary_index(verts): + length_ = torch.sum(verts ** 2, dim=-1) + center_idx = torch.argmin(length_) + boundary_neg = verts == verts.max() + boundary_pos = verts == verts.min() + boundary = torch.bitwise_or(boundary_pos, boundary_neg) + boundary = torch.sum(boundary.float(), dim=-1) + boundary_idx = torch.nonzero(boundary) + return center_idx, boundary_idx.squeeze(dim=-1) diff --git a/src/models/geometry/rep_3d/extract_texture_map.py b/src/models/geometry/rep_3d/extract_texture_map.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d62bb5a6c5cdf632fb504db3d2dfa99a3abbd3 --- /dev/null +++ b/src/models/geometry/rep_3d/extract_texture_map.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import numpy as np +import nvdiffrast.torch as dr + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/src/models/geometry/rep_3d/flexicubes.py b/src/models/geometry/rep_3d/flexicubes.py new file mode 100644 index 0000000000000000000000000000000000000000..26d7b91b6266d802baaf55b64238629cd0f740d0 --- /dev/null +++ b/src/models/geometry/rep_3d/flexicubes.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +import torch +from .tables import * + +__all__ = [ + 'FlexiCubes' +] + + +class FlexiCubes: + """ + This class implements the FlexiCubes method for extracting meshes from scalar fields. + It maintains a series of lookup tables and indices to support the mesh extraction process. + FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances + the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting + the surface representation through gradient-based optimization. + + During instantiation, the class loads DMC tables from a file and transforms them into + PyTorch tensors on the specified device. + + Attributes: + device (str): Specifies the computational device (default is "cuda"). + dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges + associated with each dual vertex in 256 Marching Cubes (MC) configurations. + num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of + the 256 MC configurations. + check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 + of the DMC configurations. + tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. + quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles + along one diagonal. + quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into + two triangles along the other diagonal. + quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles + during training by connecting all edges to their midpoints. + cube_corners (torch.Tensor): Defines the positions of a standard unit cube's + eight corners in 3D space, ordered starting from the origin (0,0,0), + moving along the x-axis, then y-axis, and finally z-axis. + Used as a blueprint for generating a voxel grid. + cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used + to retrieve the case id. + cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. + Used to retrieve edge vertices in DMC. + edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with + their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the + first edge is oriented along the x-axis. + dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges + across four adjacent cubes to the shared faces of these cubes. For instance, + dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along + the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. + This tensor is only utilized during isosurface tetrahedralization. + adj_pairs (torch.Tensor): + A tensor containing index pairs that correspond to neighboring cubes that share the same edge. + qef_reg_scale (float): + The scaling factor applied to the regularization loss to prevent issues with singularity + when solving the QEF. This parameter is only used when a 'grad_func' is specified. + weight_scale (float): + The scale of weights in FlexiCubes. Should be between 0 and 1. + """ + + def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): + + self.device = device + self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) + self.num_vd_table = torch.tensor(num_vd_table, + dtype=torch.long, device=device, requires_grad=False) + self.check_table = torch.tensor( + check_table, + dtype=torch.long, device=device, requires_grad=False) + + self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) + self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + self.quad_split_train = torch.tensor( + [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) + + self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ + 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) + self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) + self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, + 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) + + self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], + dtype=torch.long, device=device) + self.dir_faces_table = torch.tensor([ + [[5, 4], [3, 2], [4, 5], [2, 3]], + [[5, 4], [1, 0], [4, 5], [0, 1]], + [[3, 2], [1, 0], [2, 3], [0, 1]] + ], dtype=torch.long, device=device) + self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) + self.qef_reg_scale = qef_reg_scale + self.weight_scale = weight_scale + + def construct_voxel_grid(self, res): + """ + Generates a voxel grid based on the specified resolution. + + Args: + res (int or list[int]): The resolution of the voxel grid. If an integer + is provided, it is used for all three dimensions. If a list or tuple + of 3 integers is provided, they define the resolution for the x, + y, and z dimensions respectively. + + Returns: + (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the + cube corners (index into vertices) of the constructed voxel grid. + The vertices are centered at the origin, with the length of each + dimension in the grid being one. + """ + base_cube_f = torch.arange(8).to(self.device) + if isinstance(res, int): + res = (res, res, res) + voxel_grid_template = torch.ones(res, device=self.device) + + res = torch.tensor([res], dtype=torch.float, device=self.device) + coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 + verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) + cubes = (base_cube_f.unsqueeze(0) + + torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) + + verts_rounded = torch.round(verts * 10**5) / (10**5) + verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) + cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) + + return verts_unique - 0.5, cubes + + def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, + gamma_f=None, training=False, output_tetmesh=False, grad_func=None): + r""" + Main function for mesh extraction from scalar field using FlexiCubes. This function converts + discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, + to triangle or tetrahedral meshes using a differentiable operation as described in + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances + mesh quality and geometric fidelity by adjusting the surface representation based on gradient + optimization. The output surface is differentiable with respect to the input vertex positions, + scalar field values, and weight parameters. + + If you intend to extract a surface mesh from a fixed Signed Distance Field without the + optimization of parameters, it is suggested to provide the "grad_func" which should + return the surface gradient at any given 3D position. When grad_func is provided, the process + to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as + described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. + Please note, this approach is non-differentiable. + + For more details and example usage in optimization, refer to the + `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. + + Args: + x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. + s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values + denote that the corresponding vertex resides inside the isosurface. This affects + the directions of the extracted triangle faces and volume to be tetrahedralized. + cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. + res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it + is used for all three dimensions. If a list or tuple of 3 integers is provided, they + specify the resolution for the x, y, and z dimensions respectively. + beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual + vertices positioning. Defaults to uniform value for all edges. + alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual + vertices positioning. Defaults to uniform value for all vertices. + gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of + quadrilaterals into triangles. Defaults to uniform value for all cubes. + training (bool, optional): If set to True, applies differentiable quad splitting for + training. Defaults to False. + output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, + outputs a triangular mesh. Defaults to False. + grad_func (callable, optional): A function to compute the surface gradient at specified + 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 + tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. + + Returns: + (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: + - Vertices for the extracted triangular/tetrahedral mesh. + - Faces for the extracted triangular/tetrahedral mesh. + - Regularizer L_dev, computed per dual vertex. + + .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: + https://research.nvidia.com/labs/toronto-ai/flexicubes/ + .. _Manifold Dual Contouring: + https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf + """ + + surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) + if surf_cubes.sum() == 0: + return torch.zeros( + (0, 3), + device=self.device), torch.zeros( + (0, 4), + dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( + (0, 3), + dtype=torch.long, device=self.device), torch.zeros( + (0), + device=self.device) + beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) + + case_ids = self._get_case_id(occ_fx8, surf_cubes, res) + + surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) + + vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( + x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) + vertices, faces, s_edges, edge_indices = self._triangulate( + s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) + if not output_tetmesh: + return vertices, faces, L_dev + else: + vertices, tets = self._tetrahedralize( + x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training) + return vertices, tets, L_dev + + def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): + """ + Regularizer L_dev as in Equation 8 + """ + dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) + mean_l2 = torch.zeros_like(vd[:, 0]) + mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() + mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() + return mad + + def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): + """ + Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. + """ + n_cubes = surf_cubes.shape[0] + + if beta_fx12 is not None: + beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) + else: + beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) + + if alpha_fx8 is not None: + alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) + else: + alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) + + if gamma_f is not None: + gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 + else: + gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) + + return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] + + @torch.no_grad() + def _get_case_id(self, occ_fx8, surf_cubes, res): + """ + Obtains the ID of topology cases based on cell corner occupancy. This function resolves the + ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the + supplementary material. It should be noted that this function assumes a regular grid. + """ + case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) + + problem_config = self.check_table.to(self.device)[case_ids] + to_check = problem_config[..., 0] == 1 + problem_config = problem_config[to_check] + if not isinstance(res, (list, tuple)): + res = [res, res, res] + + # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, + # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). + # This allows efficient checking on adjacent cubes. + problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) + vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 + vol_idx_problem = vol_idx[surf_cubes][to_check] + problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config + vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] + + within_range = ( + vol_idx_problem_adj[..., 0] >= 0) & ( + vol_idx_problem_adj[..., 0] < res[0]) & ( + vol_idx_problem_adj[..., 1] >= 0) & ( + vol_idx_problem_adj[..., 1] < res[1]) & ( + vol_idx_problem_adj[..., 2] >= 0) & ( + vol_idx_problem_adj[..., 2] < res[2]) + + vol_idx_problem = vol_idx_problem[within_range] + vol_idx_problem_adj = vol_idx_problem_adj[within_range] + problem_config = problem_config[within_range] + problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], + vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] + # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. + to_invert = (problem_config_adj[..., 0] == 1) + idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] + case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) + return case_ids + + @torch.no_grad() + def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): + """ + Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge + can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge + and marks the cube edges with this index. + """ + occ_n = s_n < 0 + all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 + + surf_edges_mask = mask_edges[_idx_map] + counts = counts[_idx_map] + + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) + # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index + # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. + idx_map = mapping[_idx_map] + surf_edges = unique_edges[mask_edges] + return surf_edges, idx_map, counts, surf_edges_mask + + @torch.no_grad() + def _identify_surf_cubes(self, s_n, cube_fx8): + """ + Identifies grid cubes that intersect with the underlying surface by checking if the signs at + all corners are not identical. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + _occ_sum = torch.sum(occ_fx8, -1) + surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) + return surf_cubes, occ_fx8 + + def _linear_interp(self, edges_weight, edges_x): + """ + Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. + """ + edge_dim = edges_weight.dim() - 2 + assert edges_weight.shape[edge_dim] == 2 + edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - + torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) + denominator = edges_weight.sum(edge_dim) + ue = (edges_x * edges_weight).sum(edge_dim) / denominator + return ue + + def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): + p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) + norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) + c_bx3 = c_bx3.reshape(-1, 3) + A = norm_bxnx3 + B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) + + A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) + B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) + A = torch.cat([A, A_reg], 1) + B = torch.cat([B, B_reg], 1) + dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) + return dual_verts + + def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): + """ + Computes the location of dual vertices as described in Section 4.2 + """ + alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) + surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) + surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) + zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) + + idx_map = idx_map.reshape(-1, 12) + num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) + edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] + + total_num_vd = 0 + vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) + if grad_func is not None: + normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) + vd = [] + for num in torch.unique(num_vd): + cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) + curr_num_vd = cur_cubes.sum() * num + curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) + curr_edge_group_to_vd = torch.arange( + curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd + total_num_vd += curr_num_vd + curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ + cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) + + curr_mask = (curr_edge_group != -1) + edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) + edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) + edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) + vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) + vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) + + if grad_func is not None: + with torch.no_grad(): + cube_e_verts_idx = idx_map[cur_cubes] + curr_edge_group[~curr_mask] = 0 + + verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) + verts_group_idx[verts_group_idx == -1] = 0 + verts_group_pos = torch.index_select( + input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) + v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) + curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) + verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) + + normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( + -1, num.item(), 7, + 3) + curr_mask = curr_mask.squeeze(2) + vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, + verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) + edge_group = torch.cat(edge_group) + edge_group_to_vd = torch.cat(edge_group_to_vd) + edge_group_to_cube = torch.cat(edge_group_to_cube) + vd_num_edges = torch.cat(vd_num_edges) + vd_gamma = torch.cat(vd_gamma) + + if grad_func is not None: + vd = torch.cat(vd) + L_dev = torch.zeros([1], device=self.device) + else: + vd = torch.zeros((total_num_vd, 3), device=self.device) + beta_sum = torch.zeros((total_num_vd, 1), device=self.device) + + idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) + + x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) + s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) + + zero_crossing_group = torch.index_select( + input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) + + alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) + ue_group = self._linear_interp(s_group * alpha_group, x_group) + + beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, + index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) + beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) + vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum + L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) + + v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd + + vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * + 12 + edge_group, src=v_idx[edge_group_to_vd]) + + return vd, L_dev, vd_gamma, vd_idx_map + + def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): + """ + Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into + triangles based on the gamma parameter, as described in Section 4.3. + """ + with torch.no_grad(): + group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. + group = idx_map.reshape(-1)[group_mask] + vd_idx = vd_idx_map[group_mask] + edge_indices, indices = torch.sort(group, stable=True) + quad_vd_idx = vd_idx[indices].reshape(-1, 4) + + # Ensure all face directions point towards the positive SDF to maintain consistent winding. + s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) + flip_mask = s_edges[:, 0] > 0 + quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], + quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) + if grad_func is not None: + # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. + with torch.no_grad(): + vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) + gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) + else: + quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) + gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( + 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) + gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( + 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) + if not training: + mask = (gamma_02 > gamma_13).squeeze(1) + faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) + faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] + faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] + faces = faces.reshape(-1, 3) + else: + vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) + vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 + vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + + torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 + weight_sum = (gamma_02 + gamma_13) + 1e-8 + vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / + weight_sum.unsqueeze(-1)).squeeze(1) + vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] + vd = torch.cat([vd, vd_center]) + faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) + faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) + return vd, faces, s_edges, edge_indices + + def _tetrahedralize( + self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, + surf_cubes, training): + """ + Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. + """ + occ_n = s_n < 0 + occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) + occ_sum = torch.sum(occ_fx8, -1) + + inside_verts = x_nx3[occ_n] + mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] + """ + For each grid edge connecting two grid vertices with different + signs, we first form a four-sided pyramid by connecting one + of the grid vertices with four mesh vertices that correspond + to the grid edge and then subdivide the pyramid into two tetrahedra + """ + inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ + s_edges < 0]] + if not training: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) + else: + inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) + + tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) + """ + For each grid edge connecting two grid vertices with the + same sign, the tetrahedron is formed by the two grid vertices + and two vertices in consecutive adjacent cells + """ + inside_cubes = (occ_sum == 8) + inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) + inside_cubes_center_idx = torch.arange( + inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] + + surface_n_inside_cubes = surf_cubes | inside_cubes + edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), + dtype=torch.long, device=x_nx3.device) * -1 + surf_cubes = surf_cubes[surface_n_inside_cubes] + inside_cubes = inside_cubes[surface_n_inside_cubes] + edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) + edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx + + all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) + unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) + unique_edges = unique_edges.long() + mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 + mask = mask_edges[_idx_map] + counts = counts[_idx_map] + mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 + mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) + idx_map = mapping[_idx_map] + + group_mask = (counts == 4) & mask + group = idx_map.reshape(-1)[group_mask] + edge_indices, indices = torch.sort(group) + cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, + device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] + edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( + 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] + # Identify the face shared by the adjacent cells. + cube_idx_4 = cube_idx[indices].reshape(-1, 4) + edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] + shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) + cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) + # Identify an edge of the face with different signs and + # select the mesh vertex corresponding to the identified edge. + case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 + case_ids_expand[surf_cubes] = case_ids + cases = case_ids_expand[cube_idx_4x2] + quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) + mask = (quad_edge == -1).sum(-1) == 0 + inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) + tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] + + tets = torch.cat([tets_surface, tets_inside]) + vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) + return vertices, tets diff --git a/src/models/geometry/rep_3d/flexicubes_geometry.py b/src/models/geometry/rep_3d/flexicubes_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..bf050ee20361f78957839942f83fe77efde231b7 --- /dev/null +++ b/src/models/geometry/rep_3d/flexicubes_geometry.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import numpy as np +import os +from . import Geometry +from .flexicubes import FlexiCubes # replace later +from .dmtet import sdf_reg_loss_batch +import torch.nn.functional as F + +def get_center_boundary_index(grid_res, device): + v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True + center_indices = torch.nonzero(v.reshape(-1)) + + v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False + v[:2, ...] = True + v[-2:, ...] = True + v[:, :2, ...] = True + v[:, -2:, ...] = True + v[:, :, :2] = True + v[:, :, -2:] = True + boundary_indices = torch.nonzero(v.reshape(-1)) + return center_indices, boundary_indices + +############################################################################### +# Geometry interface +############################################################################### +class FlexiCubesGeometry(Geometry): + def __init__( + self, grid_res=64, scale=2.0, device='cuda', renderer=None, + render_type='neural_render', args=None): + super(FlexiCubesGeometry, self).__init__() + self.grid_res = grid_res + self.device = device + self.args = args + self.fc = FlexiCubes(device, weight_scale=0.5) + self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) + if isinstance(scale, list): + self.verts[:, 0] = self.verts[:, 0] * scale[0] + self.verts[:, 1] = self.verts[:, 1] * scale[1] + self.verts[:, 2] = self.verts[:, 2] * scale[1] + else: + self.verts = self.verts * scale + + all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) + self.all_edges = torch.unique(all_edges, dim=0) + + # Parameters used for fix boundary sdf + self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) + self.renderer = renderer + self.render_type = render_type + + def getAABB(self): + return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values + + def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): + if indices is None: + indices = self.indices + + verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, + beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], + gamma_f=weight_n[:, 20], training=is_training + ) + return verts, faces, v_reg_loss + + + def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): + return_value = dict() + if self.render_type == 'neural_render': + tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( + mesh_v_nx3.unsqueeze(dim=0), + mesh_f_fx3.int(), + camera_mv_bx4x4, + mesh_v_nx3.unsqueeze(dim=0), + resolution=resolution, + device=self.device, + hierarchical_mask=hierarchical_mask + ) + + return_value['tex_pos'] = tex_pos + return_value['mask'] = mask + return_value['hard_mask'] = hard_mask + return_value['rast'] = rast + return_value['v_pos_clip'] = v_pos_clip + return_value['mask_pyramid'] = mask_pyramid + return_value['depth'] = depth + return_value['normal'] = normal + else: + raise NotImplementedError + + return return_value + + def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): + # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 + v_list = [] + f_list = [] + n_batch = v_deformed_bxnx3.shape[0] + all_render_output = [] + for i_batch in range(n_batch): + verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) + v_list.append(verts_nx3) + f_list.append(faces_fx3) + render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) + all_render_output.append(render_output) + + # Concatenate all render output + return_keys = all_render_output[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in all_render_output] + return_value[k] = value + # We can do concatenation outside of the render + return return_value diff --git a/src/models/geometry/rep_3d/tables.py b/src/models/geometry/rep_3d/tables.py new file mode 100644 index 0000000000000000000000000000000000000000..5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca --- /dev/null +++ b/src/models/geometry/rep_3d/tables.py @@ -0,0 +1,791 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. +dmc_table = [ +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], +[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] +] +num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, +2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, +1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, +1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, +2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, +3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, +2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, +1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, +1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, +1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] +check_table = [ +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 194], +[1, -1, 0, 0, 193], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 164], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 161], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 152], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 145], +[1, 0, 0, 1, 144], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 137], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 133], +[1, 0, 1, 0, 132], +[1, 1, 0, 0, 131], +[1, 1, 0, 0, 130], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 100], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 98], +[0, 0, 0, 0, 0], +[1, 0, 0, 1, 96], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 88], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 82], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 74], +[0, 0, 0, 0, 0], +[1, 0, 1, 0, 72], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 70], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 67], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 65], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 56], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 52], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 44], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 1, 0, 0, 40], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 38], +[1, 0, -1, 0, 37], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 33], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 28], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 26], +[1, 0, 0, -1, 25], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, -1, 0, 0, 20], +[0, 0, 0, 0, 0], +[1, 0, -1, 0, 18], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 9], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[1, 0, 0, -1, 6], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0], +[0, 0, 0, 0, 0] +] +tet_table = [ +[-1, -1, -1, -1, -1, -1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, -1], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, -1], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, -1, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, -1, 2, 4, 4, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, 5, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, -1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[-1, 1, 1, 4, 4, 1], +[0, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[8, 8, 8, 8, 8, 8], +[1, 1, 1, 4, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 4, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 5, 5, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[6, 6, 6, 6, 6, 6], +[6, -1, 0, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 4, -1, 6, 4, 6], +[6, 4, 0, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 1, 1, 6, 1, 6], +[5, 5, 5, 5, 5, 5], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 4, 2, 2, 4, 2], +[0, 4, 0, 4, 4, 0], +[2, 0, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 1, 1, 6, -1, 6], +[6, 1, 1, 6, 0, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[4, 1, 1, 4, 4, 1], +[0, 1, 1, 0, 0, 1], +[4, 0, 0, 4, 4, 4], +[2, 2, 2, 2, 2, 2], +[6, 1, 1, 6, 4, 6], +[6, 1, 1, 6, 4, 6], +[6, 0, 0, 6, 0, 6], +[6, 2, 2, 6, 2, 6], +[5, 1, 1, 5, 5, 1], +[0, 1, 1, 0, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 4, 1], +[0, 4, 0, 4, 4, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 5, 0, 5, 0, 5], +[5, 5, 5, 5, 5, 5], +[5, 5, 5, 5, 5, 5], +[0, 5, 0, 5, 0, 5], +[-1, 5, 0, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[4, 5, -1, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[4, 5, 0, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[6, 6, 6, 6, 6, 6], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 5, 2, 5, -1, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 0, 5], +[1, 5, 1, 5, 1, 5], +[2, 5, 2, 5, 4, 5], +[0, 5, 0, 5, 0, 5], +[2, 5, 2, 5, 4, 5], +[1, 5, 1, 5, 1, 5], +[2, 4, 2, 4, 4, 2], +[0, 4, 0, 4, 4, 4], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 6, 2, 6, 6, 2], +[0, 0, 0, 0, 0, 0], +[2, 0, 2, 0, 0, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[4, 1, 1, 1, 4, 1], +[0, 1, 1, 1, 0, 1], +[4, 0, 0, 4, 4, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 4, 1], +[0, 0, 0, 0, 0, 0], +[4, 0, 0, 4, 4, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[6, 0, 0, 6, 0, 6], +[0, 0, 0, 0, 0, 0], +[6, 6, 6, 6, 6, 6], +[5, 5, 5, 5, 5, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 0, 5, 0, 5], +[5, 5, 1, 5, 1, 5], +[4, 4, 4, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[4, 4, 0, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[4, 4, 4, 4, 4, 4], +[4, 4, 0, 4, 4, 4], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[8, 8, 8, 8, 8, 8], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 0, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 1, 1, 4, 4, 1], +[2, 2, 2, 2, 2, 2], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 0, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 4, 2, 4, 4, 2], +[1, 1, 1, 1, 1, 1], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[2, 2, 2, 2, 2, 2], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[5, 5, 5, 5, 5, 5], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[4, 4, 4, 4, 4, 4], +[1, 1, 1, 1, 1, 1], +[0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0], +[12, 12, 12, 12, 12, 12] +] diff --git a/src/models/lrm.py b/src/models/lrm.py new file mode 100644 index 0000000000000000000000000000000000000000..eea9ee3353d74fb60451fec87f6c2c30816f64ae --- /dev/null +++ b/src/models/lrm.py @@ -0,0 +1,196 @@ +# Copyright (c) 2023, Zexin He +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import mcubes +import nvdiffrast.torch as dr +from einops import rearrange, repeat + +from .encoder.dino_wrapper import DinoWrapper +from .decoder.transformer import TriplaneTransformer +from .renderer.synthesizer import TriplaneSynthesizer +from ..utils.mesh_util import xatlas_uvmap + + +class InstantNeRF(nn.Module): + """ + Full model of the large reconstruction model. + """ + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = 'facebook/dino-vitb16', + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + ): + super().__init__() + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + ) + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + ) + + def forward_planes(self, images, cameras): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + + # transformer generating planes + planes = self.transformer(image_feats) + + return planes + + def forward(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + + # render target views + render_results = self.synthesizer(planes, render_cameras, render_size) + + return { + 'planes': planes, + **render_results, + } + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + tex_pos = torch.cat(tex_pos, dim=0) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + tex_feat = self.synthesizer.forward_points(planes, tex_pos)['rgb'] + + if hard_mask is not None: + final_tex_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def extract_mesh( + self, + planes: torch.Tensor, + mesh_resolution: int = 256, + mesh_threshold: int = 10.0, + use_texture_map: bool = False, + texture_resolution: int = 1024, + **kwargs, + ): + ''' + Extract a 3D mesh from triplane nerf. Only support batch_size 1. + :param planes: triplane features + :param mesh_resolution: marching cubes resolution + :param mesh_threshold: iso-surface threshold + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texture map + ''' + assert planes.shape[0] == 1 + device = planes.device + + grid_out = self.synthesizer.forward_grid( + planes=planes, + grid_size=mesh_resolution, + ) + + vertices, faces = mcubes.marching_cubes( + grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), + mesh_threshold, + ) + vertices = vertices / (mesh_resolution - 1) * 2 - 1 + + if not use_texture_map: + # query vertex colors + vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) + vertices_colors = self.synthesizer.forward_points( + planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() + vertices_colors = (vertices_colors * 255).astype(np.uint8) + + return vertices, faces, vertices_colors + + # use x-atlas to get uv mapping for the mesh + vertices = torch.tensor(vertices, dtype=torch.float32, device=device) + faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) + + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.float() + + # query the texture field to get the RGB color for texture map + tex_feat = self.get_texture_prediction( + planes, [gb_pos], tex_hard_mask) + background_feature = torch.zeros_like(tex_feat) + img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map \ No newline at end of file diff --git a/src/models/lrm_mesh.py b/src/models/lrm_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b0f278e6bf73d3320c05c24809de862220f53a00 --- /dev/null +++ b/src/models/lrm_mesh.py @@ -0,0 +1,385 @@ +# Copyright (c) 2023, Tencent Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import nvdiffrast.torch as dr +from einops import rearrange, repeat + +from .encoder.dino_wrapper import DinoWrapper +from .decoder.transformer import TriplaneTransformer +from .renderer.synthesizer_mesh import TriplaneSynthesizer +from .geometry.camera.perspective_camera import PerspectiveCamera +from .geometry.render.neural_render import NeuralRender +from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry +from ..utils.mesh_util import xatlas_uvmap + + +class InstantMesh(nn.Module): + """ + Full model of the large reconstruction model. + """ + def __init__( + self, + encoder_freeze: bool = False, + encoder_model_name: str = 'facebook/dino-vitb16', + encoder_feat_dim: int = 768, + transformer_dim: int = 1024, + transformer_layers: int = 16, + transformer_heads: int = 16, + triplane_low_res: int = 32, + triplane_high_res: int = 64, + triplane_dim: int = 80, + rendering_samples_per_ray: int = 128, + grid_res: int = 128, + grid_scale: float = 2.0, + ): + super().__init__() + + # attributes + self.grid_res = grid_res + self.grid_scale = grid_scale + self.deformation_multiplier = 4.0 + + # modules + self.encoder = DinoWrapper( + model_name=encoder_model_name, + freeze=encoder_freeze, + ) + + self.transformer = TriplaneTransformer( + inner_dim=transformer_dim, + num_layers=transformer_layers, + num_heads=transformer_heads, + image_feat_dim=encoder_feat_dim, + triplane_low_res=triplane_low_res, + triplane_high_res=triplane_high_res, + triplane_dim=triplane_dim, + ) + + self.synthesizer = TriplaneSynthesizer( + triplane_dim=triplane_dim, + samples_per_ray=rendering_samples_per_ray, + ) + + def init_flexicubes_geometry(self, device, fovy=50.0, use_renderer=True): + camera = PerspectiveCamera(fovy=fovy, device=device) + if use_renderer: + renderer = NeuralRender(device, camera_model=camera) + else: + renderer = None + self.geometry = FlexiCubesGeometry( + grid_res=self.grid_res, + scale=self.grid_scale, + renderer=renderer, + render_type='neural_render', + device=device, + ) + + def forward_planes(self, images, cameras): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + B = images.shape[0] + + # encode images + image_feats = self.encoder(images, cameras) + image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) + + # decode triplanes + planes = self.transformer(image_feats) + + return planes + + def get_sdf_deformation_prediction(self, planes): + ''' + Predict SDF and deformation for tetrahedron vertices + :param planes: triplane feature map for the geometry + ''' + init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) + + # Step 1: predict the SDF and deformation + sdf, deformation, weight = torch.utils.checkpoint.checkpoint( + self.synthesizer.get_geometry_prediction, + planes, + init_position, + self.geometry.indices, + use_reentrant=False, + ) + + # Step 2: Normalize the deformation to avoid the flipped triangles. + deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation) + sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) + + #### + # Step 3: Fix some sdf if we observe empty shape (full positive or full negative) + sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) + sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) + pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) + neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) + zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) + if torch.sum(zero_surface).item() > 0: + update_sdf = torch.zeros_like(sdf[0:1]) + max_sdf = sdf.max() + min_sdf = sdf.min() + update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero + update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero + new_sdf = torch.zeros_like(sdf) + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + new_sdf[i_batch:i_batch + 1] += update_sdf + update_mask = (new_sdf == 0).float() + # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) + sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) + sdf_reg_loss = sdf_reg_loss * zero_surface.float() + sdf = sdf * update_mask + new_sdf * (1 - update_mask) + + # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative) + final_sdf = [] + final_def = [] + for i_batch in range(zero_surface.shape[0]): + if zero_surface[i_batch]: + final_sdf.append(sdf[i_batch: i_batch + 1].detach()) + final_def.append(deformation[i_batch: i_batch + 1].detach()) + else: + final_sdf.append(sdf[i_batch: i_batch + 1]) + final_def.append(deformation[i_batch: i_batch + 1]) + sdf = torch.cat(final_sdf, dim=0) + deformation = torch.cat(final_def, dim=0) + return sdf, deformation, sdf_reg_loss, weight + + def get_geometry_prediction(self, planes=None): + ''' + Function to generate mesh with give triplanes + :param planes: triplane features + ''' + # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid. + sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes) + v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation + tets = self.geometry.indices + n_batch = planes.shape[0] + v_list = [] + f_list = [] + flexicubes_surface_reg_list = [] + + # Step 2: Using marching tet to obtain the mesh + for i_batch in range(n_batch): + verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( + v_deformed[i_batch], + sdf[i_batch].squeeze(dim=-1), + with_uv=False, + indices=tets, + weight_n=weight[i_batch].squeeze(dim=-1), + is_training=self.training, + ) + flexicubes_surface_reg_list.append(flexicubes_surface_reg) + v_list.append(verts) + f_list.append(faces) + + flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() + flexicubes_weight_reg = (weight ** 2).mean() + + return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg) + + def get_texture_prediction(self, planes, tex_pos, hard_mask=None): + ''' + Predict Texture given triplanes + :param planes: the triplane feature map + :param tex_pos: Position we want to query the texture field + :param hard_mask: 2D silhoueete of the rendered image + ''' + tex_pos = torch.cat(tex_pos, dim=0) + if not hard_mask is None: + tex_pos = tex_pos * hard_mask.float() + batch_size = tex_pos.shape[0] + tex_pos = tex_pos.reshape(batch_size, -1, 3) + ################### + # We use mask to get the texture location (to save the memory) + if hard_mask is not None: + n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) + sample_tex_pose_list = [] + max_point = n_point_list.max() + expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 + for i in range(tex_pos.shape[0]): + tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) + if tex_pos_one_shape.shape[1] < max_point: + tex_pos_one_shape = torch.cat( + [tex_pos_one_shape, torch.zeros( + 1, max_point - tex_pos_one_shape.shape[1], 3, + device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) + sample_tex_pose_list.append(tex_pos_one_shape) + tex_pos = torch.cat(sample_tex_pose_list, dim=0) + + tex_feat = torch.utils.checkpoint.checkpoint( + self.synthesizer.get_texture_prediction, + planes, + tex_pos, + use_reentrant=False, + ) + + if hard_mask is not None: + final_tex_feat = torch.zeros( + planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) + expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 + for i in range(planes.shape[0]): + final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) + tex_feat = final_tex_feat + + return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) + + def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): + ''' + Function to render a generated mesh with nvdiffrast + :param mesh_v: List of vertices for the mesh + :param mesh_f: List of faces for the mesh + :param cam_mv: 4x4 rotation matrix + :return: + ''' + return_value_list = [] + for i_mesh in range(len(mesh_v)): + return_value = self.geometry.render_mesh( + mesh_v[i_mesh], + mesh_f[i_mesh].int(), + cam_mv[i_mesh], + resolution=render_size, + hierarchical_mask=False + ) + return_value_list.append(return_value) + + return_keys = return_value_list[0].keys() + return_value = dict() + for k in return_keys: + value = [v[k] for v in return_value_list] + return_value[k] = value + + mask = torch.cat(return_value['mask'], dim=0) + hard_mask = torch.cat(return_value['hard_mask'], dim=0) + tex_pos = return_value['tex_pos'] + depth = torch.cat(return_value['depth'], dim=0) + normal = torch.cat(return_value['normal'], dim=0) + return mask, hard_mask, tex_pos, depth, normal + + def forward_geometry(self, planes, render_cameras, render_size=256): + ''' + Main function of our Generator. It first generate 3D mesh, then render it into 2D image + with given `render_cameras`. + :param planes: triplane features + :param render_cameras: cameras to render generated 3D shape + ''' + B, NV = render_cameras.shape[:2] + + # Generate 3D mesh first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + + # Render the mesh into 2D image (get 3d position of each image plane) + cam_mv = render_cameras + run_n_view = cam_mv.shape[1] + antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size) + + tex_hard_mask = hard_mask + tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos] + tex_hard_mask = torch.cat( + [torch.cat( + [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1] + for i_view in range(run_n_view)], dim=2) + for i in range(planes.shape[0])], dim=0) + + # Querying the texture field to predict the texture feature for each pixel on the image + tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) + background_feature = torch.ones_like(tex_feat) # white background + + # Merge them together + img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) + + # We should split it back to the original image shape + img_feat = torch.cat( + [torch.cat( + [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] + for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) + + img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) + antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive + normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) + + out = { + 'img': img, + 'mask': antilias_mask, + 'depth': depth, + 'normal': normal, + 'sdf': sdf, + 'mesh_v': mesh_v, + 'mesh_f': mesh_f, + 'sdf_reg_loss': sdf_reg_loss, + } + return out + + def forward(self, images, cameras, render_cameras, render_size: int): + # images: [B, V, C_img, H_img, W_img] + # cameras: [B, V, 16] + # render_cameras: [B, M, D_cam_render] + # render_size: int + B, M = render_cameras.shape[:2] + + planes = self.forward_planes(images, cameras) + out = self.forward_geometry(planes, render_cameras, render_size=render_size) + + return { + 'planes': planes, + **out + } + + def extract_mesh( + self, + planes: torch.Tensor, + use_texture_map: bool = False, + texture_resolution: int = 1024, + **kwargs, + ): + ''' + Extract a 3D mesh from FlexiCubes. Only support batch_size 1. + :param planes: triplane features + :param use_texture_map: use texture map or vertex color + :param texture_resolution: the resolution of texure map + ''' + assert planes.shape[0] == 1 + device = planes.device + + # predict geometry first + mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) + vertices, faces = mesh_v[0], mesh_f[0] + + if not use_texture_map: + # query vertex colors + vertices_tensor = vertices.unsqueeze(0) + vertices_colors = self.synthesizer.get_texture_prediction( + planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy() + vertices_colors = (vertices_colors * 255).astype(np.uint8) + + return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors + + # use x-atlas to get uv mapping for the mesh + ctx = dr.RasterizeCudaContext(device=device) + uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( + self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution) + tex_hard_mask = tex_hard_mask.float() + + # query the texture field to get the RGB color for texture map + tex_feat = self.get_texture_prediction( + planes, [gb_pos], tex_hard_mask) + background_feature = torch.zeros_like(tex_feat) + img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) + texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) + + return vertices, faces, uvs, mesh_tex_idx, texture_map \ No newline at end of file diff --git a/src/models/renderer/__init__.py b/src/models/renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/src/models/renderer/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/src/models/renderer/synthesizer.py b/src/models/renderer/synthesizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8db9fbdb1703b566117d227c8e4eef04157ccc93 --- /dev/null +++ b/src/models/renderer/synthesizer.py @@ -0,0 +1,203 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import ImportanceRenderer +from .utils.ray_sampler import RaySampler + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + self.net = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1 + 3), + ) + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def forward(self, sampled_features, ray_directions): + # Aggregate features by mean + # sampled_features = sampled_features.mean(1) + # Aggregate features by concatenation + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + x = sampled_features + + N, M, C = x.shape + x = x.contiguous().view(N*M, C) + + x = self.net(x) + x = x.view(N, M, -1) + rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + sigma = x[..., 0:1] + + return {'rgb': rgb, 'sigma': sigma} + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # renderings + self.renderer = ImportanceRenderer() + self.ray_sampler = RaySampler() + + # modules + self.decoder = OSGDecoder(n_features=triplane_dim) + + def forward(self, planes, cameras, render_size=128, crop_params=None): + # planes: (N, 3, D', H', W') + # cameras: (N, M, D_cam) + # render_size: int + assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" + N, M = cameras.shape[:2] + + cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) + intrinsics = cameras[..., 16:25].view(N, M, 3, 3) + + # Create a batch of rays for volume rendering + ray_origins, ray_directions = self.ray_sampler( + cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), + intrinsics=intrinsics.reshape(-1, 3, 3), + render_size=render_size, + ) + assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" + assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" + + # Crop rays if crop_params is available + if crop_params is not None: + ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) + ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) + i, j, h, w = crop_params + ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) + + # Perform volume rendering + rgb_samples, depth_samples, weights_samples = self.renderer( + planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, + ) + + # Reshape into 'raw' neural-rendered image + if crop_params is not None: + Himg, Wimg = crop_params[2:] + else: + Himg = Wimg = render_size + rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() + depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) + + out = { + 'images_rgb': rgb_images, + 'images_depth': depth_images, + 'images_weight': weight_images, + } + return out + + def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): + # planes: (N, 3, D', H', W') + # grid_size: int + # aabb: (N, 2, 3) + if aabb is None: + aabb = torch.tensor([ + [self.rendering_kwargs['sampler_bbox_min']] * 3, + [self.rendering_kwargs['sampler_bbox_max']] * 3, + ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) + assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" + N = planes.shape[0] + + # create grid points for triplane query + grid_points = [] + for i in range(N): + grid_points.append(torch.stack(torch.meshgrid( + torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), + torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), + indexing='ij', + ), dim=-1).reshape(-1, 3)) + cube_grid = torch.stack(grid_points, dim=0).to(planes.device) + + features = self.forward_points(planes, cube_grid) + + # reshape into grid + features = { + k: v.reshape(N, grid_size, grid_size, grid_size, -1) + for k, v in features.items() + } + return features + + def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): + # planes: (N, 3, D', H', W') + # points: (N, P, 3) + N, P = points.shape[:2] + + # query triplane in chunks + outs = [] + for i in range(0, points.shape[1], chunk_size): + chunk_points = points[:, i:i+chunk_size] + + # query triplane + chunk_out = self.renderer.run_model_activated( + planes=planes, + decoder=self.decoder, + sample_coordinates=chunk_points, + sample_directions=torch.zeros_like(chunk_points), + options=self.rendering_kwargs, + ) + outs.append(chunk_out) + + # concatenate the outputs + point_features = { + k: torch.cat([out[k] for out in outs], dim=1) + for k in outs[0].keys() + } + return point_features diff --git a/src/models/renderer/synthesizer_mesh.py b/src/models/renderer/synthesizer_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..dc31838315b33781560b3623c030443eeae24147 --- /dev/null +++ b/src/models/renderer/synthesizer_mesh.py @@ -0,0 +1,141 @@ +# ORIGINAL LICENSE +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + +import itertools +import torch +import torch.nn as nn + +from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes + + +class OSGDecoder(nn.Module): + """ + Triplane decoder that gives RGB and sigma values from sampled features. + Using ReLU here instead of Softplus in the original implementation. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 + """ + def __init__(self, n_features: int, + hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): + super().__init__() + + self.net_sdf = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 1), + ) + self.net_rgb = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_deformation = nn.Sequential( + nn.Linear(3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 3), + ) + self.net_weight = nn.Sequential( + nn.Linear(8 * 3 * n_features, hidden_dim), + activation(), + *itertools.chain(*[[ + nn.Linear(hidden_dim, hidden_dim), + activation(), + ] for _ in range(num_layers - 2)]), + nn.Linear(hidden_dim, 21), + ) + + # init all bias to zero + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.zeros_(m.bias) + + def get_geometry_prediction(self, sampled_features, flexicubes_indices): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + sdf = self.net_sdf(sampled_features) + deformation = self.net_deformation(sampled_features) + + grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) + grid_features = grid_features.reshape( + sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) + weight = self.net_weight(grid_features) * 0.1 + + return sdf, deformation, weight + + def get_texture_prediction(self, sampled_features): + _N, n_planes, _M, _C = sampled_features.shape + sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) + + rgb = self.net_rgb(sampled_features) + rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF + + return rgb + + +class TriplaneSynthesizer(nn.Module): + """ + Synthesizer that renders a triplane volume with planes and a camera. + + Reference: + EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 + """ + + DEFAULT_RENDERING_KWARGS = { + 'ray_start': 'auto', + 'ray_end': 'auto', + 'box_warp': 2., + 'white_back': True, + 'disparity_space_sampling': False, + 'clamp_mode': 'softplus', + 'sampler_bbox_min': -1., + 'sampler_bbox_max': 1., + } + + def __init__(self, triplane_dim: int, samples_per_ray: int): + super().__init__() + + # attributes + self.triplane_dim = triplane_dim + self.rendering_kwargs = { + **self.DEFAULT_RENDERING_KWARGS, + 'depth_resolution': samples_per_ray // 2, + 'depth_resolution_importance': samples_per_ray // 2, + } + + # modules + self.plane_axes = generate_planes() + self.decoder = OSGDecoder(n_features=triplane_dim) + + def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) + return sdf, deformation, weight + + def get_texture_prediction(self, planes, sample_coordinates): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes( + plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) + + rgb = self.decoder.get_texture_prediction(sampled_features) + return rgb diff --git a/src/models/renderer/utils/__init__.py b/src/models/renderer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2c772e4fa331c678cfff50884be94d7d31835b34 --- /dev/null +++ b/src/models/renderer/utils/__init__.py @@ -0,0 +1,9 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. diff --git a/sgm/modules/encoders/math_utils.py b/src/models/renderer/utils/math_utils.py similarity index 86% rename from sgm/modules/encoders/math_utils.py rename to src/models/renderer/utils/math_utils.py index 35e4f6f876ebb5878fa05e93bdf10488cb73e297..4cf9d2b811e0acbc7923bc9126e010b52cb1a8af 100644 --- a/sgm/modules/encoders/math_utils.py +++ b/src/models/renderer/utils/math_utils.py @@ -22,7 +22,6 @@ import torch - def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: """ Left-multiplies MxM @ NxM. Returns NxM. @@ -37,7 +36,6 @@ def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: """ return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) - def torch_dot(x: torch.Tensor, y: torch.Tensor): """ Dot product of two tensors. @@ -57,16 +55,9 @@ def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_leng rays_o = rays_o.detach().reshape(-1, 3) rays_d = rays_d.detach().reshape(-1, 3) - bb_min = [ - -1 * (box_side_length / 2), - -1 * (box_side_length / 2), - -1 * (box_side_length / 2), - ] - bb_max = [ - 1 * (box_side_length / 2), - 1 * (box_side_length / 2), - 1 * (box_side_length / 2), - ] + + bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] + bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) @@ -75,20 +66,12 @@ def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_leng sign = (invdir < 0).long() # Intersect with YZ plane. - tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[ - ..., 0 - ] - tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[ - ..., 0 - ] + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] # Intersect with XZ plane. - tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[ - ..., 1 - ] - tymax = ( - bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1] - ) * invdir[..., 1] + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] + tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] # Resolve parallel rays. is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False @@ -98,12 +81,8 @@ def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_leng tmax = torch.min(tmax, tymax) # Intersect with XY plane. - tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[ - ..., 2 - ] - tzmax = ( - bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2] - ) * invdir[..., 2] + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] + tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] # Resolve parallel rays. is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False diff --git a/src/models/renderer/utils/ray_marcher.py b/src/models/renderer/utils/ray_marcher.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1db43478de703509cdd04c684f92f8e283c5ad --- /dev/null +++ b/src/models/renderer/utils/ray_marcher.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. +Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MipRayMarcher2(nn.Module): + def __init__(self, activation_factory): + super().__init__() + self.activation_factory = activation_factory + + def run_forward(self, colors, densities, depths, rendering_options, normals=None): + dtype = colors.dtype + deltas = depths[:, :, 1:] - depths[:, :, :-1] + colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + + # using factory mode for better usability + densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) + + density_delta = densities_mid * deltas + + alpha = 1 - torch.exp(-density_delta).to(dtype) + + alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + weights = weights.to(dtype) + + composite_rgb = torch.sum(weights * colors_mid, -2) + weight_total = weights.sum(2) + # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total + composite_depth = torch.sum(weights * depths_mid, -2) + + # clip the composite to min/max range of depths + composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) + composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) + + if rendering_options.get('white_back', False): + composite_rgb = composite_rgb + 1 - weight_total + + # rendered value scale is 0-1, comment out original mipnerf scaling + # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) + + return composite_rgb, composite_depth, weights + + + def forward(self, colors, densities, depths, rendering_options, normals=None): + if normals is not None: + composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) + return composite_rgb, composite_depth, composite_normals, weights + + composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) + return composite_rgb, composite_depth, weights diff --git a/src/models/renderer/utils/ray_sampler.py b/src/models/renderer/utils/ray_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ae5151dda467e826ce346986bd486d4465c906f2 --- /dev/null +++ b/src/models/renderer/utils/ray_sampler.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The ray sampler is a module that takes in camera matrices and resolution and batches of rays. +Expects cam2world matrices that use the OpenCV camera coordinate system conventions. +""" + +import torch + +class RaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, intrinsics, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + dtype = cam2world_matrix.dtype + device = cam2world_matrix.device + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=dtype, device=device), + torch.arange(render_size, dtype=dtype, device=device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.ones((N, M), dtype=dtype, device=device) + + x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs + + +class OrthoRaySampler(torch.nn.Module): + def __init__(self): + super().__init__() + self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None + + + def forward(self, cam2world_matrix, ortho_scale, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + ortho_scale: float + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 3) + """ + + N, M = cam2world_matrix.shape[0], render_size**2 + + uv = torch.stack(torch.meshgrid( + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), + indexing='ij', + )) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) + z_cam = torch.zeros((N, M), device=cam2world_matrix.device) + + x_lift = (x_cam - 0.5) * ortho_scale + y_lift = (y_cam - 0.5) * ortho_scale + + cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) + + _opencv2blender = torch.tensor([ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] + + ray_dirs_cam = torch.stack([ + torch.zeros((N, M), device=cam2world_matrix.device), + torch.zeros((N, M), device=cam2world_matrix.device), + torch.ones((N, M), device=cam2world_matrix.device), + ], dim=-1) + ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) + + return ray_origins, ray_dirs diff --git a/src/models/renderer/utils/renderer.py b/src/models/renderer/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..95c4c728efbd0283b8ddd7dc6a1b28d1510efa97 --- /dev/null +++ b/src/models/renderer/utils/renderer.py @@ -0,0 +1,323 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +# +# Modified by Jiale Xu +# The modifications are subject to the same license as the original. + + +""" +The renderer is a module that takes in rays, decides where to sample along each +ray, and computes pixel colors using the volume rendering equation. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .ray_marcher import MipRayMarcher2 +from . import math_utils + + +def generate_planes(): + """ + Defines planes by the three vectors that form the "axes" of the + plane. Should work with arbitrary number of planes and planes of + arbitrary orientation. + + Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 + """ + return torch.tensor([[[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], + [[1, 0, 0], + [0, 0, 1], + [0, 1, 0]], + [[0, 0, 1], + [0, 1, 0], + [1, 0, 0]]], dtype=torch.float32) + +def project_onto_planes(planes, coordinates): + """ + Does a projection of a 3D point onto a batch of 2D planes, + returning 2D plane coordinates. + + Takes plane axes of shape n_planes, 3, 3 + # Takes coordinates of shape N, M, 3 + # returns projections of shape N*n_planes, M, 2 + """ + N, M, C = coordinates.shape + n_planes, _, _ = planes.shape + coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) + inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) + projections = torch.bmm(coordinates, inv_planes) + return projections[..., :2] + +def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): + assert padding_mode == 'zeros' + N, n_planes, C, H, W = plane_features.shape + _, M, _ = coordinates.shape + plane_features = plane_features.view(N*n_planes, C, H, W) + dtype = plane_features.dtype + + coordinates = (2/box_warp) * coordinates # add specific box bounds + + projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) + output_features = torch.nn.functional.grid_sample( + plane_features, + projected_coordinates.to(dtype), + mode=mode, + padding_mode=padding_mode, + align_corners=False, + ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) + return output_features + +def sample_from_3dgrid(grid, coordinates): + """ + Expects coordinates in shape (batch_size, num_points_per_batch, 3) + Expects grid in shape (1, channels, H, W, D) + (Also works if grid has batch size) + Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) + """ + batch_size, n_coords, n_dims = coordinates.shape + sampled_features = torch.nn.functional.grid_sample( + grid.expand(batch_size, -1, -1, -1, -1), + coordinates.reshape(batch_size, 1, 1, -1, n_dims), + mode='bilinear', + padding_mode='zeros', + align_corners=False, + ) + N, C, H, W, D = sampled_features.shape + sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) + return sampled_features + +class ImportanceRenderer(torch.nn.Module): + """ + Modified original version to filter out-of-box samples as TensoRF does. + + Reference: + TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 + """ + def __init__(self): + super().__init__() + self.activation_factory = self._build_activation_factory() + self.ray_marcher = MipRayMarcher2(self.activation_factory) + self.plane_axes = generate_planes() + + def _build_activation_factory(self): + def activation_factory(options: dict): + if options['clamp_mode'] == 'softplus': + return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better + else: + assert False, "Renderer only supports `clamp_mode`=`softplus`!" + return activation_factory + + def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, + planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): + """ + Additional filtering is applied to filter out-of-box samples. + Modifications made by Zexin He. + """ + + # context related variables + batch_size, num_rays, samples_per_ray, _ = depths.shape + device = depths.device + + # define sample points with depths + sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) + sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) + + # filter out-of-box samples + mask_inbox = \ + (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ + (sample_coordinates <= rendering_options['sampler_bbox_max']) + mask_inbox = mask_inbox.all(-1) + + # forward model according to all samples + _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) + + # set out-of-box samples to zeros(rgb) & -inf(sigma) + SAFE_GUARD = 3 + DATA_TYPE = _out['sigma'].dtype + colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) + densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD + colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] + + # reshape back + colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) + densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) + + return colors_pass, densities_pass + + def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): + # self.plane_axes = self.plane_axes.to(ray_origins.device) + + if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': + ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + else: + # Create stratified depth samples + depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) + + # Coarse Pass + colors_coarse, densities_coarse = self._forward_pass( + depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + # Fine Pass + N_importance = rendering_options['depth_resolution_importance'] + if N_importance > 0: + _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + depths_fine = self.sample_importance(depths_coarse, weights, N_importance) + + colors_fine, densities_fine = self._forward_pass( + depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, + planes=planes, decoder=decoder, rendering_options=rendering_options) + + all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, + depths_fine, colors_fine, densities_fine) + + rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) + else: + rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) + + return rgb_final, depth_final, weights.sum(2) + + def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): + plane_axes = self.plane_axes.to(planes.device) + sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) + + out = decoder(sampled_features, sample_directions) + if options.get('density_noise', 0) > 0: + out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] + return out + + def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): + out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) + out['sigma'] = self.activation_factory(options)(out['sigma']) + return out + + def sort_samples(self, all_depths, all_colors, all_densities): + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + return all_depths, all_colors, all_densities + + def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): + all_depths = torch.cat([depths1, depths2], dim = -2) + all_colors = torch.cat([colors1, colors2], dim = -2) + all_densities = torch.cat([densities1, densities2], dim = -2) + + if normals1 is not None and normals2 is not None: + all_normals = torch.cat([normals1, normals2], dim = -2) + else: + all_normals = None + + _, indices = torch.sort(all_depths, dim=-2) + all_depths = torch.gather(all_depths, -2, indices) + all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) + all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) + + if all_normals is not None: + all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) + return all_depths, all_colors, all_normals, all_densities + + return all_depths, all_colors, all_densities + + def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): + """ + Return depths of approximately uniformly spaced samples along rays. + """ + N, M, _ = ray_origins.shape + if disparity_space_sampling: + depths_coarse = torch.linspace(0, + 1, + depth_resolution, + device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = 1/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) + else: + if type(ray_start) == torch.Tensor: + depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] + else: + depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start)/(depth_resolution - 1) + depths_coarse += torch.rand_like(depths_coarse) * depth_delta + + return depths_coarse + + def sample_importance(self, z_vals, weights, N_importance): + """ + Return depths of importance sampled points along rays. See NeRF importance sampling for more. + """ + with torch.no_grad(): + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher + + # smooth weights + weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1) + weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 + + z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) + importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], + N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) + return importance_z_vals + + def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): + """ + Sample @N_importance samples from @bins with distribution defined by @weights. + Inputs: + bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" + weights: (N_rays, N_samples_) + N_importance: the number of samples to draw from the distribution + det: deterministic or not + eps: a small number to prevent division by zero + Outputs: + samples: the sampled samples + """ + N_rays, N_samples_ = weights.shape + weights = weights + eps # prevent division by zero (don't do inplace op!) + pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) + cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function + cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) + # padded to 0~1 inclusive + + if det: + u = torch.linspace(0, 1, N_importance, device=bins.device) + u = u.expand(N_rays, N_importance) + else: + u = torch.rand(N_rays, N_importance, device=bins.device) + u = u.contiguous() + + inds = torch.searchsorted(cdf, u, right=True) + below = torch.clamp_min(inds-1, 0) + above = torch.clamp_max(inds, N_samples_) + + inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) + cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) + bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) + + denom = cdf_g[...,1]-cdf_g[...,0] + denom[denom 0 and radius > 0 + + elevation = np.deg2rad(elevation) + + camera_positions = [] + for i in range(M): + azimuth = 2 * np.pi * i / M + x = radius * np.cos(elevation) * np.cos(azimuth) + y = radius * np.cos(elevation) * np.sin(azimuth) + z = radius * np.sin(elevation) + camera_positions.append([x, y, z]) + camera_positions = np.array(camera_positions) + camera_positions = torch.from_numpy(camera_positions).float() + extrinsics = center_looking_at_camera_pose(camera_positions) + return extrinsics + + +def FOV_to_intrinsics(fov, device='cpu'): + """ + Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. + Note the intrinsics are returned as normalized by image size, rather than in pixel units. + Assumes principal point is at image center. + """ + focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) + intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) + return intrinsics + + +def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): + """ + Get the input camera parameters. + """ + azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) + elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) + + c2ws = spherical_camera_pose(azimuths, elevations, radius) + c2ws = c2ws.float().flatten(-2) + + Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) + + extrinsics = c2ws[:, :12] + intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) + cameras = torch.cat([extrinsics, intrinsics], dim=-1) + + return cameras.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/src/utils/infer_util.py b/src/utils/infer_util.py new file mode 100644 index 0000000000000000000000000000000000000000..89cd078214afcc0e3dadafea5fbbb9ac005ea476 --- /dev/null +++ b/src/utils/infer_util.py @@ -0,0 +1,84 @@ +import os +import imageio +import rembg +import torch +import numpy as np +import PIL.Image +from PIL import Image +from typing import Any + + +def remove_background(image: PIL.Image.Image, + rembg_session: Any = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def resize_foreground( + image: PIL.Image.Image, + ratio: float, +) -> PIL.Image.Image: + image = np.array(image) + assert image.shape[-1] == 4 + alpha = np.where(image[..., 3] > 0) + y1, y2, x1, x2 = ( + alpha[0].min(), + alpha[0].max(), + alpha[1].min(), + alpha[1].max(), + ) + # crop the foreground + fg = image[y1:y2, x1:x2] + # pad to square + size = max(fg.shape[0], fg.shape[1]) + ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 + ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 + new_image = np.pad( + fg, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + + # compute padding according to the ratio + new_size = int(new_image.shape[0] / ratio) + # pad to size, double side + ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 + ph1, pw1 = new_size - size - ph0, new_size - size - pw0 + new_image = np.pad( + new_image, + ((ph0, ph1), (pw0, pw1), (0, 0)), + mode="constant", + constant_values=((0, 0), (0, 0), (0, 0)), + ) + new_image = PIL.Image.fromarray(new_image) + return new_image + + +def images_to_video( + images: torch.Tensor, + output_path: str, + fps: int = 30, +) -> None: + # images: (N, C, H, W) + video_dir = os.path.dirname(output_path) + video_name = os.path.basename(output_path) + os.makedirs(video_dir, exist_ok=True) + + frames = [] + for i in range(len(images)): + frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ + f"Frame shape mismatch: {frame.shape} vs {images.shape}" + assert frame.min() >= 0 and frame.max() <= 255, \ + f"Frame value out of range: {frame.min()} ~ {frame.max()}" + frames.append(frame) + imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) \ No newline at end of file diff --git a/src/utils/mesh_util.py b/src/utils/mesh_util.py new file mode 100644 index 0000000000000000000000000000000000000000..33b40dd7a34759ba81fe7037ee075882ad7a25bf --- /dev/null +++ b/src/utils/mesh_util.py @@ -0,0 +1,165 @@ +# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. + +import torch +import xatlas +import trimesh +import cv2 +import numpy as np +import nvdiffrast.torch as dr +from PIL import Image + + +def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fname): + mesh = trimesh.Trimesh( + vertices=pointnp_px3, + faces=facenp_fx3, + vertex_colors=colornp_px3, + ) + mesh.export(fname, 'obj') + + +def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): + import os + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = '%s/%s.mtl' % (fol, na) + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write('map_Kd %s.png\n' % na) + fid.close() + #### + + fid = open(fname, 'w') + fid.write('mtllib %s.mtl\n' % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write('vt %f %f\n' % (pp[0], pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + # save texture map + lo, hi = 0, 1 + img = np.asarray(texmap_hxwx3, dtype=np.float32) + img = (img - lo) * (255 / (hi - lo)) + img = img.clip(0, 255) + mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(np.float32) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') + + +def loadobj(meshfile): + v = [] + f = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if len(data) != 4: + continue + if data[0] == 'v': + v.append([float(d) for d in data[1:]]) + if data[0] == 'f': + data = [da.split('/')[0] for da in data] + f.append([int(d) for d in data[1:]]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + return pointnp_px3, facenp_fx3 + + +def loadobjtex(meshfile): + v = [] + vt = [] + f = [] + ft = [] + meshfp = open(meshfile, 'r') + for line in meshfp.readlines(): + data = line.strip().split(' ') + data = [da for da in data if len(da) > 0] + if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): + continue + if data[0] == 'v': + assert len(data) == 4 + + v.append([float(d) for d in data[1:]]) + if data[0] == 'vt': + if len(data) == 3 or len(data) == 4: + vt.append([float(d) for d in data[1:3]]) + if data[0] == 'f': + data = [da.split('/') for da in data] + if len(data) == 4: + f.append([int(d[0]) for d in data[1:]]) + ft.append([int(d[1]) for d in data[1:]]) + elif len(data) == 5: + idx1 = [1, 2, 3] + data1 = [data[i] for i in idx1] + f.append([int(d[0]) for d in data1]) + ft.append([int(d[1]) for d in data1]) + idx2 = [1, 3, 4] + data2 = [data[i] for i in idx2] + f.append([int(d[0]) for d in data2]) + ft.append([int(d[1]) for d in data2]) + meshfp.close() + + # torch need int64 + facenp_fx3 = np.array(f, dtype=np.int64) - 1 + ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 + pointnp_px3 = np.array(v, dtype=np.float32) + uvs = np.array(vt, dtype=np.float32) + return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 + + +# ============================================================================================== +def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') + + +def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): + vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + return uvs, mesh_tex_idx, gb_pos, mask diff --git a/src/utils/train_util.py b/src/utils/train_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2e65421bffa8cc42c1517e86f2dfd8183caf52ab --- /dev/null +++ b/src/utils/train_util.py @@ -0,0 +1,26 @@ +import importlib + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) diff --git a/util/text_img.py b/util/text_img.py deleted file mode 100644 index 8c309f17b09b9c2a02e9c892205d80f8fb60a764..0000000000000000000000000000000000000000 --- a/util/text_img.py +++ /dev/null @@ -1,17 +0,0 @@ -import spaces -import rembg -import torch -from diffusers import DiffusionPipeline - -pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16") -pipe.to("cuda") - -# Function to generate an image from text using diffusion -@spaces.GPU -def generate_image(prompt): - prompt += "no background, side view, minimalist shot" - - image = pipe(prompt).images[0] - image2 = rembg.remove(image) - - return image, image2 \ No newline at end of file diff --git a/util/v3d.py b/util/v3d.py deleted file mode 100644 index 05d1344bee8a66110f89f74fc31828adf55dbb70..0000000000000000000000000000000000000000 --- a/util/v3d.py +++ /dev/null @@ -1,202 +0,0 @@ -# TODO -import numpy as np -import argparse -import torch -from torchvision.utils import make_grid -import tempfile -import gradio as gr -import spaces -from omegaconf import OmegaConf -from einops import rearrange -from scripts.pub.V3D_512 import ( - sample_one, - get_batch, - get_unique_embedder_keys_from_conditioner, - load_model, -) -from sgm.util import default, instantiate_from_config -from safetensors.torch import load_file as load_safetensors -from PIL import Image -from kiui.op import recenter -from torchvision.transforms import ToTensor -from einops import rearrange, repeat -import rembg -import os -from glob import glob -from mediapy import write_video -from pathlib import Path - -@spaces.GPU -def generate_v3d( - image, - model, - clip_model, - ae_model, - num_frames, - num_steps, - decoding_t, - border_ratio, - ignore_alpha, - rembg_session, - output_folder, - min_cfg, - max_cfg, - device="cuda", -): - change_model_params(model, min_cfg, max_cfg) - # if image.mode == "RGBA": - # image = image.convert("RGB") - image = Image.fromarray(image) - w, h = image.size - - if border_ratio > 0: - if image.mode != "RGBA" or ignore_alpha: - image = image.convert("RGB") - image = np.asarray(image) - carved_image = rembg.remove(image, session=rembg_session) # [H, W, 4] - else: - image = np.asarray(image) - carved_image = image - mask = carved_image[..., -1] > 0 - image = recenter(carved_image, mask, border_ratio=border_ratio) - image = image.astype(np.float32) / 255.0 - if image.shape[-1] == 4: - image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) - image = Image.fromarray((image * 255).astype(np.uint8)) - else: - print("Ignore border ratio") - image = image.resize((512, 512)) - - image = ToTensor()(image) - image = image * 2.0 - 1.0 - - image = image.unsqueeze(0).to(device) - H, W = image.shape[2:] - assert image.shape[1] == 3 - F = 8 - C = 4 - shape = (num_frames, C, H // F, W // F) - - value_dict = {} - value_dict["motion_bucket_id"] = 0 - value_dict["fps_id"] = 0 - value_dict["cond_aug"] = 0.05 - value_dict["cond_frames_without_noise"] = clip_model(image) - value_dict["cond_frames"] = ae_model.encode(image) - value_dict["cond_frames"] += 0.05 * torch.randn_like(value_dict["cond_frames"]) - value_dict["cond_aug"] = 0.05 - - with torch.no_grad(): - with torch.autocast(device): - batch, batch_uc = get_batch( - get_unique_embedder_keys_from_conditioner(model.conditioner), - value_dict, - [1, num_frames], - T=num_frames, - device=device, - ) - c, uc = model.conditioner.get_unconditional_conditioning( - batch, - batch_uc=batch_uc, - force_uc_zero_embeddings=[ - "cond_frames", - "cond_frames_without_noise", - ], - ) - - for k in ["crossattn", "concat"]: - uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) - uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) - c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) - c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) - - randn = torch.randn(shape, device=device) - randn = randn.to(device) - - additional_model_inputs = {} - additional_model_inputs["image_only_indicator"] = torch.zeros( - 2, num_frames - ).to(device) - additional_model_inputs["num_video_frames"] = batch["num_video_frames"] - - def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) - - samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) - model.en_and_decode_n_samples_a_time = decoding_t - samples_x = model.decode_first_stage(samples_z) - samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) - - os.makedirs(output_folder, exist_ok=True) - base_count = len(glob(os.path.join(output_folder, "*.mp4"))) - video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") - - frames = ( - (rearrange(samples, "t c h w -> t h w c") * 255) - .cpu() - .numpy() - .astype(np.uint8) - ) - write_video(video_path, frames, fps=6) - - return video_path - - -def change_model_params(model, min_cfg, max_cfg): - model.sampler.guider.max_scale = max_cfg - model.sampler.guider.min_scale = min_cfg - -@spaces.GPU -def prep(): - model_config = "scripts/pub/configs/V3D_512.yaml" - num_frames = OmegaConf.load( - model_config - ).model.params.sampler_config.params.guider_config.params.num_frames - print("Detected num_frames:", num_frames) - num_steps = 25 - output_folder = "outputs/V3D_512" - device = "cuda" - - sd = load_safetensors("./ckpts/svd_xt.safetensors") - clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml") - clip_model = instantiate_from_config(clip_model_config).eval() - clip_sd = dict() - for k, v in sd.items(): - if "conditioner.embedders.0" in k: - clip_sd[k.replace("conditioner.embedders.0.", "")] = v - clip_model.load_state_dict(clip_sd) - clip_model = clip_model.to(device) - - ae_model_config = OmegaConf.load("configs/ae/video.yaml") - ae_model = instantiate_from_config(ae_model_config).eval() - encoder_sd = dict() - for k, v in sd.items(): - if "first_stage_model" in k: - encoder_sd[k.replace("first_stage_model.", "")] = v - ae_model.load_state_dict(encoder_sd) - ae_model = ae_model.to(device) - rembg_session = rembg.new_session() - - model, _ = load_model( - model_config, device, num_frames, num_steps, min_cfg=3.5, max_cfg=3.5 - ) - - def download_if_need(path, url): - if Path(path).exists(): - return - import wget - - path.parent.mkdir(parents=True, exist_ok=True) - wget.download(url, out=str(path)) - - download_if_need( - "ckpts/svd_xt.safetensors", - "https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors -O ckpts/svd_xt.safetensors", - ) - download_if_need( - "ckpts/V3D_512.ckpt", "https://huggingface.co/heheyas/V3D/resolve/main/V3D.ckpt" - ) - - return model, clip_model, ae_model, num_frames, num_steps, rembg_session, device diff --git a/zero123plus/pipeline.Py b/zero123plus/pipeline.Py new file mode 100644 index 0000000000000000000000000000000000000000..402fb4d02a0b49f80de6e36b804adc8144dd6198 --- /dev/null +++ b/zero123plus/pipeline.Py @@ -0,0 +1,406 @@ +from typing import Any, Dict, Optional +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers + +import numpy +import torch +import torch.nn as nn +import torch.utils.checkpoint +import torch.distributed +import transformers +from collections import OrderedDict +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + EulerAncestralDiscreteScheduler, + UNet2DConditionModel, + ImagePipelineOutput +) +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 +from diffusers.utils.import_utils import is_xformers_available + + +def to_rgb_image(maybe_rgba: Image.Image): + if maybe_rgba.mode == 'RGB': + return maybe_rgba + elif maybe_rgba.mode == 'RGBA': + rgba = maybe_rgba + img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) + img = Image.fromarray(img, 'RGB') + img.paste(rgba, mask=rgba.getchannel('A')) + return img + else: + raise ValueError("Unsupported image type.", maybe_rgba.mode) + + +class ReferenceOnlyAttnProc(torch.nn.Module): + def __init__( + self, + chained_proc, + enabled=False, + name=None + ) -> None: + super().__init__() + self.enabled = enabled + self.chained_proc = chained_proc + self.name = name + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, + mode="w", ref_dict: dict = None, is_cfg_guidance = False + ) -> Any: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + if self.enabled and is_cfg_guidance: + res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) + hidden_states = hidden_states[1:] + encoder_hidden_states = encoder_hidden_states[1:] + if self.enabled: + if mode == 'w': + ref_dict[self.name] = encoder_hidden_states + elif mode == 'r': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) + elif mode == 'm': + encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) + else: + assert False, mode + res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) + if self.enabled and is_cfg_guidance: + res = torch.cat([res0, res]) + return res + + +class RefOnlyNoisedUNet(torch.nn.Module): + def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: + super().__init__() + self.unet = unet + self.train_sched = train_sched + self.val_sched = val_sched + + unet_lora_attn_procs = dict() + for name, _ in unet.attn_processors.items(): + if torch.__version__ >= '2.0': + default_attn_proc = AttnProcessor2_0() + elif is_xformers_available(): + default_attn_proc = XFormersAttnProcessor() + else: + default_attn_proc = AttnProcessor() + unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( + default_attn_proc, enabled=name.endswith("attn1.processor"), name=name + ) + unet.set_attn_processor(unet_lora_attn_procs) + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): + if is_cfg_guidance: + encoder_hidden_states = encoder_hidden_states[1:] + class_labels = class_labels[1:] + self.unet( + noisy_cond_lat, timestep, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), + **kwargs + ) + + def forward( + self, sample, timestep, encoder_hidden_states, class_labels=None, + *args, cross_attention_kwargs, + down_block_res_samples=None, mid_block_res_sample=None, + **kwargs + ): + cond_lat = cross_attention_kwargs['cond_lat'] + is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) + noise = torch.randn_like(cond_lat) + if self.training: + noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) + noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) + else: + noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) + noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) + ref_dict = {} + self.forward_cond( + noisy_cond_lat, timestep, + encoder_hidden_states, class_labels, + ref_dict, is_cfg_guidance, **kwargs + ) + weight_dtype = self.unet.dtype + return self.unet( + sample, timestep, + encoder_hidden_states, *args, + class_labels=class_labels, + cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ] if down_block_res_samples is not None else None, + mid_block_additional_residual=( + mid_block_res_sample.to(dtype=weight_dtype) + if mid_block_res_sample is not None else None + ), + **kwargs + ) + + +def scale_latents(latents): + latents = (latents - 0.22) * 0.75 + return latents + + +def unscale_latents(latents): + latents = latents / 0.75 + 0.22 + return latents + + +def scale_image(image): + image = image * 0.5 / 0.8 + return image + + +def unscale_image(image): + image = image / 0.5 * 0.8 + return image + + +class DepthControlUNet(torch.nn.Module): + def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: + super().__init__() + self.unet = unet + if controlnet is None: + self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) + else: + self.controlnet = controlnet + DefaultAttnProc = AttnProcessor2_0 + if is_xformers_available(): + DefaultAttnProc = XFormersAttnProcessor + self.controlnet.set_attn_processor(DefaultAttnProc()) + self.conditioning_scale = conditioning_scale + + def __getattr__(self, name: str): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.unet, name) + + def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): + cross_attention_kwargs = dict(cross_attention_kwargs) + control_depth = cross_attention_kwargs.pop('control_depth') + down_block_res_samples, mid_block_res_sample = self.controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=control_depth, + conditioning_scale=self.conditioning_scale, + return_dict=False, + ) + return self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_res_samples=down_block_res_samples, + mid_block_res_sample=mid_block_res_sample, + cross_attention_kwargs=cross_attention_kwargs + ) + + +class ModuleListDict(torch.nn.Module): + def __init__(self, procs: dict) -> None: + super().__init__() + self.keys = sorted(procs.keys()) + self.values = torch.nn.ModuleList(procs[k] for k in self.keys) + + def __getitem__(self, key): + return self.values[self.keys.index(key)] + + +class SuperNet(torch.nn.Module): + def __init__(self, state_dict: Dict[str, torch.Tensor]): + super().__init__() + state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) + self.layers = torch.nn.ModuleList(state_dict.values()) + self.mapping = dict(enumerate(state_dict.keys())) + self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] + + # we add a hook to state_dict() and load_state_dict() so that the + # naming fits with `unet.attn_processors` + def map_to(module, state_dict, *args, **kwargs): + new_state_dict = {} + for key, value in state_dict.items(): + num = int(key.split(".")[1]) # 0 is always "layers" + new_key = key.replace(f"layers.{num}", module.mapping[num]) + new_state_dict[new_key] = value + + return new_state_dict + + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + return key.split('.')[0] + + def map_from(module, state_dict, *args, **kwargs): + all_keys = list(state_dict.keys()) + for key in all_keys: + replace_key = remap_key(key, state_dict) + new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") + state_dict[new_key] = state_dict[key] + del state_dict[key] + + self._register_state_dict_hook(map_to) + self._register_load_state_dict_pre_hook(map_from, with_module=True) + + +class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): + tokenizer: transformers.CLIPTokenizer + text_encoder: transformers.CLIPTextModel + vision_encoder: transformers.CLIPVisionModelWithProjection + + feature_extractor_clip: transformers.CLIPImageProcessor + unet: UNet2DConditionModel + scheduler: diffusers.schedulers.KarrasDiffusionSchedulers + + vae: AutoencoderKL + ramping: nn.Linear + + feature_extractor_vae: transformers.CLIPImageProcessor + + depth_transforms_multi = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]) + ]) + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + vision_encoder: transformers.CLIPVisionModelWithProjection, + feature_extractor_clip: CLIPImageProcessor, + feature_extractor_vae: CLIPImageProcessor, + ramping_coefficients: Optional[list] = None, + safety_checker=None, + ): + DiffusionPipeline.__init__(self) + + self.register_modules( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, + unet=unet, scheduler=scheduler, safety_checker=None, + vision_encoder=vision_encoder, + feature_extractor_clip=feature_extractor_clip, + feature_extractor_vae=feature_extractor_vae + ) + self.register_to_config(ramping_coefficients=ramping_coefficients) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def prepare(self): + train_sched = DDPMScheduler.from_config(self.scheduler.config) + if isinstance(self.unet, UNet2DConditionModel): + self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() + + def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): + self.prepare() + self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) + return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) + + def encode_condition_image(self, image: torch.Tensor): + image = self.vae.encode(image).latent_dist.sample() + return image + + @torch.no_grad() + def __call__( + self, + image: Image.Image = None, + prompt = "", + *args, + num_images_per_prompt: Optional[int] = 1, + guidance_scale=4.0, + depth_image: Image.Image = None, + output_type: Optional[str] = "pil", + width=640, + height=960, + num_inference_steps=28, + return_dict=True, + **kwargs + ): + self.prepare() + if image is None: + raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") + assert not isinstance(image, torch.Tensor) + image = to_rgb_image(image) + image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values + image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values + if depth_image is not None and hasattr(self.unet, "controlnet"): + depth_image = to_rgb_image(depth_image) + depth_image = self.depth_transforms_multi(depth_image).to( + device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype + ) + image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) + image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) + cond_lat = self.encode_condition_image(image) + if guidance_scale > 1: + negative_lat = self.encode_condition_image(torch.zeros_like(image)) + cond_lat = torch.cat([negative_lat, cond_lat]) + encoded = self.vision_encoder(image_2, output_hidden_states=False) + global_embeds = encoded.image_embeds + global_embeds = global_embeds.unsqueeze(-2) + + if hasattr(self, "encode_prompt"): + encoder_hidden_states = self.encode_prompt( + prompt, + self.device, + num_images_per_prompt, + False + )[0] + else: + encoder_hidden_states = self._encode_prompt( + prompt, + self.device, + num_images_per_prompt, + False + ) + ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) + encoder_hidden_states = encoder_hidden_states + global_embeds * ramp + cak = dict(cond_lat=cond_lat) + if hasattr(self.unet, "controlnet"): + cak['control_depth'] = depth_image + latents: torch.Tensor = super().__call__( + None, + *args, + cross_attention_kwargs=cak, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + prompt_embeds=encoder_hidden_states, + num_inference_steps=num_inference_steps, + output_type='latent', + width=width, + height=height, + **kwargs + ).images + latents = unscale_latents(latents) + if not output_type == "latent": + image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file