``` isn't optimized and could be much better?* There are several parts we didn't even have time to think about improving (yet). The performance you get with this prototype is probably a rather slow baseline for what is physically possible.
+
+- *Something is broken, how did this happen?* We tried hard to provide a solid and comprehensible basis to make use of the paper's method. We have refactored the code quite a bit, but we have limited capacity to test all possible usage scenarios. Thus, if part of the website, the code or the performance is lacking, please create an issue. If we find the time, we will do our best to address it.
diff --git a/recon/arguments/__init__.py b/recon/arguments/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..36afc54420cdea24425ff3a2953b826f339160bf
--- /dev/null
+++ b/recon/arguments/__init__.py
@@ -0,0 +1,132 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+from argparse import ArgumentParser, Namespace
+import sys
+import os
+
+
+class GroupParams:
+ pass
+
+
+class ParamGroup:
+ def __init__(self, parser: ArgumentParser, name: str, fill_none=False):
+ group = parser.add_argument_group(name)
+ for key, value in vars(self).items():
+ shorthand = False
+ if key.startswith("_"):
+ shorthand = True
+ key = key[1:]
+ t = type(value)
+ value = value if not fill_none else None
+ if shorthand:
+ if t == bool:
+ group.add_argument(
+ "--" + key, ("-" + key[0:1]), default=value, action="store_true"
+ )
+ else:
+ group.add_argument(
+ "--" + key, ("-" + key[0:1]), default=value, type=t
+ )
+ else:
+ if t == bool:
+ group.add_argument("--" + key, default=value, action="store_true")
+ else:
+ group.add_argument("--" + key, default=value, type=t)
+
+ def extract(self, args):
+ group = GroupParams()
+ for arg in vars(args).items():
+ if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
+ setattr(group, arg[0], arg[1])
+ return group
+
+
+class ModelParams(ParamGroup):
+ def __init__(self, parser, sentinel=False):
+ self.sh_degree = 3
+ self._source_path = ""
+ self._model_path = ""
+ # self._images = "images"
+ self._resolution = -1
+ self._white_background = False
+ self.data_device = "cuda"
+ self.eval = False
+ self.num_frames = 18
+ self.radius = 2.0
+ self.elevation = 0.0
+ self.fov = 60.0
+ self.reso = 512
+ self.images = []
+ self.masks = []
+ self.num_pts = 100_000
+ self.train = True
+ super().__init__(parser, "Loading Parameters", sentinel)
+
+ def extract(self, args):
+ g = super().extract(args)
+ g.source_path = os.path.abspath(g.source_path)
+ return g
+
+
+class PipelineParams(ParamGroup):
+ def __init__(self, parser):
+ self.convert_SHs_python = False
+ self.compute_cov3D_python = False
+ self.debug = False
+ super().__init__(parser, "Pipeline Parameters")
+
+
+class OptimizationParams(ParamGroup):
+ def __init__(self, parser):
+ self.iterations = 30_000
+ self.position_lr_init = 0.00016
+ self.position_lr_final = 0.0000016
+ self.position_lr_delay_mult = 0.01
+ self.position_lr_max_steps = 30_000
+ self.feature_lr = 0.0025
+ self.opacity_lr = 0.05
+ self.scaling_lr = 0.005
+ self.rotation_lr = 0.001
+ self.percent_dense = 0.01
+ self.lambda_dssim = 0.2
+ self.lambda_lpips = 0.2
+ self.densification_interval = 100
+ self.opacity_reset_interval = 3000
+ self.densify_from_iter = 500
+ self.densify_until_iter = 15_000
+ self.densify_grad_threshold = 0.0002
+ self.random_background = False
+ super().__init__(parser, "Optimization Parameters")
+
+
+def get_combined_args(parser: ArgumentParser):
+ cmdlne_string = sys.argv[1:]
+ cfgfile_string = "Namespace()"
+ args_cmdline = parser.parse_args(cmdlne_string)
+
+ try:
+ cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
+ print("Looking for config file in", cfgfilepath)
+ with open(cfgfilepath) as cfg_file:
+ print("Config file found: {}".format(cfgfilepath))
+ cfgfile_string = cfg_file.read()
+ except TypeError:
+ print("Config file not found at")
+ pass
+ args_cfgfile = eval(cfgfile_string)
+
+ merged_dict = vars(args_cfgfile).copy()
+ for k, v in vars(args_cmdline).items():
+ if v != None:
+ merged_dict[k] = v
+ return Namespace(**merged_dict)
diff --git a/recon/convert.py b/recon/convert.py
new file mode 100644
index 0000000000000000000000000000000000000000..78948848f4849a88d686542790cd04f34f34beb0
--- /dev/null
+++ b/recon/convert.py
@@ -0,0 +1,124 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import logging
+from argparse import ArgumentParser
+import shutil
+
+# This Python script is based on the shell converter script provided in the MipNerF 360 repository.
+parser = ArgumentParser("Colmap converter")
+parser.add_argument("--no_gpu", action='store_true')
+parser.add_argument("--skip_matching", action='store_true')
+parser.add_argument("--source_path", "-s", required=True, type=str)
+parser.add_argument("--camera", default="OPENCV", type=str)
+parser.add_argument("--colmap_executable", default="", type=str)
+parser.add_argument("--resize", action="store_true")
+parser.add_argument("--magick_executable", default="", type=str)
+args = parser.parse_args()
+colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
+magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
+use_gpu = 1 if not args.no_gpu else 0
+
+if not args.skip_matching:
+ os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
+
+ ## Feature extraction
+ feat_extracton_cmd = colmap_command + " feature_extractor "\
+ "--database_path " + args.source_path + "/distorted/database.db \
+ --image_path " + args.source_path + "/input \
+ --ImageReader.single_camera 1 \
+ --ImageReader.camera_model " + args.camera + " \
+ --SiftExtraction.use_gpu " + str(use_gpu)
+ exit_code = os.system(feat_extracton_cmd)
+ if exit_code != 0:
+ logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+ ## Feature matching
+ feat_matching_cmd = colmap_command + " exhaustive_matcher \
+ --database_path " + args.source_path + "/distorted/database.db \
+ --SiftMatching.use_gpu " + str(use_gpu)
+ exit_code = os.system(feat_matching_cmd)
+ if exit_code != 0:
+ logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+ ### Bundle adjustment
+ # The default Mapper tolerance is unnecessarily large,
+ # decreasing it speeds up bundle adjustment steps.
+ mapper_cmd = (colmap_command + " mapper \
+ --database_path " + args.source_path + "/distorted/database.db \
+ --image_path " + args.source_path + "/input \
+ --output_path " + args.source_path + "/distorted/sparse \
+ --Mapper.ba_global_function_tolerance=0.000001")
+ exit_code = os.system(mapper_cmd)
+ if exit_code != 0:
+ logging.error(f"Mapper failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+### Image undistortion
+## We need to undistort our images into ideal pinhole intrinsics.
+img_undist_cmd = (colmap_command + " image_undistorter \
+ --image_path " + args.source_path + "/input \
+ --input_path " + args.source_path + "/distorted/sparse/0 \
+ --output_path " + args.source_path + "\
+ --output_type COLMAP")
+exit_code = os.system(img_undist_cmd)
+if exit_code != 0:
+ logging.error(f"Mapper failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+files = os.listdir(args.source_path + "/sparse")
+os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
+# Copy each file from the source directory to the destination directory
+for file in files:
+ if file == '0':
+ continue
+ source_file = os.path.join(args.source_path, "sparse", file)
+ destination_file = os.path.join(args.source_path, "sparse", "0", file)
+ shutil.move(source_file, destination_file)
+
+if(args.resize):
+ print("Copying and resizing...")
+
+ # Resize images.
+ os.makedirs(args.source_path + "/images_2", exist_ok=True)
+ os.makedirs(args.source_path + "/images_4", exist_ok=True)
+ os.makedirs(args.source_path + "/images_8", exist_ok=True)
+ # Get the list of files in the source directory
+ files = os.listdir(args.source_path + "/images")
+ # Copy each file from the source directory to the destination directory
+ for file in files:
+ source_file = os.path.join(args.source_path, "images", file)
+
+ destination_file = os.path.join(args.source_path, "images_2", file)
+ shutil.copy2(source_file, destination_file)
+ exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
+ if exit_code != 0:
+ logging.error(f"50% resize failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+ destination_file = os.path.join(args.source_path, "images_4", file)
+ shutil.copy2(source_file, destination_file)
+ exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
+ if exit_code != 0:
+ logging.error(f"25% resize failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+ destination_file = os.path.join(args.source_path, "images_8", file)
+ shutil.copy2(source_file, destination_file)
+ exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
+ if exit_code != 0:
+ logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
+ exit(exit_code)
+
+print("Done.")
diff --git a/recon/convert_mesh.py b/recon/convert_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/recon/convert_nerf_mesh.py b/recon/convert_nerf_mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c5cac9b7311e8832bedfc40a5a7ea8d63036df6
--- /dev/null
+++ b/recon/convert_nerf_mesh.py
@@ -0,0 +1,539 @@
+import os
+import tyro
+import tqdm
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from lgm.options import AllConfigs, Options
+from lgm.gs import GaussianRenderer
+
+import mcubes
+import nerfacc
+import nvdiffrast.torch as dr
+
+import kiui
+from kiui.mesh import Mesh
+from kiui.mesh_utils import clean_mesh, decimate_mesh
+from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
+from kiui.op import uv_padding, safe_normalize, inverse_sigmoid
+from kiui.cam import orbit_camera, get_perspective
+from kiui.nn import MLP, trunc_exp
+from kiui.gridencoder import GridEncoder
+
+
+def get_rays(pose, h, w, fovy, opengl=True):
+ x, y = torch.meshgrid(
+ torch.arange(w, device=pose.device),
+ torch.arange(h, device=pose.device),
+ indexing="xy",
+ )
+ x = x.flatten()
+ y = y.flatten()
+
+ cx = w * 0.5
+ cy = h * 0.5
+ focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
+
+ camera_dirs = F.pad(
+ torch.stack(
+ [
+ (x - cx + 0.5) / focal,
+ (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
+ ],
+ dim=-1,
+ ),
+ (0, 1),
+ value=(-1.0 if opengl else 1.0),
+ ) # [hw, 3]
+
+ rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
+ rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
+
+ rays_d = safe_normalize(rays_d)
+
+ return rays_o, rays_d
+
+
+# Triple renderer of gaussians, gaussian, and diso mesh.
+# gaussian --> nerf --> mesh
+class Converter(nn.Module):
+ def __init__(self, opt: Options):
+ super().__init__()
+
+ self.opt = opt
+ self.device = torch.device("cuda")
+
+ # gs renderer
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy))
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device)
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[2, 3] = 1
+
+ self.gs_renderer = GaussianRenderer(opt)
+
+ self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device)
+
+ # nerf renderer
+ if not self.opt.force_cuda_rast:
+ self.glctx = dr.RasterizeGLContext()
+ else:
+ self.glctx = dr.RasterizeCudaContext()
+
+ self.step = 0
+ self.render_step_size = 5e-3
+ self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device)
+ self.estimator = nerfacc.OccGridEstimator(
+ roi_aabb=self.aabb, resolution=64, levels=1
+ )
+
+ self.encoder_density = GridEncoder(
+ num_levels=12
+ ) # VMEncoder(output_dim=16, mode='sum')
+ self.encoder = GridEncoder(num_levels=12)
+ self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False)
+ self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False)
+
+ # mesh renderer
+ self.proj = (
+ torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device)
+ )
+ self.v = self.f = None
+ self.vt = self.ft = None
+ self.deform = None
+ self.albedo = None
+
+ @torch.no_grad()
+ def render_gs(self, pose):
+ cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device)
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
+
+ # cameras needed by gaussian rasterizer
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
+ cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4]
+ cam_pos = -cam_poses[:, :3, 3] # [V, 3]
+
+ out = self.gs_renderer.render(
+ self.gaussians.unsqueeze(0),
+ cam_view.unsqueeze(0),
+ cam_view_proj.unsqueeze(0),
+ cam_pos.unsqueeze(0),
+ )
+ image = out["image"].squeeze(1).squeeze(0) # [C, H, W]
+ alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W]
+
+ return image, alpha
+
+ def get_density(self, xs):
+ # xs: [..., 3]
+ prefix = xs.shape[:-1]
+ xs = xs.view(-1, 3)
+ feats = self.encoder_density(xs)
+ density = trunc_exp(self.mlp_density(feats))
+ density = density.view(*prefix, 1)
+ return density
+
+ def render_nerf(self, pose):
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
+
+ # get rays
+ resolution = self.opt.output_size
+ rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy)
+
+ # update occ grid
+ if self.training:
+
+ def occ_eval_fn(xs):
+ sigmas = self.get_density(xs)
+ return self.render_step_size * sigmas
+
+ self.estimator.update_every_n_steps(
+ self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8
+ )
+ self.step += 1
+
+ # render
+ def sigma_fn(t_starts, t_ends, ray_indices):
+ t_origins = rays_o[ray_indices]
+ t_dirs = rays_d[ray_indices]
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
+ sigmas = self.get_density(xs)
+ return sigmas.squeeze(-1)
+
+ with torch.no_grad():
+ ray_indices, t_starts, t_ends = self.estimator.sampling(
+ rays_o,
+ rays_d,
+ sigma_fn=sigma_fn,
+ near_plane=0.01,
+ far_plane=100,
+ render_step_size=self.render_step_size,
+ stratified=self.training,
+ cone_angle=0,
+ )
+
+ t_origins = rays_o[ray_indices]
+ t_dirs = rays_d[ray_indices]
+ xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0
+ sigmas = self.get_density(xs).squeeze(-1)
+ rgbs = torch.sigmoid(self.mlp(self.encoder(xs)))
+
+ n_rays = rays_o.shape[0]
+ weights, trans, alphas = nerfacc.render_weight_from_density(
+ t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays
+ )
+ color = nerfacc.accumulate_along_rays(
+ weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays
+ )
+ alpha = nerfacc.accumulate_along_rays(
+ weights, values=None, ray_indices=ray_indices, n_rays=n_rays
+ )
+
+ color = color + 1 * (1.0 - alpha)
+
+ color = (
+ color.view(resolution, resolution, 3)
+ .clamp(0, 1)
+ .permute(2, 0, 1)
+ .contiguous()
+ )
+ alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous()
+
+ return color, alpha
+
+ def fit_nerf(self, iters=512, resolution=128):
+ self.opt.output_size = resolution
+
+ optimizer = torch.optim.Adam(
+ [
+ {"params": self.encoder_density.parameters(), "lr": 1e-2},
+ {"params": self.encoder.parameters(), "lr": 1e-2},
+ {"params": self.mlp_density.parameters(), "lr": 1e-3},
+ {"params": self.mlp.parameters(), "lr": 1e-3},
+ ]
+ )
+
+ print(f"[INFO] fitting nerf...")
+ pbar = tqdm.trange(iters)
+ for i in pbar:
+ ver = np.random.randint(-45, 45)
+ hor = np.random.randint(-180, 180)
+ rad = np.random.uniform(1.5, 3.0)
+
+ pose = orbit_camera(ver, hor, rad)
+
+ image_gt, alpha_gt = self.render_gs(pose)
+ image_pred, alpha_pred = self.render_nerf(pose)
+
+ # if i % 200 == 0:
+ # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred)
+
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
+ alpha_pred, alpha_gt
+ )
+ loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss()
+
+ loss.backward()
+ self.encoder_density.grad_total_variation(1e-8)
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
+
+ print(f"[INFO] finished fitting nerf!")
+
+ def render_mesh(self, pose):
+ h = w = self.opt.output_size
+
+ v = self.v + self.deform
+ f = self.f
+
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
+
+ # get v_clip and render rgb
+ v_cam = (
+ torch.matmul(
+ F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T
+ )
+ .float()
+ .unsqueeze(0)
+ )
+ v_clip = v_cam @ self.proj.T
+
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w))
+
+ alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1]
+ alpha = (
+ dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0)
+ ) # [H, W] important to enable gradients!
+
+ if self.albedo is None:
+ xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3]
+ xyzs = xyzs.view(-1, 3)
+ mask = (alpha > 0).view(-1)
+ image = torch.zeros_like(xyzs, dtype=torch.float32)
+ if mask.any():
+ masked_albedo = torch.sigmoid(
+ self.mlp(self.encoder(xyzs[mask].detach(), bound=1))
+ )
+ image[mask] = masked_albedo.float()
+ else:
+ texc, texc_db = dr.interpolate(
+ self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all"
+ )
+ image = torch.sigmoid(
+ dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db)
+ ) # [1, H, W, 3]
+
+ image = image.view(1, h, w, 3)
+ # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1)
+ image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W]
+ image = alpha * image + (1 - alpha)
+
+ return image, alpha
+
+ def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4):
+ self.opt.output_size = resolution
+
+ # init mesh from nerf
+ grid_size = 256
+ sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32)
+
+ S = 128
+ density_thresh = 10
+
+ X = torch.linspace(-1, 1, grid_size).split(S)
+ Y = torch.linspace(-1, 1, grid_size).split(S)
+ Z = torch.linspace(-1, 1, grid_size).split(S)
+
+ for xi, xs in enumerate(X):
+ for yi, ys in enumerate(Y):
+ for zi, zs in enumerate(Z):
+ xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij")
+ pts = torch.cat(
+ [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
+ dim=-1,
+ ) # [S, 3]
+ val = self.get_density(pts.to(self.device))
+ sigmas[
+ xi * S : xi * S + len(xs),
+ yi * S : yi * S + len(ys),
+ zi * S : zi * S + len(zs),
+ ] = (
+ val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()
+ ) # [S, 1] --> [x, y, z]
+
+ print(
+ f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})"
+ )
+
+ vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh)
+ vertices = vertices / (grid_size - 1.0) * 2 - 1
+
+ # clean
+ vertices = vertices.astype(np.float32)
+ triangles = triangles.astype(np.int32)
+ vertices, triangles = clean_mesh(
+ vertices, triangles, remesh=True, remesh_size=0.01
+ )
+ if triangles.shape[0] > decimate_target:
+ vertices, triangles = decimate_mesh(
+ vertices, triangles, decimate_target, optimalplacement=False
+ )
+
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
+
+ # fit mesh from gs
+ lr_factor = 1
+ optimizer = torch.optim.Adam(
+ [
+ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
+ {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
+ {"params": self.deform, "lr": 1e-4},
+ ]
+ )
+
+ print(f"[INFO] fitting mesh...")
+ pbar = tqdm.trange(iters)
+ for i in pbar:
+ ver = np.random.randint(-10, 10)
+ hor = np.random.randint(-180, 180)
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
+
+ pose = orbit_camera(ver, hor, rad)
+
+ image_gt, alpha_gt = self.render_gs(pose)
+ image_pred, alpha_pred = self.render_mesh(pose)
+
+ loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(
+ alpha_pred, alpha_gt
+ )
+ # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f)
+ loss_normal = normal_consistency(self.v + self.deform, self.f)
+ loss_offsets = (self.deform**2).sum(-1).mean()
+ loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets
+
+ loss.backward()
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ # remesh periodically
+ if i > 0 and i % 512 == 0:
+ vertices = (self.v + self.deform).detach().cpu().numpy()
+ triangles = self.f.detach().cpu().numpy()
+ vertices, triangles = clean_mesh(
+ vertices, triangles, remesh=True, remesh_size=0.01
+ )
+ if triangles.shape[0] > decimate_target:
+ vertices, triangles = decimate_mesh(
+ vertices, triangles, decimate_target, optimalplacement=False
+ )
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
+ self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device)
+ lr_factor *= 0.5
+ optimizer = torch.optim.Adam(
+ [
+ {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor},
+ {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor},
+ {"params": self.deform, "lr": 1e-4},
+ ]
+ )
+
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
+
+ # last clean
+ vertices = (self.v + self.deform).detach().cpu().numpy()
+ triangles = self.f.detach().cpu().numpy()
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
+ self.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
+ self.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
+ self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device))
+
+ print(f"[INFO] finished fitting mesh!")
+
+ # uv mesh refine
+ def fit_mesh_uv(
+ self, iters=512, resolution=512, texture_resolution=1024, padding=2
+ ):
+ self.opt.output_size = resolution
+
+ # unwrap uv
+ print(f"[INFO] uv unwrapping...")
+ mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device)
+ mesh.auto_normal()
+ mesh.auto_uv()
+
+ self.vt = mesh.vt
+ self.ft = mesh.ft
+
+ # render uv maps
+ h = w = texture_resolution
+ uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1]
+ uv = torch.cat(
+ (uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1
+ ) # [N, 4]
+
+ rast, _ = dr.rasterize(
+ self.glctx, uv.unsqueeze(0), mesh.ft, (h, w)
+ ) # [1, h, w, 4]
+ xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3]
+ mask, _ = dr.interpolate(
+ torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f
+ ) # [1, h, w, 1]
+
+ # masked query
+ xyzs = xyzs.view(-1, 3)
+ mask = (mask > 0).view(-1)
+
+ albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32)
+
+ if mask.any():
+ print(f"[INFO] querying texture...")
+
+ xyzs = xyzs[mask] # [M, 3]
+
+ # batched inference to avoid OOM
+ batch = []
+ head = 0
+ while head < xyzs.shape[0]:
+ tail = min(head + 640000, xyzs.shape[0])
+ batch.append(
+ torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float()
+ )
+ head += 640000
+
+ albedo[mask] = torch.cat(batch, dim=0)
+
+ albedo = albedo.view(h, w, -1)
+ mask = mask.view(h, w)
+ albedo = uv_padding(albedo, mask, padding)
+
+ # optimize texture
+ self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device)
+
+ optimizer = torch.optim.Adam(
+ [
+ {"params": self.albedo, "lr": 1e-3},
+ ]
+ )
+
+ print(f"[INFO] fitting mesh texture...")
+ pbar = tqdm.trange(iters)
+ for i in pbar:
+ # shrink to front view as we care more about it...
+ ver = np.random.randint(-5, 5)
+ hor = np.random.randint(-15, 15)
+ rad = self.opt.cam_radius # np.random.uniform(1, 2)
+
+ pose = orbit_camera(ver, hor, rad)
+
+ image_gt, alpha_gt = self.render_gs(pose)
+ image_pred, alpha_pred = self.render_mesh(pose)
+
+ loss_mse = F.mse_loss(image_pred, image_gt)
+ loss = loss_mse
+
+ loss.backward()
+
+ optimizer.step()
+ optimizer.zero_grad()
+
+ pbar.set_description(f"MSE = {loss_mse.item():.6f}")
+
+ print(f"[INFO] finished fitting mesh texture!")
+
+ @torch.no_grad()
+ def export_mesh(self, path):
+ mesh = Mesh(
+ v=self.v,
+ f=self.f,
+ vt=self.vt,
+ ft=self.ft,
+ albedo=torch.sigmoid(self.albedo),
+ device=self.device,
+ )
+ mesh.auto_normal()
+ mesh.write(path)
+
+
+opt = tyro.cli(AllConfigs)
+
+# load a saved ply and convert to mesh
+assert opt.test_path.endswith(
+ ".ply"
+), "--test_path must be a .ply file saved by infer.py"
+
+converter = Converter(opt).cuda()
+converter.fit_nerf()
+converter.fit_mesh()
+converter.fit_mesh_uv()
+converter.export_mesh(opt.test_path.replace(".ply", ".glb"))
diff --git a/recon/convert_to_blender.py b/recon/convert_to_blender.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f3d85dc1eff4efbdfa1708f5a57e33c2b85c51b
--- /dev/null
+++ b/recon/convert_to_blender.py
@@ -0,0 +1,102 @@
+import json
+import torch
+from scene import Scene
+from pathlib import Path
+from PIL import Image
+import numpy as np
+import sys
+import os
+from tqdm import tqdm
+from os import makedirs
+from gaussian_renderer import render
+import torchvision
+from utils.general_utils import safe_state
+from argparse import ArgumentParser
+from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
+from gaussian_renderer import GaussianModel
+from mediapy import write_video
+from tqdm import tqdm
+from einops import rearrange
+from utils.camera_utils import get_uniform_poses
+from mediapy import write_image
+
+
+@torch.no_grad()
+def render_spiral(dataset, opt, pipe, model_path):
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False)
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+ viewpoint_stack = scene.getTrainCameras().copy()
+ views = []
+ alphas = []
+ for view_cam in tqdm(viewpoint_stack):
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+ render_pkg = render(view_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+ views.append(image)
+ alphas.append(render_pkg["alpha"])
+ views = torch.stack(views)
+ alphas = torch.stack(alphas)
+
+ png_images = (
+ (torch.cat([views, alphas], dim=1).clamp(0.0, 1.0) * 255)
+ .cpu()
+ .numpy()
+ .astype(np.uint8)
+ )
+ png_images = rearrange(png_images, "t c h w -> t h w c")
+
+ poses = get_uniform_poses(
+ dataset.num_frames, dataset.radius, dataset.elevation, opengl=True
+ )
+ camera_angle_x = np.deg2rad(dataset.fov)
+ name = Path(dataset.model_path).stem
+ meta_dir = Path(f"blenders/{name}")
+ meta_dir.mkdir(exist_ok=True, parents=True)
+ meta = {}
+ meta["camera_angle_x"] = camera_angle_x
+ meta["frames"] = []
+ for idx, (pose, image) in enumerate(zip(poses, png_images)):
+ this_frames = {}
+ this_frames["file_path"] = f"{idx:06d}"
+ this_frames["transform_matrix"] = pose.tolist()
+ meta["frames"].append(this_frames)
+ write_image(meta_dir / f"{idx:06d}.png", image)
+
+ with open(meta_dir / "transforms_train.json", "w") as f:
+ json.dump(meta, f, indent=4)
+ with open(meta_dir / "transforms_val.json", "w") as f:
+ json.dump(meta, f, indent=4)
+ with open(meta_dir / "transforms_test.json", "w") as f:
+ json.dump(meta, f, indent=4)
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--iteration", default=-1, type=int)
+ parser.add_argument("--skip_train", action="store_true")
+ parser.add_argument("--skip_test", action="store_true")
+ parser.add_argument("--quiet", action="store_true")
+ args = parser.parse_args(sys.argv[1:])
+ print("Rendering " + args.model_path)
+ lp = lp.extract(args)
+ fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8))
+ lp.images = [fake_image] * args.num_frames
+
+ # Initialize system state (RNG)
+ render_spiral(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ model_path=args.model_path,
+ )
diff --git a/recon/environment.yml b/recon/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b17a50f6570ee66e157c8fd168b62c45bfba2fee
--- /dev/null
+++ b/recon/environment.yml
@@ -0,0 +1,17 @@
+name: gaussian_splatting
+channels:
+ - pytorch
+ - conda-forge
+ - defaults
+dependencies:
+ - cudatoolkit=11.6
+ - plyfile=0.8.1
+ - python=3.7.13
+ - pip=22.3.1
+ - pytorch=1.12.1
+ - torchaudio=0.12.1
+ - torchvision=0.13.1
+ - tqdm
+ - pip:
+ - submodules/diff-gaussian-rasterization
+ - submodules/simple-knn
\ No newline at end of file
diff --git a/recon/full_eval.py b/recon/full_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fbb12247724b25563e215b4409ded9af1cbdd04
--- /dev/null
+++ b/recon/full_eval.py
@@ -0,0 +1,75 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+from argparse import ArgumentParser
+
+mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
+mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
+tanks_and_temples_scenes = ["truck", "train"]
+deep_blending_scenes = ["drjohnson", "playroom"]
+
+parser = ArgumentParser(description="Full evaluation script parameters")
+parser.add_argument("--skip_training", action="store_true")
+parser.add_argument("--skip_rendering", action="store_true")
+parser.add_argument("--skip_metrics", action="store_true")
+parser.add_argument("--output_path", default="./eval")
+args, _ = parser.parse_known_args()
+
+all_scenes = []
+all_scenes.extend(mipnerf360_outdoor_scenes)
+all_scenes.extend(mipnerf360_indoor_scenes)
+all_scenes.extend(tanks_and_temples_scenes)
+all_scenes.extend(deep_blending_scenes)
+
+if not args.skip_training or not args.skip_rendering:
+ parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
+ parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
+ parser.add_argument("--deepblending", "-db", required=True, type=str)
+ args = parser.parse_args()
+
+if not args.skip_training:
+ common_args = " --quiet --eval --test_iterations -1 "
+ for scene in mipnerf360_outdoor_scenes:
+ source = args.mipnerf360 + "/" + scene
+ os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
+ for scene in mipnerf360_indoor_scenes:
+ source = args.mipnerf360 + "/" + scene
+ os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
+ for scene in tanks_and_temples_scenes:
+ source = args.tanksandtemples + "/" + scene
+ os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
+ for scene in deep_blending_scenes:
+ source = args.deepblending + "/" + scene
+ os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
+
+if not args.skip_rendering:
+ all_sources = []
+ for scene in mipnerf360_outdoor_scenes:
+ all_sources.append(args.mipnerf360 + "/" + scene)
+ for scene in mipnerf360_indoor_scenes:
+ all_sources.append(args.mipnerf360 + "/" + scene)
+ for scene in tanks_and_temples_scenes:
+ all_sources.append(args.tanksandtemples + "/" + scene)
+ for scene in deep_blending_scenes:
+ all_sources.append(args.deepblending + "/" + scene)
+
+ common_args = " --quiet --eval --skip_train"
+ for scene, source in zip(all_scenes, all_sources):
+ os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
+ os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
+
+if not args.skip_metrics:
+ scenes_string = ""
+ for scene in all_scenes:
+ scenes_string += "\"" + args.output_path + "/" + scene + "\" "
+
+ os.system("python metrics.py -m " + scenes_string)
\ No newline at end of file
diff --git a/recon/gaussian_renderer/__init__.py b/recon/gaussian_renderer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..10bb78a587ecde17969800e3f781402d1d9a42f7
--- /dev/null
+++ b/recon/gaussian_renderer/__init__.py
@@ -0,0 +1,134 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import math
+from diff_gaussian_rasterization import (
+ GaussianRasterizationSettings,
+ GaussianRasterizer,
+)
+from scene.gaussian_model import GaussianModel
+from utils.sh_utils import eval_sh
+
+
+def render(
+ viewpoint_camera,
+ pc: GaussianModel,
+ pipe,
+ bg_color: torch.Tensor,
+ scaling_modifier=1.0,
+ override_color=None,
+):
+ """
+ Render the scene.
+
+ Background tensor (bg_color) must be on GPU!
+ """
+
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
+ screenspace_points = (
+ torch.zeros_like(
+ pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda"
+ )
+ + 0
+ )
+ try:
+ screenspace_points.retain_grad()
+ except:
+ pass
+
+ # Set up rasterization configuration
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
+
+ raster_settings = GaussianRasterizationSettings(
+ image_height=int(viewpoint_camera.image_height),
+ image_width=int(viewpoint_camera.image_width),
+ tanfovx=tanfovx,
+ tanfovy=tanfovy,
+ bg=bg_color,
+ scale_modifier=scaling_modifier,
+ viewmatrix=viewpoint_camera.world_view_transform,
+ projmatrix=viewpoint_camera.full_proj_transform,
+ sh_degree=pc.active_sh_degree,
+ campos=viewpoint_camera.camera_center,
+ prefiltered=False,
+ debug=pipe.debug,
+ )
+
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+ means3D = pc.get_xyz
+ means2D = screenspace_points
+ opacity = pc.get_opacity
+
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
+ # scaling / rotation by the rasterizer.
+ scales = None
+ rotations = None
+ cov3D_precomp = None
+ if pipe.compute_cov3D_python:
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
+ else:
+ scales = pc.get_scaling
+ rotations = pc.get_rotation
+
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
+ shs = None
+ colors_precomp = None
+ if override_color is None:
+ if pipe.convert_SHs_python:
+ shs_view = pc.get_features.transpose(1, 2).view(
+ -1, 3, (pc.max_sh_degree + 1) ** 2
+ )
+ dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(
+ pc.get_features.shape[0], 1
+ )
+ dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
+ sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
+ else:
+ shs = pc.get_features
+ else:
+ colors_precomp = override_color
+
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
+ rendered_image, radii, depth, alpha = rasterizer(
+ means3D=means3D,
+ means2D=means2D,
+ shs=shs,
+ colors_precomp=colors_precomp,
+ opacities=opacity,
+ scales=scales,
+ rotations=rotations,
+ cov3D_precomp=cov3D_precomp,
+ )
+ # rendered_image, radii = rasterizer(
+ # means3D = means3D,
+ # means2D = means2D,
+ # shs = shs,
+ # colors_precomp = colors_precomp,
+ # opacities = opacity,
+ # scales = scales,
+ # rotations = rotations,
+ # cov3D_precomp = cov3D_precomp)
+
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
+ # They will be excluded from value updates used in the splitting criteria.
+ return {
+ "render": rendered_image,
+ "viewspace_points": screenspace_points,
+ "visibility_filter": radii > 0,
+ "radii": radii,
+ "depth": depth,
+ "alpha": alpha,
+ }
diff --git a/recon/gaussian_renderer/network_gui.py b/recon/gaussian_renderer/network_gui.py
new file mode 100644
index 0000000000000000000000000000000000000000..df2f9dae782b24527ae5b09f91ca4009361de53f
--- /dev/null
+++ b/recon/gaussian_renderer/network_gui.py
@@ -0,0 +1,86 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import traceback
+import socket
+import json
+from scene.cameras import MiniCam
+
+host = "127.0.0.1"
+port = 6009
+
+conn = None
+addr = None
+
+listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+def init(wish_host, wish_port):
+ global host, port, listener
+ host = wish_host
+ port = wish_port
+ listener.bind((host, port))
+ listener.listen()
+ listener.settimeout(0)
+
+def try_connect():
+ global conn, addr, listener
+ try:
+ conn, addr = listener.accept()
+ print(f"\nConnected by {addr}")
+ conn.settimeout(None)
+ except Exception as inst:
+ pass
+
+def read():
+ global conn
+ messageLength = conn.recv(4)
+ messageLength = int.from_bytes(messageLength, 'little')
+ message = conn.recv(messageLength)
+ return json.loads(message.decode("utf-8"))
+
+def send(message_bytes, verify):
+ global conn
+ if message_bytes != None:
+ conn.sendall(message_bytes)
+ conn.sendall(len(verify).to_bytes(4, 'little'))
+ conn.sendall(bytes(verify, 'ascii'))
+
+def receive():
+ message = read()
+
+ width = message["resolution_x"]
+ height = message["resolution_y"]
+
+ if width != 0 and height != 0:
+ try:
+ do_training = bool(message["train"])
+ fovy = message["fov_y"]
+ fovx = message["fov_x"]
+ znear = message["z_near"]
+ zfar = message["z_far"]
+ do_shs_python = bool(message["shs_python"])
+ do_rot_scale_python = bool(message["rot_scale_python"])
+ keep_alive = bool(message["keep_alive"])
+ scaling_modifier = message["scaling_modifier"]
+ world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
+ world_view_transform[:,1] = -world_view_transform[:,1]
+ world_view_transform[:,2] = -world_view_transform[:,2]
+ full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
+ full_proj_transform[:,1] = -full_proj_transform[:,1]
+ custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
+ except Exception as e:
+ print("")
+ traceback.print_exc()
+ raise e
+ return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
+ else:
+ return None, None, None, None, None, None
\ No newline at end of file
diff --git a/recon/lgm/gs.py b/recon/lgm/gs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c67469d0c3ed92fa7f6f7575daf609b360bc98a5
--- /dev/null
+++ b/recon/lgm/gs.py
@@ -0,0 +1,213 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diff_gaussian_rasterization import (
+ GaussianRasterizationSettings,
+ GaussianRasterizer,
+)
+
+from .options import Options
+
+import kiui
+
+
+class GaussianRenderer:
+ def __init__(self, opt: Options):
+ self.opt = opt
+ self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda")
+
+ # intrinsics
+ self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy))
+ self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32)
+ self.proj_matrix[0, 0] = 1 / self.tan_half_fov
+ self.proj_matrix[1, 1] = 1 / self.tan_half_fov
+ self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear)
+ self.proj_matrix[2, 3] = 1
+
+ def render(
+ self,
+ gaussians,
+ cam_view,
+ cam_view_proj,
+ cam_pos,
+ bg_color=None,
+ scale_modifier=1,
+ ):
+ # gaussians: [B, N, 14]
+ # cam_view, cam_view_proj: [B, V, 4, 4]
+ # cam_pos: [B, V, 3]
+
+ device = gaussians.device
+ B, V = cam_view.shape[:2]
+
+ # loop of loop...
+ images = []
+ alphas = []
+ for b in range(B):
+ # pos, opacity, scale, rotation, shs
+ means3D = gaussians[b, :, 0:3].contiguous().float()
+ opacity = gaussians[b, :, 3:4].contiguous().float()
+ scales = gaussians[b, :, 4:7].contiguous().float()
+ rotations = gaussians[b, :, 7:11].contiguous().float()
+ rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3]
+
+ for v in range(V):
+ # render novel views
+ view_matrix = cam_view[b, v].float()
+ view_proj_matrix = cam_view_proj[b, v].float()
+ campos = cam_pos[b, v].float()
+
+ raster_settings = GaussianRasterizationSettings(
+ image_height=self.opt.output_size,
+ image_width=self.opt.output_size,
+ tanfovx=self.tan_half_fov,
+ tanfovy=self.tan_half_fov,
+ bg=self.bg_color if bg_color is None else bg_color,
+ scale_modifier=scale_modifier,
+ viewmatrix=view_matrix,
+ projmatrix=view_proj_matrix,
+ sh_degree=0,
+ campos=campos,
+ prefiltered=False,
+ debug=False,
+ )
+
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
+
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
+ means3D=means3D,
+ means2D=torch.zeros_like(
+ means3D, dtype=torch.float32, device=device
+ ),
+ shs=None,
+ colors_precomp=rgbs,
+ opacities=opacity,
+ scales=scales,
+ rotations=rotations,
+ cov3D_precomp=None,
+ )
+
+ rendered_image = rendered_image.clamp(0, 1)
+
+ images.append(rendered_image)
+ alphas.append(rendered_alpha)
+
+ images = torch.stack(images, dim=0).view(
+ B, V, 3, self.opt.output_size, self.opt.output_size
+ )
+ alphas = torch.stack(alphas, dim=0).view(
+ B, V, 1, self.opt.output_size, self.opt.output_size
+ )
+
+ return {
+ "image": images, # [B, V, 3, H, W]
+ "alpha": alphas, # [B, V, 1, H, W]
+ }
+
+ def save_ply(self, gaussians, path, compatible=True):
+ # gaussians: [B, N, 14]
+ # compatible: save pre-activated gaussians as in the original paper
+
+ assert gaussians.shape[0] == 1, "only support batch size 1"
+
+ from plyfile import PlyData, PlyElement
+
+ means3D = gaussians[0, :, 0:3].contiguous().float()
+ opacity = gaussians[0, :, 3:4].contiguous().float()
+ scales = gaussians[0, :, 4:7].contiguous().float()
+ rotations = gaussians[0, :, 7:11].contiguous().float()
+ shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3]
+
+ # prune by opacity
+ mask = opacity.squeeze(-1) >= 0.005
+ means3D = means3D[mask]
+ opacity = opacity[mask]
+ scales = scales[mask]
+ rotations = rotations[mask]
+ shs = shs[mask]
+
+ # invert activation to make it compatible with the original ply format
+ if compatible:
+ opacity = kiui.op.inverse_sigmoid(opacity)
+ scales = torch.log(scales + 1e-8)
+ shs = (shs - 0.5) / 0.28209479177387814
+
+ xyzs = means3D.detach().cpu().numpy()
+ f_dc = (
+ shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
+ )
+ opacities = opacity.detach().cpu().numpy()
+ scales = scales.detach().cpu().numpy()
+ rotations = rotations.detach().cpu().numpy()
+
+ l = ["x", "y", "z"]
+ # All channels except the 3 DC
+ for i in range(f_dc.shape[1]):
+ l.append("f_dc_{}".format(i))
+ l.append("opacity")
+ for i in range(scales.shape[1]):
+ l.append("scale_{}".format(i))
+ for i in range(rotations.shape[1]):
+ l.append("rot_{}".format(i))
+
+ dtype_full = [(attribute, "f4") for attribute in l]
+
+ elements = np.empty(xyzs.shape[0], dtype=dtype_full)
+ attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1)
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, "vertex")
+
+ PlyData([el]).write(path)
+
+ def load_ply(self, path, compatible=True):
+ from plyfile import PlyData, PlyElement
+
+ plydata = PlyData.read(path)
+
+ xyz = np.stack(
+ (
+ np.asarray(plydata.elements[0]["x"]),
+ np.asarray(plydata.elements[0]["y"]),
+ np.asarray(plydata.elements[0]["z"]),
+ ),
+ axis=1,
+ )
+ print("Number of points at loading : ", xyz.shape[0])
+
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+ shs = np.zeros((xyz.shape[0], 3))
+ shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+ shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"])
+ shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+ scale_names = [
+ p.name
+ for p in plydata.elements[0].properties
+ if p.name.startswith("scale_")
+ ]
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ rot_names = [
+ p.name for p in plydata.elements[0].properties if p.name.startswith("rot_")
+ ]
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1)
+ gaussians = torch.from_numpy(gaussians).float() # cpu
+
+ if compatible:
+ gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4])
+ gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7])
+ gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5
+
+ return gaussians
diff --git a/recon/lgm/options.py b/recon/lgm/options.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cc31944f89ff14a4387204f1828edd785bc3498
--- /dev/null
+++ b/recon/lgm/options.py
@@ -0,0 +1,120 @@
+import tyro
+from dataclasses import dataclass
+from typing import Tuple, Literal, Dict, Optional
+
+
+@dataclass
+class Options:
+ ### model
+ # Unet image input size
+ input_size: int = 256
+ # Unet definition
+ down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024)
+ down_attention: Tuple[bool, ...] = (False, False, False, True, True, True)
+ mid_attention: bool = True
+ up_channels: Tuple[int, ...] = (1024, 1024, 512, 256)
+ up_attention: Tuple[bool, ...] = (True, True, True, False)
+ # Unet output size, dependent on the input_size and U-Net structure!
+ splat_size: int = 64
+ # gaussian render size
+ output_size: int = 256
+
+ ### dataset
+ # data mode (only support s3 now)
+ data_mode: Literal['s3'] = 's3'
+ # fovy of the dataset
+ fovy: float = 49.1
+ # camera near plane
+ znear: float = 0.5
+ # camera far plane
+ zfar: float = 2.5
+ # number of all views (input + output)
+ num_views: int = 12
+ # number of views
+ num_input_views: int = 4
+ # camera radius
+ cam_radius: float = 1.5 # to better use [-1, 1]^3 space
+ # num workers
+ num_workers: int = 8
+
+ ### training
+ # workspace
+ workspace: str = './workspace'
+ # resume
+ resume: Optional[str] = None
+ # batch size (per-GPU)
+ batch_size: int = 8
+ # gradient accumulation
+ gradient_accumulation_steps: int = 1
+ # training epochs
+ num_epochs: int = 30
+ # lpips loss weight
+ lambda_lpips: float = 1.0
+ # gradient clip
+ gradient_clip: float = 1.0
+ # mixed precision
+ mixed_precision: str = 'bf16'
+ # learning rate
+ lr: float = 4e-4
+ # augmentation prob for grid distortion
+ prob_grid_distortion: float = 0.5
+ # augmentation prob for camera jitter
+ prob_cam_jitter: float = 0.5
+
+ ### testing
+ # test image path
+ test_path: Optional[str] = None
+
+ ### misc
+ # nvdiffrast backend setting
+ force_cuda_rast: bool = False
+ # render fancy video with gaussian scaling effect
+ fancy_video: bool = False
+
+
+# all the default settings
+config_defaults: Dict[str, Options] = {}
+config_doc: Dict[str, str] = {}
+
+config_doc['lrm'] = 'the default settings for LGM'
+config_defaults['lrm'] = Options()
+
+config_doc['small'] = 'small model with lower resolution Gaussians'
+config_defaults['small'] = Options(
+ input_size=256,
+ splat_size=64,
+ output_size=256,
+ batch_size=8,
+ gradient_accumulation_steps=1,
+ mixed_precision='bf16',
+)
+
+config_doc['big'] = 'big model with higher resolution Gaussians'
+config_defaults['big'] = Options(
+ input_size=256,
+ up_channels=(1024, 1024, 512, 256, 128), # one more decoder
+ up_attention=(True, True, True, False, False),
+ splat_size=128,
+ output_size=512, # render & supervise Gaussians at a higher resolution.
+ batch_size=8,
+ num_views=8,
+ gradient_accumulation_steps=1,
+ mixed_precision='bf16',
+)
+
+config_doc['tiny'] = 'tiny model for ablation'
+config_defaults['tiny'] = Options(
+ input_size=256,
+ down_channels=(32, 64, 128, 256, 512),
+ down_attention=(False, False, False, False, True),
+ up_channels=(512, 256, 128),
+ up_attention=(True, False, False, False),
+ splat_size=64,
+ output_size=256,
+ batch_size=16,
+ num_views=8,
+ gradient_accumulation_steps=1,
+ mixed_precision='bf16',
+)
+
+AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc)
diff --git a/recon/lpipsPyTorch/__init__.py b/recon/lpipsPyTorch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6297daa457d1d041c9491dfdf6a75994ffe06e
--- /dev/null
+++ b/recon/lpipsPyTorch/__init__.py
@@ -0,0 +1,21 @@
+import torch
+
+from .modules.lpips import LPIPS
+
+
+def lpips(x: torch.Tensor,
+ y: torch.Tensor,
+ net_type: str = 'alex',
+ version: str = '0.1'):
+ r"""Function that measures
+ Learned Perceptual Image Patch Similarity (LPIPS).
+
+ Arguments:
+ x, y (torch.Tensor): the input tensors to compare.
+ net_type (str): the network type to compare the features:
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+ version (str): the version of LPIPS. Default: 0.1.
+ """
+ device = x.device
+ criterion = LPIPS(net_type, version).to(device)
+ return criterion(x, y)
diff --git a/recon/lpipsPyTorch/modules/lpips.py b/recon/lpipsPyTorch/modules/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea2d1da4daa267cc5e6c2ce11e1dddec3a5e9406
--- /dev/null
+++ b/recon/lpipsPyTorch/modules/lpips.py
@@ -0,0 +1,38 @@
+import torch
+import torch.nn as nn
+
+from .networks import get_network, LinLayers
+from .utils import get_state_dict
+
+
+class LPIPS(nn.Module):
+ r"""Creates a criterion that measures
+ Learned Perceptual Image Patch Similarity (LPIPS).
+
+ Arguments:
+ net_type (str): the network type to compare the features:
+ 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
+ version (str): the version of LPIPS. Default: 0.1.
+ """
+
+ def __init__(self, net_type: str = "alex", version: str = "0.1"):
+
+ assert version in ["0.1"], "v0.1 is only supported now"
+
+ super(LPIPS, self).__init__()
+
+ # pretrained network
+ self.net = get_network(net_type)
+
+ # linear layers
+ self.lin = LinLayers(self.net.n_channels_list)
+ self.lin.load_state_dict(get_state_dict(net_type, version))
+ self.eval()
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ feat_x, feat_y = self.net(x), self.net(y)
+
+ diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
+ res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
+
+ return torch.sum(torch.cat(res, 0), 0, True)
diff --git a/recon/lpipsPyTorch/modules/networks.py b/recon/lpipsPyTorch/modules/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d36c6a56163004d49c321da5e26404af9baa4c2a
--- /dev/null
+++ b/recon/lpipsPyTorch/modules/networks.py
@@ -0,0 +1,96 @@
+from typing import Sequence
+
+from itertools import chain
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from .utils import normalize_activation
+
+
+def get_network(net_type: str):
+ if net_type == 'alex':
+ return AlexNet()
+ elif net_type == 'squeeze':
+ return SqueezeNet()
+ elif net_type == 'vgg':
+ return VGG16()
+ else:
+ raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
+
+
+class LinLayers(nn.ModuleList):
+ def __init__(self, n_channels_list: Sequence[int]):
+ super(LinLayers, self).__init__([
+ nn.Sequential(
+ nn.Identity(),
+ nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
+ ) for nc in n_channels_list
+ ])
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+
+class BaseNet(nn.Module):
+ def __init__(self):
+ super(BaseNet, self).__init__()
+
+ # register buffer
+ self.register_buffer(
+ 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer(
+ 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def set_requires_grad(self, state: bool):
+ for param in chain(self.parameters(), self.buffers()):
+ param.requires_grad = state
+
+ def z_score(self, x: torch.Tensor):
+ return (x - self.mean) / self.std
+
+ def forward(self, x: torch.Tensor):
+ x = self.z_score(x)
+
+ output = []
+ for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
+ x = layer(x)
+ if i in self.target_layers:
+ output.append(normalize_activation(x))
+ if len(output) == len(self.target_layers):
+ break
+ return output
+
+
+class SqueezeNet(BaseNet):
+ def __init__(self):
+ super(SqueezeNet, self).__init__()
+
+ self.layers = models.squeezenet1_1(True).features
+ self.target_layers = [2, 5, 8, 10, 11, 12, 13]
+ self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
+
+ self.set_requires_grad(False)
+
+
+class AlexNet(BaseNet):
+ def __init__(self):
+ super(AlexNet, self).__init__()
+
+ self.layers = models.alexnet(True).features
+ self.target_layers = [2, 5, 8, 10, 12]
+ self.n_channels_list = [64, 192, 384, 256, 256]
+
+ self.set_requires_grad(False)
+
+
+class VGG16(BaseNet):
+ def __init__(self):
+ super(VGG16, self).__init__()
+
+ self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
+ self.target_layers = [4, 9, 16, 23, 30]
+ self.n_channels_list = [64, 128, 256, 512, 512]
+
+ self.set_requires_grad(False)
diff --git a/recon/lpipsPyTorch/modules/utils.py b/recon/lpipsPyTorch/modules/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5
--- /dev/null
+++ b/recon/lpipsPyTorch/modules/utils.py
@@ -0,0 +1,30 @@
+from collections import OrderedDict
+
+import torch
+
+
+def normalize_activation(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
+ # build url
+ url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ + f'master/lpips/weights/v{version}/{net_type}.pth'
+
+ # download
+ old_state_dict = torch.hub.load_state_dict_from_url(
+ url, progress=True,
+ map_location=None if torch.cuda.is_available() else torch.device('cpu')
+ )
+
+ # rename keys
+ new_state_dict = OrderedDict()
+ for key, val in old_state_dict.items():
+ new_key = key
+ new_key = new_key.replace('lin', '')
+ new_key = new_key.replace('model.', '')
+ new_state_dict[new_key] = val
+
+ return new_state_dict
diff --git a/recon/metrics.py b/recon/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..faa3b698a68296a9a6226bc51c78407320c106fe
--- /dev/null
+++ b/recon/metrics.py
@@ -0,0 +1,131 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+from pathlib import Path
+import os
+from PIL import Image
+import torch
+import torchvision.transforms.functional as tf
+from utils.loss_utils import ssim
+from lpipsPyTorch import lpips
+import json
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser
+
+
+def readImages(renders_dir, gt_dir):
+ renders = []
+ gts = []
+ image_names = []
+ for fname in os.listdir(renders_dir):
+ render = Image.open(renders_dir / fname)
+ gt = Image.open(gt_dir / fname)
+ renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
+ gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
+ image_names.append(fname)
+ return renders, gts, image_names
+
+
+def evaluate(model_paths):
+
+ full_dict = {}
+ per_view_dict = {}
+ full_dict_polytopeonly = {}
+ per_view_dict_polytopeonly = {}
+ print("")
+
+ for scene_dir in model_paths:
+ try:
+ print("Scene:", scene_dir)
+ full_dict[scene_dir] = {}
+ per_view_dict[scene_dir] = {}
+ full_dict_polytopeonly[scene_dir] = {}
+ per_view_dict_polytopeonly[scene_dir] = {}
+
+ test_dir = Path(scene_dir) / "test"
+
+ for method in os.listdir(test_dir):
+ print("Method:", method)
+
+ full_dict[scene_dir][method] = {}
+ per_view_dict[scene_dir][method] = {}
+ full_dict_polytopeonly[scene_dir][method] = {}
+ per_view_dict_polytopeonly[scene_dir][method] = {}
+
+ method_dir = test_dir / method
+ gt_dir = method_dir / "gt"
+ renders_dir = method_dir / "renders"
+ renders, gts, image_names = readImages(renders_dir, gt_dir)
+
+ ssims = []
+ psnrs = []
+ lpipss = []
+
+ for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
+ ssims.append(ssim(renders[idx], gts[idx]))
+ psnrs.append(psnr(renders[idx], gts[idx]))
+ lpipss.append(lpips(renders[idx], gts[idx], net_type="vgg"))
+
+ print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
+ print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
+ print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
+ print("")
+
+ full_dict[scene_dir][method].update(
+ {
+ "SSIM": torch.tensor(ssims).mean().item(),
+ "PSNR": torch.tensor(psnrs).mean().item(),
+ "LPIPS": torch.tensor(lpipss).mean().item(),
+ }
+ )
+ per_view_dict[scene_dir][method].update(
+ {
+ "SSIM": {
+ name: ssim
+ for ssim, name in zip(
+ torch.tensor(ssims).tolist(), image_names
+ )
+ },
+ "PSNR": {
+ name: psnr
+ for psnr, name in zip(
+ torch.tensor(psnrs).tolist(), image_names
+ )
+ },
+ "LPIPS": {
+ name: lp
+ for lp, name in zip(
+ torch.tensor(lpipss).tolist(), image_names
+ )
+ },
+ }
+ )
+
+ with open(scene_dir + "/results.json", "w") as fp:
+ json.dump(full_dict[scene_dir], fp, indent=True)
+ with open(scene_dir + "/per_view.json", "w") as fp:
+ json.dump(per_view_dict[scene_dir], fp, indent=True)
+ except:
+ print("Unable to compute metrics for model", scene_dir)
+
+
+if __name__ == "__main__":
+ device = torch.device("cuda:0")
+ torch.cuda.set_device(device)
+
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ parser.add_argument(
+ "--model_paths", "-m", required=True, nargs="+", type=str, default=[]
+ )
+ args = parser.parse_args()
+ evaluate(args.model_paths)
diff --git a/recon/render.py b/recon/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0d66379bf3227127ea18ed10b82bfe53ea2726d
--- /dev/null
+++ b/recon/render.py
@@ -0,0 +1,65 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+from scene import Scene
+import os
+from tqdm import tqdm
+from os import makedirs
+from gaussian_renderer import render
+import torchvision
+from utils.general_utils import safe_state
+from argparse import ArgumentParser
+from arguments import ModelParams, PipelineParams, get_combined_args
+from gaussian_renderer import GaussianModel
+
+def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
+ render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
+ gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
+
+ makedirs(render_path, exist_ok=True)
+ makedirs(gts_path, exist_ok=True)
+
+ for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
+ rendering = render(view, gaussians, pipeline, background)["render"]
+ gt = view.original_image[0:3, :, :]
+ torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
+ torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
+
+def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
+ with torch.no_grad():
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
+
+ bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+ if not skip_train:
+ render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
+
+ if not skip_test:
+ render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Testing script parameters")
+ model = ModelParams(parser, sentinel=True)
+ pipeline = PipelineParams(parser)
+ parser.add_argument("--iteration", default=-1, type=int)
+ parser.add_argument("--skip_train", action="store_true")
+ parser.add_argument("--skip_test", action="store_true")
+ parser.add_argument("--quiet", action="store_true")
+ args = get_combined_args(parser)
+ print("Rendering " + args.model_path)
+
+ # Initialize system state (RNG)
+
+ render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
\ No newline at end of file
diff --git a/recon/render_depth.py b/recon/render_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a09acda32d991f257522f50d0c83c5f718f2cd9
--- /dev/null
+++ b/recon/render_depth.py
@@ -0,0 +1,79 @@
+import torch
+from scene import Scene
+from pathlib import Path
+from PIL import Image
+import numpy as np
+import sys
+import os
+from tqdm import tqdm
+from os import makedirs
+from gaussian_renderer import render
+import torchvision
+from utils.general_utils import safe_state
+from argparse import ArgumentParser
+from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
+from gaussian_renderer import GaussianModel
+from mediapy import write_video
+from tqdm import tqdm
+from einops import rearrange
+from utils.colormaps import apply_depth_colormap
+
+
+@torch.no_grad()
+def render_spiral(dataset, opt, pipe, model_path):
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False)
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+ viewpoint_stack = scene.getTrainCameras().copy()
+ views = []
+ for view_cam in tqdm(viewpoint_stack):
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+ render_pkg = render(view_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["depth"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+ views.append(
+ rearrange(
+ apply_depth_colormap(
+ image[0][..., None],
+ accumulation=render_pkg["alpha"][0][..., None],
+ ),
+ "h w c -> c h w",
+ )
+ )
+ views = torch.stack(views)
+
+ write_video(
+ f"./depth_spirals/{Path(dataset.model_path).stem}.mp4",
+ rearrange(views.cpu().numpy(), "t c h w -> t h w c"),
+ fps=3,
+ )
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--iteration", default=-1, type=int)
+ parser.add_argument("--skip_train", action="store_true")
+ parser.add_argument("--skip_test", action="store_true")
+ parser.add_argument("--quiet", action="store_true")
+ args = parser.parse_args(sys.argv[1:])
+ print("Rendering " + args.model_path)
+ lp = lp.extract(args)
+ fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8))
+ lp.images = [fake_image] * args.num_frames
+
+ # Initialize system state (RNG)
+ render_spiral(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ model_path=args.model_path,
+ )
diff --git a/recon/render_points.py b/recon/render_points.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ad86e10151d899474d9dcecfe3a19afc2e34258
--- /dev/null
+++ b/recon/render_points.py
@@ -0,0 +1,70 @@
+import torch
+from scene import Scene
+from pathlib import Path
+from PIL import Image
+import numpy as np
+import sys
+import os
+from tqdm import tqdm
+from os import makedirs
+from gaussian_renderer import render
+import torchvision
+from utils.general_utils import safe_state
+from argparse import ArgumentParser
+from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
+from gaussian_renderer import GaussianModel
+from mediapy import write_video
+from tqdm import tqdm
+from einops import rearrange
+
+
+@torch.no_grad()
+def render_spiral(dataset, opt, pipe, model_path):
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False)
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+ viewpoint_stack = scene.getTrainCameras().copy()
+ views = []
+ for view_cam in tqdm(viewpoint_stack):
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+ render_pkg = render(view_cam, gaussians, pipe, bg, scaling_modifier=0.1)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+ views.append(image)
+ views = torch.stack(views)
+
+ write_video(
+ f"./paper/specials/{Path(dataset.model_path).stem}.mp4",
+ rearrange(views.cpu().numpy(), "t c h w -> t h w c"),
+ fps=30,
+ )
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--iteration", default=-1, type=int)
+ parser.add_argument("--skip_train", action="store_true")
+ parser.add_argument("--skip_test", action="store_true")
+ parser.add_argument("--quiet", action="store_true")
+ args = parser.parse_args(sys.argv[1:])
+ print("Rendering " + args.model_path)
+ lp = lp.extract(args)
+ fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8))
+ lp.images = [fake_image] * args.num_frames
+
+ # Initialize system state (RNG)
+ render_spiral(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ model_path=args.model_path,
+ )
diff --git a/recon/render_spiral.py b/recon/render_spiral.py
new file mode 100644
index 0000000000000000000000000000000000000000..23aea3afc972638315311e96e30c3aec8d75aee1
--- /dev/null
+++ b/recon/render_spiral.py
@@ -0,0 +1,75 @@
+import torch
+from scene import Scene
+from pathlib import Path
+from PIL import Image
+import numpy as np
+import sys
+import os
+from tqdm import tqdm
+from os import makedirs
+from gaussian_renderer import render
+import torchvision
+from utils.general_utils import safe_state
+from argparse import ArgumentParser
+from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams
+from gaussian_renderer import GaussianModel
+from mediapy import write_video
+from tqdm import tqdm
+from einops import rearrange
+
+
+@torch.no_grad()
+def render_spiral(dataset, opt, pipe, model_path):
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False)
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+ viewpoint_stack = scene.getTrainCameras().copy()
+ views = []
+ for view_cam in tqdm(viewpoint_stack):
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+ render_pkg = render(view_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+ views.append(image)
+ views = torch.stack(views)
+
+ write_video(
+ f"./spirals/{Path(dataset.model_path).stem}.mp4",
+ rearrange(views.cpu().numpy(), "t c h w -> t h w c"),
+ fps=30,
+ )
+ write_video(
+ f"tmp/test_spiral.mp4",
+ rearrange(views.cpu().numpy(), "t c h w -> t h w c"),
+ fps=30,
+ )
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--iteration", default=-1, type=int)
+ parser.add_argument("--skip_train", action="store_true")
+ parser.add_argument("--skip_test", action="store_true")
+ parser.add_argument("--quiet", action="store_true")
+ args = parser.parse_args(sys.argv[1:])
+ print("Rendering " + args.model_path)
+ lp = lp.extract(args)
+ fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8))
+ lp.images = [fake_image] * args.num_frames
+
+ # Initialize system state (RNG)
+ render_spiral(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ model_path=args.model_path,
+ )
diff --git a/recon/restore.py b/recon/restore.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c907a210bf62ef2453ccc9d06559c1cfe3b33f2
--- /dev/null
+++ b/recon/restore.py
@@ -0,0 +1,3 @@
+import torch
+
+pass
diff --git a/recon/scene/__init__.py b/recon/scene/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..faea2365d4aa15be843b7adfdb7efbf22dc60554
--- /dev/null
+++ b/recon/scene/__init__.py
@@ -0,0 +1,139 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import random
+import json
+from utils.system_utils import searchForMaxIteration
+from scene.dataset_readers import sceneLoadTypeCallbacks
+from scene.gaussian_model import GaussianModel
+from arguments import ModelParams
+from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
+
+
+class Scene:
+ gaussians: GaussianModel
+
+ def __init__(
+ self,
+ args: ModelParams,
+ gaussians: GaussianModel,
+ load_iteration=None,
+ shuffle=True,
+ resolution_scales=[1.0],
+ skip_gaussians=False,
+ ):
+ """b
+ :param path: Path to colmap scene main folder.
+ """
+ self.model_path = args.model_path
+ self.loaded_iter = None
+ self.gaussians = gaussians
+
+ if load_iteration:
+ if load_iteration == -1:
+ self.loaded_iter = searchForMaxIteration(
+ os.path.join(self.model_path, "point_cloud")
+ )
+ else:
+ self.loaded_iter = load_iteration
+ print("Loading trained model at iteration {}".format(self.loaded_iter))
+
+ self.train_cameras = {}
+ self.test_cameras = {}
+
+ if os.path.exists(os.path.join(args.source_path, "sparse")):
+ scene_info = sceneLoadTypeCallbacks["Colmap"](
+ args.source_path, args.images, args.eval
+ )
+ elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
+ print("Found transforms_train.json file, assuming Blender data set!")
+ scene_info = sceneLoadTypeCallbacks["Blender"](
+ args.source_path, args.white_background, args.eval
+ )
+ elif hasattr(args, "num_frames"):
+ print("using video-nvs target")
+ scene_info = sceneLoadTypeCallbacks["VideoNVS"](
+ args.num_frames,
+ args.radius,
+ args.elevation,
+ args.fov,
+ args.reso,
+ args.images,
+ args.masks,
+ args.num_pts,
+ args.train,
+ )
+ else:
+ assert False, "Could not recognize scene type!"
+
+ if not self.loaded_iter:
+ with open(scene_info.ply_path, "rb") as src_file, open(
+ os.path.join(self.model_path, "input.ply"), "wb"
+ ) as dest_file:
+ dest_file.write(src_file.read())
+ json_cams = []
+ camlist = []
+ if scene_info.test_cameras:
+ camlist.extend(scene_info.test_cameras)
+ if scene_info.train_cameras:
+ camlist.extend(scene_info.train_cameras)
+ for id, cam in enumerate(camlist):
+ json_cams.append(camera_to_JSON(id, cam))
+ with open(os.path.join(self.model_path, "cameras.json"), "w") as file:
+ json.dump(json_cams, file)
+
+ if shuffle:
+ random.shuffle(
+ scene_info.train_cameras
+ ) # Multi-res consistent random shuffling
+ random.shuffle(
+ scene_info.test_cameras
+ ) # Multi-res consistent random shuffling
+
+ self.cameras_extent = scene_info.nerf_normalization["radius"]
+
+ for resolution_scale in resolution_scales:
+ print("Loading Training Cameras")
+ self.train_cameras[resolution_scale] = cameraList_from_camInfos(
+ scene_info.train_cameras, resolution_scale, args
+ )
+ print("Loading Test Cameras")
+ self.test_cameras[resolution_scale] = cameraList_from_camInfos(
+ scene_info.test_cameras, resolution_scale, args
+ )
+
+ if not skip_gaussians:
+ if self.loaded_iter:
+ self.gaussians.load_ply(
+ os.path.join(
+ self.model_path,
+ "point_cloud",
+ "iteration_" + str(self.loaded_iter),
+ "point_cloud.ply",
+ )
+ )
+ else:
+ self.gaussians.create_from_pcd(
+ scene_info.point_cloud, self.cameras_extent
+ )
+
+ def save(self, iteration):
+ point_cloud_path = os.path.join(
+ self.model_path, "point_cloud/iteration_{}".format(iteration)
+ )
+ self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
+
+ def getTrainCameras(self, scale=1.0):
+ return self.train_cameras[scale]
+
+ def getTestCameras(self, scale=1.0):
+ return self.test_cameras[scale]
diff --git a/recon/scene/cameras.py b/recon/scene/cameras.py
new file mode 100644
index 0000000000000000000000000000000000000000..abf6e5242bc46ef1915ce24619a8319d0b7591c7
--- /dev/null
+++ b/recon/scene/cameras.py
@@ -0,0 +1,71 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+from torch import nn
+import numpy as np
+from utils.graphics_utils import getWorld2View2, getProjectionMatrix
+
+class Camera(nn.Module):
+ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
+ image_name, uid,
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
+ ):
+ super(Camera, self).__init__()
+
+ self.uid = uid
+ self.colmap_id = colmap_id
+ self.R = R
+ self.T = T
+ self.FoVx = FoVx
+ self.FoVy = FoVy
+ self.image_name = image_name
+
+ try:
+ self.data_device = torch.device(data_device)
+ except Exception as e:
+ print(e)
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
+ self.data_device = torch.device("cuda")
+
+ self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
+ self.image_width = self.original_image.shape[2]
+ self.image_height = self.original_image.shape[1]
+
+ if gt_alpha_mask is not None:
+ self.original_image *= gt_alpha_mask.to(self.data_device)
+ else:
+ self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
+
+ self.zfar = 100.0
+ self.znear = 0.01
+
+ self.trans = trans
+ self.scale = scale
+
+ self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
+
+class MiniCam:
+ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
+ self.image_width = width
+ self.image_height = height
+ self.FoVy = fovy
+ self.FoVx = fovx
+ self.znear = znear
+ self.zfar = zfar
+ self.world_view_transform = world_view_transform
+ self.full_proj_transform = full_proj_transform
+ view_inv = torch.inverse(self.world_view_transform)
+ self.camera_center = view_inv[3][:3]
+
diff --git a/recon/scene/colmap_loader.py b/recon/scene/colmap_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6fba6a9c961f52c88780ecb44d7821b4cb73ee
--- /dev/null
+++ b/recon/scene/colmap_loader.py
@@ -0,0 +1,294 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import numpy as np
+import collections
+import struct
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"])
+Camera = collections.namedtuple(
+ "Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
+}
+CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
+ for camera_model in CAMERA_MODELS])
+CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
+ for camera_model in CAMERA_MODELS])
+
+
+def qvec2rotmat(qvec):
+ return np.array([
+ [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
+ [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
+ [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = np.array([
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+def read_points3D_text(path):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ xyzs = None
+ rgbs = None
+ errors = None
+ num_points = 0
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ num_points += 1
+
+
+ xyzs = np.empty((num_points, 3))
+ rgbs = np.empty((num_points, 3))
+ errors = np.empty((num_points, 1))
+ count = 0
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = np.array(float(elems[7]))
+ xyzs[count] = xyz
+ rgbs[count] = rgb
+ errors[count] = error
+ count += 1
+
+ return xyzs, rgbs, errors
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+
+
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+
+ xyzs = np.empty((num_points, 3))
+ rgbs = np.empty((num_points, 3))
+ errors = np.empty((num_points, 1))
+
+ for p_id in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd")
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q")[0]
+ track_elems = read_next_bytes(
+ fid, num_bytes=8*track_length,
+ format_char_sequence="ii"*track_length)
+ xyzs[p_id] = xyz
+ rgbs[p_id] = rgb
+ errors[p_id] = error
+ return xyzs, rgbs, errors
+
+def read_intrinsics_text(path):
+ """
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(id=camera_id, model=model,
+ width=width, height=height,
+ params=params)
+ return cameras
+
+def read_extrinsics_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi")
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ image_name = ""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ image_name += current_char.decode("utf-8")
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ num_points2D = read_next_bytes(fid, num_bytes=8,
+ format_char_sequence="Q")[0]
+ x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
+ format_char_sequence="ddq"*num_points2D)
+ xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
+ tuple(map(float, x_y_id_s[1::3]))])
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def read_intrinsics_binary(path_to_model_file):
+ """
+ see: src/base/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ")
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(fid, num_bytes=8*num_params,
+ format_char_sequence="d"*num_params)
+ cameras[camera_id] = Camera(id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params))
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def read_extrinsics_text(path):
+ """
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack([tuple(map(float, elems[0::3])),
+ tuple(map(float, elems[1::3]))])
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id, qvec=qvec, tvec=tvec,
+ camera_id=camera_id, name=image_name,
+ xys=xys, point3D_ids=point3D_ids)
+ return images
+
+
+def read_colmap_bin_array(path):
+ """
+ Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
+
+ :param path: path to the colmap binary file.
+ :return: nd array with the floating point values in the value
+ """
+ with open(path, "rb") as fid:
+ width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
+ usecols=(0, 1, 2), dtype=int)
+ fid.seek(0)
+ num_delimiter = 0
+ byte = fid.read(1)
+ while True:
+ if byte == b"&":
+ num_delimiter += 1
+ if num_delimiter >= 3:
+ break
+ byte = fid.read(1)
+ array = np.fromfile(fid, np.float32)
+ array = array.reshape((width, height, channels), order="F")
+ return np.transpose(array, (1, 0, 2)).squeeze()
diff --git a/recon/scene/dataset_readers.py b/recon/scene/dataset_readers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e09a82ea5726f3f7a46fd24bacb5f1b93ef5231
--- /dev/null
+++ b/recon/scene/dataset_readers.py
@@ -0,0 +1,512 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import sys
+from PIL import Image
+from typing import NamedTuple
+from scene.colmap_loader import (
+ read_extrinsics_text,
+ read_intrinsics_text,
+ qvec2rotmat,
+ read_extrinsics_binary,
+ read_intrinsics_binary,
+ read_points3D_binary,
+ read_points3D_text,
+)
+from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
+from utils.camera_utils import get_uniform_poses
+import numpy as np
+import json
+from pathlib import Path
+from plyfile import PlyData, PlyElement
+from utils.sh_utils import SH2RGB
+from scene.gaussian_model import BasicPointCloud
+from scene.cameras import Camera
+import torch
+import rembg
+import mcubes
+import trimesh
+
+
+class CameraInfo(NamedTuple):
+ uid: int
+ R: np.array
+ T: np.array
+ FovY: np.array
+ FovX: np.array
+ image: np.array
+ image_path: str
+ image_name: str
+ width: int
+ height: int
+
+
+class SceneInfo(NamedTuple):
+ point_cloud: BasicPointCloud
+ train_cameras: list
+ test_cameras: list
+ nerf_normalization: dict
+ ply_path: str
+
+
+def getNerfppNorm(cam_info):
+ def get_center_and_diag(cam_centers):
+ cam_centers = np.hstack(cam_centers)
+ avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
+ center = avg_cam_center
+ dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
+ diagonal = np.max(dist)
+ return center.flatten(), diagonal
+
+ cam_centers = []
+
+ for cam in cam_info:
+ W2C = getWorld2View2(cam.R, cam.T)
+ C2W = np.linalg.inv(W2C)
+ cam_centers.append(C2W[:3, 3:4])
+
+ center, diagonal = get_center_and_diag(cam_centers)
+ radius = diagonal * 1.1
+
+ translate = -center
+
+ return {"translate": translate, "radius": radius}
+
+
+def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
+ cam_infos = []
+ for idx, key in enumerate(cam_extrinsics):
+ sys.stdout.write("\r")
+ # the exact output you're looking for:
+ sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics)))
+ sys.stdout.flush()
+
+ extr = cam_extrinsics[key]
+ intr = cam_intrinsics[extr.camera_id]
+ height = intr.height
+ width = intr.width
+
+ uid = intr.id
+ R = np.transpose(qvec2rotmat(extr.qvec))
+ T = np.array(extr.tvec)
+
+ if intr.model == "SIMPLE_PINHOLE":
+ focal_length_x = intr.params[0]
+ FovY = focal2fov(focal_length_x, height)
+ FovX = focal2fov(focal_length_x, width)
+ elif intr.model == "PINHOLE":
+ focal_length_x = intr.params[0]
+ focal_length_y = intr.params[1]
+ FovY = focal2fov(focal_length_y, height)
+ FovX = focal2fov(focal_length_x, width)
+ else:
+ assert (
+ False
+ ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
+
+ image_path = os.path.join(images_folder, os.path.basename(extr.name))
+ image_name = os.path.basename(image_path).split(".")[0]
+ image = Image.open(image_path)
+
+ cam_info = CameraInfo(
+ uid=uid,
+ R=R,
+ T=T,
+ FovY=FovY,
+ FovX=FovX,
+ image=image,
+ image_path=image_path,
+ image_name=image_name,
+ width=width,
+ height=height,
+ )
+ cam_infos.append(cam_info)
+ sys.stdout.write("\n")
+ return cam_infos
+
+
+def fetchPly(path):
+ plydata = PlyData.read(path)
+ vertices = plydata["vertex"]
+ positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T
+ colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0
+ normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T
+ return BasicPointCloud(points=positions, colors=colors, normals=normals)
+
+
+def storePly(path, xyz, rgb):
+ # Define the dtype for the structured array
+ dtype = [
+ ("x", "f4"),
+ ("y", "f4"),
+ ("z", "f4"),
+ ("nx", "f4"),
+ ("ny", "f4"),
+ ("nz", "f4"),
+ ("red", "u1"),
+ ("green", "u1"),
+ ("blue", "u1"),
+ ]
+
+ normals = np.zeros_like(xyz)
+
+ elements = np.empty(xyz.shape[0], dtype=dtype)
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
+ elements[:] = list(map(tuple, attributes))
+
+ # Create the PlyData object and write to file
+ vertex_element = PlyElement.describe(elements, "vertex")
+ ply_data = PlyData([vertex_element])
+ ply_data.write(path)
+
+
+def readColmapSceneInfo(path, images, eval, llffhold=8):
+ try:
+ cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
+ cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
+ cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
+ cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
+ except:
+ cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
+ cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
+ cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
+ cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
+
+ reading_dir = "images" if images == None else images
+ cam_infos_unsorted = readColmapCameras(
+ cam_extrinsics=cam_extrinsics,
+ cam_intrinsics=cam_intrinsics,
+ images_folder=os.path.join(path, reading_dir),
+ )
+ cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name)
+
+ if eval:
+ train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
+ test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
+ else:
+ train_cam_infos = cam_infos
+ test_cam_infos = []
+
+ nerf_normalization = getNerfppNorm(train_cam_infos)
+
+ ply_path = os.path.join(path, "sparse/0/points3D.ply")
+ bin_path = os.path.join(path, "sparse/0/points3D.bin")
+ txt_path = os.path.join(path, "sparse/0/points3D.txt")
+ if not os.path.exists(ply_path):
+ print(
+ "Converting point3d.bin to .ply, will happen only the first time you open the scene."
+ )
+ try:
+ xyz, rgb, _ = read_points3D_binary(bin_path)
+ except:
+ xyz, rgb, _ = read_points3D_text(txt_path)
+ storePly(ply_path, xyz, rgb)
+ try:
+ pcd = fetchPly(ply_path)
+ except:
+ pcd = None
+
+ scene_info = SceneInfo(
+ point_cloud=pcd,
+ train_cameras=train_cam_infos,
+ test_cameras=test_cam_infos,
+ nerf_normalization=nerf_normalization,
+ ply_path=ply_path,
+ )
+ return scene_info
+
+
+def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
+ cam_infos = []
+
+ with open(os.path.join(path, transformsfile)) as json_file:
+ contents = json.load(json_file)
+ fovx = contents["camera_angle_x"]
+
+ frames = contents["frames"]
+ for idx, frame in enumerate(frames):
+ cam_name = os.path.join(path, frame["file_path"] + extension)
+
+ # NeRF 'transform_matrix' is a camera-to-world transform
+ c2w = np.array(frame["transform_matrix"])
+ # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
+ c2w[:3, 1:3] *= -1
+
+ # get the world-to-camera transform and set R, T
+ w2c = np.linalg.inv(c2w)
+ R = np.transpose(
+ w2c[:3, :3]
+ ) # R is stored transposed due to 'glm' in CUDA code
+ T = w2c[:3, 3]
+
+ image_path = os.path.join(path, cam_name)
+ image_name = Path(cam_name).stem
+ image = Image.open(image_path)
+
+ im_data = np.array(image.convert("RGBA"))
+
+ bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0])
+
+ norm_data = im_data / 255.0
+ if norm_data.shape[-1] != 3:
+ arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * (
+ 1 - norm_data[:, :, 3:4]
+ )
+ image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB")
+
+ fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
+ FovY = fovy
+ FovX = fovx
+
+ cam_infos.append(
+ CameraInfo(
+ uid=idx,
+ R=R,
+ T=T,
+ FovY=FovY,
+ FovX=FovX,
+ image=image,
+ image_path=image_path,
+ image_name=image_name,
+ width=image.size[0],
+ height=image.size[1],
+ )
+ )
+
+ return cam_infos
+
+
+def uniform_surface_sampling_from_vertices_and_faces(
+ vertices, faces, num_points: int
+) -> torch.Tensor:
+ """
+ Uniformly sample points from the surface of a mesh.
+
+ Args:
+ vertices (torch.Tensor): Vertices of the mesh.
+ faces (torch.Tensor): Faces of the mesh.
+ num_points (int): Number of points to sample.
+
+ Returns:
+ torch.Tensor: Points sampled from the surface of the mesh.
+ """
+ mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
+ n = num_points
+ points = []
+ while n > 0:
+ p, _ = trimesh.sample.sample_surface_even(mesh, n)
+ n -= p.shape[0]
+ if n >= 0:
+ points.append(p)
+ else:
+ points.append(p[:n])
+
+ if len(points) > 1:
+ points = np.concatenate(points, axis=0)
+ else:
+ points = points[0]
+
+ points = torch.from_numpy(points.astype(np.float32))
+
+ return points, torch.rand_like(points)
+
+
+def occ_from_sparse_initialize(poses, images, cameras, grid_reso, num_points):
+ # fov is in degrees
+ this_session = rembg.new_session()
+
+ imgs = [rembg.remove(im, session=this_session) for im in images]
+
+ reso = grid_reso
+ occ_grid = torch.ones((reso, reso, reso), dtype=torch.bool, device="cuda")
+
+ c2ws = poses
+ center = c2ws[..., :3, 3].mean(axis=0)
+ radius = np.linalg.norm(c2ws[..., :3, 3] - center, axis=-1).mean()
+ xx, yy, zz = torch.meshgrid(
+ torch.linspace(-radius, radius, reso, device="cuda"),
+ torch.linspace(-radius, radius, reso, device="cuda"),
+ torch.linspace(-radius, radius, reso, device="cuda"),
+ indexing="ij",
+ )
+ print("radius", radius)
+
+ # xyz_grid = torch.stack((xx.flatten(), yy.flatten(), zz.flatten()), dim=-1)
+ ww = torch.ones((reso, reso, reso), dtype=torch.float32, device="cuda")
+ xyzw_grid = torch.stack((xx, yy, zz, ww), dim=-1)
+ xyzw_grid[..., :3] += torch.from_numpy(center).cuda()
+
+ c2ws = torch.tensor(c2ws, dtype=torch.float32)
+
+ for c2w, camera, img in zip(c2ws, cameras, imgs):
+ img = np.asarray(img)
+ alpha = img[..., 3].astype(np.float32) / 255.0
+ is_foreground = alpha > 0.05
+ is_foreground = torch.from_numpy(is_foreground).cuda()
+
+ full_proj_mtx = Camera(
+ colmap_id=camera.uid,
+ R=camera.R,
+ T=camera.T,
+ FoVx=camera.FovX,
+ FoVy=camera.FovY,
+ image=torch.randn(3, 10, 10),
+ gt_alpha_mask=None,
+ image_name="no",
+ uid=0,
+ data_device="cuda",
+ ).full_proj_transform
+ # check the scale
+
+ ij = xyzw_grid @ full_proj_mtx
+ ij = (ij + 1) / 2.0
+ h, w = img.shape[:2]
+ ij = ij[..., :2] * torch.tensor([w, h], dtype=torch.float32, device="cuda")
+ ij = (
+ ij.clamp(
+ min=torch.tensor([0.0, 0.0], device="cuda"),
+ max=torch.tensor([w - 1, h - 1], dtype=torch.float32, device="cuda"),
+ )
+ .to(torch.long)
+ .cuda()
+ )
+
+ occ_grid = torch.logical_and(occ_grid, is_foreground[ij[..., 1], ij[..., 0]])
+
+ # To mesh
+ occ_grid = occ_grid.to(torch.float32).cpu().numpy()
+ vertices, triangles = mcubes.marching_cubes(occ_grid, 0.5)
+
+ # vertices = (vertices / reso - 0.5) * radius * 2 + center
+ # vertices = (vertices / (reso - 1.0) - 0.5) * radius * 2 * 2 + center
+ vertices = vertices / (grid_reso - 1) * 2 - 1
+ vertices = vertices * radius + center
+ # mcubes.export_obj(vertices, triangles, "./tmp/occ_voxel.obj")
+
+ xyz, rgb = uniform_surface_sampling_from_vertices_and_faces(
+ vertices, triangles, num_points
+ )
+
+ return xyz
+
+
+def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
+ print("Reading Training Transforms")
+ train_cam_infos = readCamerasFromTransforms(
+ path, "transforms_train.json", white_background, extension
+ )
+ print("Reading Test Transforms")
+ test_cam_infos = readCamerasFromTransforms(
+ path, "transforms_test.json", white_background, extension
+ )
+
+ if not eval:
+ train_cam_infos.extend(test_cam_infos)
+ test_cam_infos = []
+
+ nerf_normalization = getNerfppNorm(train_cam_infos)
+
+ ply_path = os.path.join(path, "points3d.ply")
+ if not os.path.exists(ply_path):
+ # Since this data set has no colmap data, we start with random points
+ num_pts = 100_000
+ print(f"Generating random point cloud ({num_pts})...")
+
+ # We create random points inside the bounds of the synthetic Blender scenes
+ xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
+ shs = np.random.random((num_pts, 3)) / 255.0
+ pcd = BasicPointCloud(
+ points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
+ )
+
+ storePly(ply_path, xyz, SH2RGB(shs) * 255)
+ try:
+ pcd = fetchPly(ply_path)
+ except:
+ pcd = None
+
+ scene_info = SceneInfo(
+ point_cloud=pcd,
+ train_cameras=train_cam_infos,
+ test_cameras=test_cam_infos,
+ nerf_normalization=nerf_normalization,
+ ply_path=ply_path,
+ )
+ return scene_info
+
+
+def constructVideoNVSInfo(
+ num_frames,
+ radius,
+ elevation,
+ fov,
+ reso,
+ images,
+ masks,
+ num_pts=100_000,
+ train=True,
+):
+ poses = get_uniform_poses(num_frames, radius, elevation)
+ w2cs = np.linalg.inv(poses)
+ train_cam_infos = []
+
+ for idx, pose in enumerate(w2cs):
+ train_cam_infos.append(
+ CameraInfo(
+ uid=idx,
+ R=np.transpose(pose[:3, :3]),
+ T=pose[:3, 3],
+ FovY=np.deg2rad(fov),
+ FovX=np.deg2rad(fov),
+ image=images[idx],
+ image_path=None,
+ image_name=idx,
+ width=reso,
+ height=reso,
+ )
+ )
+
+ nerf_normalization = getNerfppNorm(train_cam_infos)
+ # xyz = np.random.random((num_pts, 3)) * radius / 3 - radius / 3
+ xyz = np.random.randn(num_pts, 3) * radius / 16
+ # if len(poses) <= 24:
+ # xyz = occ_from_sparse_initialize(poses, images, train_cam_infos, 256, num_pts)
+ # num_pts = xyz.shape[0]
+ # else:
+ # xyz = np.random.randn(num_pts, 3) * radius / 16
+ xyz = np.random.randn(num_pts, 3) * radius / 16
+ # shs = np.random.random((num_pts, 3)) / 255.0
+ shs = np.ones((num_pts, 3)) * 0.2
+ pcd = BasicPointCloud(
+ points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
+ )
+
+ ply_path = "./tmp/points3d.ply"
+ storePly(ply_path, xyz, SH2RGB(shs) * 255)
+ pcd = fetchPly(ply_path)
+
+ scene_info = SceneInfo(
+ point_cloud=pcd,
+ train_cameras=train_cam_infos,
+ test_cameras=[],
+ nerf_normalization=nerf_normalization,
+ ply_path="./tmp/points3d.ply",
+ )
+
+ return scene_info
+
+
+sceneLoadTypeCallbacks = {
+ "Colmap": readColmapSceneInfo,
+ "Blender": readNerfSyntheticInfo,
+ "VideoNVS": constructVideoNVSInfo,
+}
diff --git a/recon/scene/gaussian_model.py b/recon/scene/gaussian_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef893f2183e28a1d6a7cfa4455c446fbaed3ff76
--- /dev/null
+++ b/recon/scene/gaussian_model.py
@@ -0,0 +1,570 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import numpy as np
+from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
+from torch import nn
+import os
+from utils.system_utils import mkdir_p
+from plyfile import PlyData, PlyElement
+from utils.sh_utils import RGB2SH
+from simple_knn._C import distCUDA2
+from utils.graphics_utils import BasicPointCloud
+from utils.general_utils import strip_symmetric, build_scaling_rotation
+
+
+class GaussianModel:
+ def setup_functions(self):
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
+ actual_covariance = L @ L.transpose(1, 2)
+ symm = strip_symmetric(actual_covariance)
+ return symm
+
+ self.scaling_activation = torch.exp
+ self.scaling_inverse_activation = torch.log
+
+ self.covariance_activation = build_covariance_from_scaling_rotation
+
+ self.opacity_activation = torch.sigmoid
+ self.inverse_opacity_activation = inverse_sigmoid
+
+ self.rotation_activation = torch.nn.functional.normalize
+
+ def __init__(self, sh_degree: int):
+ self.active_sh_degree = 0
+ self.max_sh_degree = sh_degree
+ self._xyz = torch.empty(0)
+ self._features_dc = torch.empty(0)
+ self._features_rest = torch.empty(0)
+ self._scaling = torch.empty(0)
+ self._rotation = torch.empty(0)
+ self._opacity = torch.empty(0)
+ self.max_radii2D = torch.empty(0)
+ self.xyz_gradient_accum = torch.empty(0)
+ self.denom = torch.empty(0)
+ self.optimizer = None
+ self.percent_dense = 0
+ self.spatial_lr_scale = 0
+ self.setup_functions()
+
+ def capture(self):
+ return (
+ self.active_sh_degree,
+ self._xyz,
+ self._features_dc,
+ self._features_rest,
+ self._scaling,
+ self._rotation,
+ self._opacity,
+ self.max_radii2D,
+ self.xyz_gradient_accum,
+ self.denom,
+ self.optimizer.state_dict(),
+ self.spatial_lr_scale,
+ )
+
+ def restore(self, model_args, training_args):
+ (
+ self.active_sh_degree,
+ self._xyz,
+ self._features_dc,
+ self._features_rest,
+ self._scaling,
+ self._rotation,
+ self._opacity,
+ self.max_radii2D,
+ xyz_gradient_accum,
+ denom,
+ opt_dict,
+ self.spatial_lr_scale,
+ ) = model_args
+ self.training_setup(training_args)
+ self.xyz_gradient_accum = xyz_gradient_accum
+ self.denom = denom
+ self.optimizer.load_state_dict(opt_dict)
+
+ @property
+ def get_scaling(self):
+ return self.scaling_activation(self._scaling)
+
+ @property
+ def get_rotation(self):
+ return self.rotation_activation(self._rotation)
+
+ @property
+ def get_xyz(self):
+ return self._xyz
+
+ @property
+ def get_features(self):
+ features_dc = self._features_dc
+ features_rest = self._features_rest
+ return torch.cat((features_dc, features_rest), dim=1)
+
+ @property
+ def get_opacity(self):
+ return self.opacity_activation(self._opacity)
+
+ def get_covariance(self, scaling_modifier=1):
+ return self.covariance_activation(
+ self.get_scaling, scaling_modifier, self._rotation
+ )
+
+ def oneupSHdegree(self):
+ if self.active_sh_degree < self.max_sh_degree:
+ self.active_sh_degree += 1
+
+ def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float):
+ self.spatial_lr_scale = spatial_lr_scale
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
+ features = (
+ torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2))
+ .float()
+ .cuda()
+ )
+ features[:, :3, 0] = fused_color
+ features[:, 3:, 1:] = 0.0
+
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
+
+ dist2 = torch.clamp_min(
+ distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()),
+ 0.0000001,
+ )
+ scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3)
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
+ rots[:, 0] = 1
+
+ opacities = inverse_sigmoid(
+ 0.5
+ * torch.ones(
+ (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"
+ )
+ )
+
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
+ self._features_dc = nn.Parameter(
+ features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True)
+ )
+ self._features_rest = nn.Parameter(
+ features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True)
+ )
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
+
+ def training_setup(self, training_args):
+ self.percent_dense = training_args.percent_dense
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+
+ l = [
+ {
+ "params": [self._xyz],
+ "lr": training_args.position_lr_init * self.spatial_lr_scale,
+ "name": "xyz",
+ },
+ {
+ "params": [self._features_dc],
+ "lr": training_args.feature_lr,
+ "name": "f_dc",
+ },
+ {
+ "params": [self._features_rest],
+ "lr": training_args.feature_lr / 20.0,
+ "name": "f_rest",
+ },
+ {
+ "params": [self._opacity],
+ "lr": training_args.opacity_lr,
+ "name": "opacity",
+ },
+ {
+ "params": [self._scaling],
+ "lr": training_args.scaling_lr,
+ "name": "scaling",
+ },
+ {
+ "params": [self._rotation],
+ "lr": training_args.rotation_lr,
+ "name": "rotation",
+ },
+ ]
+
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
+ self.xyz_scheduler_args = get_expon_lr_func(
+ lr_init=training_args.position_lr_init * self.spatial_lr_scale,
+ lr_final=training_args.position_lr_final * self.spatial_lr_scale,
+ lr_delay_mult=training_args.position_lr_delay_mult,
+ max_steps=training_args.position_lr_max_steps,
+ )
+
+ def update_learning_rate(self, iteration):
+ """Learning rate scheduling per step"""
+ for param_group in self.optimizer.param_groups:
+ if param_group["name"] == "xyz":
+ lr = self.xyz_scheduler_args(iteration)
+ param_group["lr"] = lr
+ return lr
+
+ def construct_list_of_attributes(self):
+ l = ["x", "y", "z", "nx", "ny", "nz"]
+ # All channels except the 3 DC
+ for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
+ l.append("f_dc_{}".format(i))
+ for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]):
+ l.append("f_rest_{}".format(i))
+ l.append("opacity")
+ for i in range(self._scaling.shape[1]):
+ l.append("scale_{}".format(i))
+ for i in range(self._rotation.shape[1]):
+ l.append("rot_{}".format(i))
+ return l
+
+ def save_ply(self, path):
+ mkdir_p(os.path.dirname(path))
+
+ xyz = self._xyz.detach().cpu().numpy()
+ normals = np.zeros_like(xyz)
+ f_dc = (
+ self._features_dc.detach()
+ .transpose(1, 2)
+ .flatten(start_dim=1)
+ .contiguous()
+ .cpu()
+ .numpy()
+ )
+ f_rest = (
+ self._features_rest.detach()
+ .transpose(1, 2)
+ .flatten(start_dim=1)
+ .contiguous()
+ .cpu()
+ .numpy()
+ )
+ opacities = self._opacity.detach().cpu().numpy()
+ scale = self._scaling.detach().cpu().numpy()
+ rotation = self._rotation.detach().cpu().numpy()
+
+ dtype_full = [
+ (attribute, "f4") for attribute in self.construct_list_of_attributes()
+ ]
+
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
+ attributes = np.concatenate(
+ (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1
+ )
+ elements[:] = list(map(tuple, attributes))
+ el = PlyElement.describe(elements, "vertex")
+ PlyData([el]).write(path)
+
+ def reset_opacity(self):
+ opacities_new = inverse_sigmoid(
+ torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01)
+ )
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
+ self._opacity = optimizable_tensors["opacity"]
+
+ def load_ply(self, path):
+ plydata = PlyData.read(path)
+
+ xyz = np.stack(
+ (
+ np.asarray(plydata.elements[0]["x"]),
+ np.asarray(plydata.elements[0]["y"]),
+ np.asarray(plydata.elements[0]["z"]),
+ ),
+ axis=1,
+ )
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
+
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
+
+ extra_f_names = [
+ p.name
+ for p in plydata.elements[0].properties
+ if p.name.startswith("f_rest_")
+ ]
+ extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1]))
+ assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
+ for idx, attr_name in enumerate(extra_f_names):
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
+ features_extra = features_extra.reshape(
+ (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)
+ )
+
+ scale_names = [
+ p.name
+ for p in plydata.elements[0].properties
+ if p.name.startswith("scale_")
+ ]
+ scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
+ for idx, attr_name in enumerate(scale_names):
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ rot_names = [
+ p.name for p in plydata.elements[0].properties if p.name.startswith("rot")
+ ]
+ rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
+ for idx, attr_name in enumerate(rot_names):
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
+
+ self._xyz = nn.Parameter(
+ torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)
+ )
+ self._features_dc = nn.Parameter(
+ torch.tensor(features_dc, dtype=torch.float, device="cuda")
+ .transpose(1, 2)
+ .contiguous()
+ .requires_grad_(True)
+ )
+ self._features_rest = nn.Parameter(
+ torch.tensor(features_extra, dtype=torch.float, device="cuda")
+ .transpose(1, 2)
+ .contiguous()
+ .requires_grad_(True)
+ )
+ self._opacity = nn.Parameter(
+ torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(
+ True
+ )
+ )
+ self._scaling = nn.Parameter(
+ torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)
+ )
+ self._rotation = nn.Parameter(
+ torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)
+ )
+
+ self.active_sh_degree = self.max_sh_degree
+
+ def replace_tensor_to_optimizer(self, tensor, name):
+ optimizable_tensors = {}
+ for group in self.optimizer.param_groups:
+ if group["name"] == name:
+ stored_state = self.optimizer.state.get(group["params"][0], None)
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
+
+ del self.optimizer.state[group["params"][0]]
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
+ self.optimizer.state[group["params"][0]] = stored_state
+
+ optimizable_tensors[group["name"]] = group["params"][0]
+ return optimizable_tensors
+
+ def _prune_optimizer(self, mask):
+ optimizable_tensors = {}
+ for group in self.optimizer.param_groups:
+ stored_state = self.optimizer.state.get(group["params"][0], None)
+ if stored_state is not None:
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
+
+ del self.optimizer.state[group["params"][0]]
+ group["params"][0] = nn.Parameter(
+ (group["params"][0][mask].requires_grad_(True))
+ )
+ self.optimizer.state[group["params"][0]] = stored_state
+
+ optimizable_tensors[group["name"]] = group["params"][0]
+ else:
+ group["params"][0] = nn.Parameter(
+ group["params"][0][mask].requires_grad_(True)
+ )
+ optimizable_tensors[group["name"]] = group["params"][0]
+ return optimizable_tensors
+
+ def prune_points(self, mask):
+ valid_points_mask = ~mask
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
+
+ self._xyz = optimizable_tensors["xyz"]
+ self._features_dc = optimizable_tensors["f_dc"]
+ self._features_rest = optimizable_tensors["f_rest"]
+ self._opacity = optimizable_tensors["opacity"]
+ self._scaling = optimizable_tensors["scaling"]
+ self._rotation = optimizable_tensors["rotation"]
+
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
+
+ self.denom = self.denom[valid_points_mask]
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
+
+ def cat_tensors_to_optimizer(self, tensors_dict):
+ optimizable_tensors = {}
+ for group in self.optimizer.param_groups:
+ assert len(group["params"]) == 1
+ extension_tensor = tensors_dict[group["name"]]
+ stored_state = self.optimizer.state.get(group["params"][0], None)
+ if stored_state is not None:
+ stored_state["exp_avg"] = torch.cat(
+ (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0
+ )
+ stored_state["exp_avg_sq"] = torch.cat(
+ (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)),
+ dim=0,
+ )
+
+ del self.optimizer.state[group["params"][0]]
+ group["params"][0] = nn.Parameter(
+ torch.cat(
+ (group["params"][0], extension_tensor), dim=0
+ ).requires_grad_(True)
+ )
+ self.optimizer.state[group["params"][0]] = stored_state
+
+ optimizable_tensors[group["name"]] = group["params"][0]
+ else:
+ group["params"][0] = nn.Parameter(
+ torch.cat(
+ (group["params"][0], extension_tensor), dim=0
+ ).requires_grad_(True)
+ )
+ optimizable_tensors[group["name"]] = group["params"][0]
+
+ return optimizable_tensors
+
+ def densification_postfix(
+ self,
+ new_xyz,
+ new_features_dc,
+ new_features_rest,
+ new_opacities,
+ new_scaling,
+ new_rotation,
+ ):
+ d = {
+ "xyz": new_xyz,
+ "f_dc": new_features_dc,
+ "f_rest": new_features_rest,
+ "opacity": new_opacities,
+ "scaling": new_scaling,
+ "rotation": new_rotation,
+ }
+
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
+ self._xyz = optimizable_tensors["xyz"]
+ self._features_dc = optimizable_tensors["f_dc"]
+ self._features_rest = optimizable_tensors["f_rest"]
+ self._opacity = optimizable_tensors["opacity"]
+ self._scaling = optimizable_tensors["scaling"]
+ self._rotation = optimizable_tensors["rotation"]
+
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
+
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
+ n_init_points = self.get_xyz.shape[0]
+ # Extract points that satisfy the gradient condition
+ padded_grad = torch.zeros((n_init_points), device="cuda")
+ padded_grad[: grads.shape[0]] = grads.squeeze()
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
+ selected_pts_mask = torch.logical_and(
+ selected_pts_mask,
+ torch.max(self.get_scaling, dim=1).values
+ > self.percent_dense * scene_extent,
+ )
+
+ stds = self.get_scaling[selected_pts_mask].repeat(N, 1)
+ means = torch.zeros((stds.size(0), 3), device="cuda")
+ samples = torch.normal(mean=means, std=stds)
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1)
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[
+ selected_pts_mask
+ ].repeat(N, 1)
+ new_scaling = self.scaling_inverse_activation(
+ self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N)
+ )
+ new_rotation = self._rotation[selected_pts_mask].repeat(N, 1)
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1)
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1)
+ new_opacity = self._opacity[selected_pts_mask].repeat(N, 1)
+
+ self.densification_postfix(
+ new_xyz,
+ new_features_dc,
+ new_features_rest,
+ new_opacity,
+ new_scaling,
+ new_rotation,
+ )
+
+ prune_filter = torch.cat(
+ (
+ selected_pts_mask,
+ torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool),
+ )
+ )
+ self.prune_points(prune_filter)
+
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
+ # Extract points that satisfy the gradient condition
+ selected_pts_mask = torch.where(
+ torch.norm(grads, dim=-1) >= grad_threshold, True, False
+ )
+ selected_pts_mask = torch.logical_and(
+ selected_pts_mask,
+ torch.max(self.get_scaling, dim=1).values
+ <= self.percent_dense * scene_extent,
+ )
+
+ new_xyz = self._xyz[selected_pts_mask]
+ new_features_dc = self._features_dc[selected_pts_mask]
+ new_features_rest = self._features_rest[selected_pts_mask]
+ new_opacities = self._opacity[selected_pts_mask]
+ new_scaling = self._scaling[selected_pts_mask]
+ new_rotation = self._rotation[selected_pts_mask]
+
+ self.densification_postfix(
+ new_xyz,
+ new_features_dc,
+ new_features_rest,
+ new_opacities,
+ new_scaling,
+ new_rotation,
+ )
+
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
+ grads = self.xyz_gradient_accum / self.denom
+ grads[grads.isnan()] = 0.0
+
+ self.densify_and_clone(grads, max_grad, extent)
+ self.densify_and_split(grads, max_grad, extent)
+
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
+ if max_screen_size:
+ big_points_vs = self.max_radii2D > max_screen_size
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
+ prune_mask = torch.logical_or(
+ torch.logical_or(prune_mask, big_points_vs), big_points_ws
+ )
+ self.prune_points(prune_mask)
+
+ torch.cuda.empty_cache()
+
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
+ self.xyz_gradient_accum[update_filter] += torch.norm(
+ viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True
+ )
+ self.denom[update_filter] += 1
diff --git a/recon/sparse_pcd.py b/recon/sparse_pcd.py
new file mode 100644
index 0000000000000000000000000000000000000000..523a3d5ebd6d758b500d27f83b6405ff3ada76af
--- /dev/null
+++ b/recon/sparse_pcd.py
@@ -0,0 +1,141 @@
+import argparse
+import subprocess
+from pathlib import Path
+import os
+import numpy as np
+
+from skimage.io import imread, imsave
+from transforms3d.quaternions import mat2quat
+
+from colmap.database import COLMAPDatabase
+from colmap.read_write_model import CAMERA_MODEL_NAMES
+import open3d as o3d
+
+# from ldm.base_utils import read_pickle
+
+K, _, _, _, POSES = read_pickle(f'meta_info/camera-16.pkl')
+H, W, NUM_IMAGES = 256, 256, 16
+
+def extract_and_match_sift(colmap_path, database_path, image_dir):
+ cmd = [
+ str(colmap_path), 'feature_extractor',
+ '--database_path', str(database_path),
+ '--image_path', str(image_dir),
+ ]
+ print(' '.join(cmd))
+ subprocess.run(cmd, check=True)
+ cmd = [
+ str(colmap_path), 'exhaustive_matcher',
+ '--database_path', str(database_path),
+ ]
+ print(' '.join(cmd))
+ subprocess.run(cmd, check=True)
+
+def run_triangulation(colmap_path, model_path, in_sparse_model, database_path, image_dir):
+ print('Running the triangulation...')
+ model_path.mkdir(exist_ok=True, parents=True)
+ cmd = [
+ str(colmap_path), 'point_triangulator',
+ '--database_path', str(database_path),
+ '--image_path', str(image_dir),
+ '--input_path', str(in_sparse_model),
+ '--output_path', str(model_path),
+ '--Mapper.ba_refine_focal_length', '0',
+ '--Mapper.ba_refine_principal_point', '0',
+ '--Mapper.ba_refine_extra_params', '0']
+ print(' '.join(cmd))
+ subprocess.run(cmd, check=True)
+
+def run_patch_match(colmap_path, sparse_model: Path, image_dir: Path, dense_model: Path):
+ print('Running patch match...')
+ assert sparse_model.exists()
+ dense_model.mkdir(parents=True, exist_ok=True)
+ cmd = [str(colmap_path), 'image_undistorter', '--input_path', str(sparse_model), '--image_path', str(image_dir), '--output_path', str(dense_model),]
+ print(' '.join(cmd))
+ subprocess.run(cmd, check=True)
+ cmd = [str(colmap_path), 'patch_match_stereo','--workspace_path', str(dense_model),]
+ print(' '.join(cmd))
+ subprocess.run(cmd, check=True)
+
+def dump_images(in_image_dir, image_dir):
+ for index in range(NUM_IMAGES):
+ img = imread(f'{in_image_dir}/{index:03}.png')
+ imsave(f'{str(image_dir)}/{index:03}.png', img)
+
+def build_db_known_poses_fixed(db_path, in_sparse_path):
+ db = COLMAPDatabase.connect(db_path)
+ db.create_tables()
+
+ # insert intrinsics
+ with open(f'{str(in_sparse_path)}/cameras.txt', 'w') as f:
+ for index in range(NUM_IMAGES):
+ fx, fy = K[0,0], K[1,1]
+ cx, cy = K[0,2], K[1,2]
+ model, width, height, params = CAMERA_MODEL_NAMES['PINHOLE'].model_id, W, H, np.array((fx, fy, cx, cy),np.float32)
+ db.add_camera(model, width, height, params, prior_focal_length=(fx+fy)/2, camera_id=index+1)
+ f.write(f'{index+1} PINHOLE {W} {H} {fx:.3f} {fy:.3f} {cx:.3f} {cy:.3f}\n')
+
+ with open(f'{str(in_sparse_path)}/images.txt','w') as f:
+ for index in range(NUM_IMAGES):
+ pose = POSES[index]
+ q = mat2quat(pose[:,:3])
+ t = pose[:,3]
+ img_id = db.add_image(f"{index:03}.png", camera_id=index+1, prior_q=q, prior_t=t)
+ f.write(f'{img_id} {q[0]:.5f} {q[1]:.5f} {q[2]:.5f} {q[3]:.5f} {t[0]:.5f} {t[1]:.5f} {t[2]:.5f} {index+1} {index:03}.png\n\n')
+
+ db.commit()
+ db.close()
+
+ with open(f'{in_sparse_path}/points3D.txt','w') as f:
+ f.write('\n')
+
+
+def patch_match_with_known_poses(in_image_dir, project_dir, colmap_path='colmap'):
+ Path(project_dir).mkdir(exist_ok=True, parents=True)
+ if os.path.exists(f'{str(project_dir)}/dense/stereo/depth_maps'): return
+
+ # output poses
+ db_path = f'{str(project_dir)}/database.db'
+ image_dir = Path(f'{str(project_dir)}/images')
+ sparse_dir = Path(f'{str(project_dir)}/sparse')
+ in_sparse_dir = Path(f'{str(project_dir)}/sparse_in')
+ dense_dir = Path(f'{str(project_dir)}/dense')
+
+ image_dir.mkdir(exist_ok=True,parents=True)
+ sparse_dir.mkdir(exist_ok=True,parents=True)
+ in_sparse_dir.mkdir(exist_ok=True,parents=True)
+ dense_dir.mkdir(exist_ok=True,parents=True)
+
+ dump_images(in_image_dir, image_dir)
+ build_db_known_poses_fixed(db_path, in_sparse_dir)
+ extract_and_match_sift(colmap_path, db_path, image_dir)
+ run_triangulation(colmap_path,sparse_dir, in_sparse_dir, db_path, image_dir)
+ run_patch_match(colmap_path, sparse_dir, image_dir, dense_dir)
+
+ # fuse
+ cmd = [str(colmap_path), 'stereo_fusion',
+ '--workspace_path', f'{project_dir}/dense',
+ '--workspace_format', 'COLMAP',
+ '--input_type', 'geometric',
+ '--output_path', f'{project_dir}/points.ply',]
+ print(' '.join(cmd))
+ subprocess.run(cmd, check=True)
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--dir',type=str)
+ parser.add_argument('--project',type=str)
+ parser.add_argument('--name',type=str)
+ parser.add_argument('--colmap',type=str, default='colmap')
+ args = parser.parse_args()
+
+ if not os.path.exists(f'{args.project}/points.ply'):
+ patch_match_with_known_poses(args.dir, args.project, colmap_path=args.colmap)
+
+ mesh = o3d.io.read_triangle_mesh(f'{args.project}/points.ply',)
+ vn = len(mesh.vertices)
+ with open('colmap-results.log', 'a') as f:
+ f.write(f'{args.name}\t{vn}\n')
+
+if __name__=="__main__":
+ main()
\ No newline at end of file
diff --git a/recon/sync_submodules.sh b/recon/sync_submodules.sh
new file mode 100755
index 0000000000000000000000000000000000000000..2e7d433381b376ead37297aab5cde26205b1f7c6
--- /dev/null
+++ b/recon/sync_submodules.sh
@@ -0,0 +1,14 @@
+#!/bin/sh
+
+set -e
+
+git config -f .gitmodules --get-regexp '^submodule\..*\.path$' |
+ while read path_key path
+ do
+ name=$(echo $path_key | sed 's/\submodule\.\(.*\)\.path/\1/')
+ url_key=$(echo $path_key | sed 's/\.path/.url/')
+ branch_key=$(echo $path_key | sed 's/\.path/.branch/')
+ url=$(git config -f .gitmodules --get "$url_key")
+ branch=$(git config -f .gitmodules --get "$branch_key" || echo "master")
+ git submodule add -b $branch --name $name $url $path || continue
+ done
diff --git a/recon/train.py b/recon/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a89ca9785071178a1b8505ebe512a328118f41d
--- /dev/null
+++ b/recon/train.py
@@ -0,0 +1,369 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import torch
+from random import randint
+from utils.loss_utils import l1_loss, ssim, lpips
+from gaussian_renderer import render, network_gui
+import sys
+from scene import Scene, GaussianModel
+from utils.general_utils import safe_state
+import uuid
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser, Namespace
+from arguments import ModelParams, PipelineParams, OptimizationParams
+
+from scripts.sampling.simple_mv_sample import sample_one
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+
+ TENSORBOARD_FOUND = True
+except ImportError:
+ TENSORBOARD_FOUND = False
+
+
+def training(
+ dataset,
+ opt,
+ pipe,
+ testing_iterations,
+ saving_iterations,
+ checkpoint_iterations,
+ checkpoint,
+ debug_from,
+):
+ first_iter = 0
+ tb_writer = prepare_output_and_logger(dataset)
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians)
+ gaussians.training_setup(opt)
+ if checkpoint:
+ (model_params, first_iter) = torch.load(checkpoint)
+ gaussians.restore(model_params, opt)
+
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+ iter_start = torch.cuda.Event(enable_timing=True)
+ iter_end = torch.cuda.Event(enable_timing=True)
+
+ viewpoint_stack = None
+ ema_loss_for_log = 0.0
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
+ first_iter += 1
+ for iteration in range(first_iter, opt.iterations + 1):
+ if network_gui.conn == None:
+ network_gui.try_connect()
+ while network_gui.conn != None:
+ try:
+ net_image_bytes = None
+ (
+ custom_cam,
+ do_training,
+ pipe.convert_SHs_python,
+ pipe.compute_cov3D_python,
+ keep_alive,
+ scaling_modifer,
+ ) = network_gui.receive()
+ if custom_cam != None:
+ net_image = render(
+ custom_cam, gaussians, pipe, background, scaling_modifer
+ )["render"]
+ net_image_bytes = memoryview(
+ (torch.clamp(net_image, min=0, max=1.0) * 255)
+ .byte()
+ .permute(1, 2, 0)
+ .contiguous()
+ .cpu()
+ .numpy()
+ )
+ network_gui.send(net_image_bytes, dataset.source_path)
+ if do_training and (
+ (iteration < int(opt.iterations)) or not keep_alive
+ ):
+ break
+ except Exception as e:
+ network_gui.conn = None
+
+ iter_start.record()
+
+ gaussians.update_learning_rate(iteration)
+
+ # Every 1000 its we increase the levels of SH up to a maximum degree
+ if iteration % 1000 == 0:
+ gaussians.oneupSHdegree()
+
+ # Pick a random Camera
+ if not viewpoint_stack:
+ viewpoint_stack = scene.getTrainCameras().copy()
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
+
+ # Render
+ if (iteration - 1) == debug_from:
+ pipe.debug = True
+
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+
+ # Loss
+ gt_image = viewpoint_cam.original_image.cuda()
+ Ll1 = l1_loss(image, gt_image)
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
+ 1.0 - ssim(image, gt_image)
+ )
+ if opt.lambda_lpips > 0:
+ loss += opt.lambda_lpips * lpips(image, gt_image)
+ loss.backward()
+
+ iter_end.record()
+
+ with torch.no_grad():
+ # Progress bar
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
+ if iteration % 10 == 0:
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
+ progress_bar.update(10)
+ if iteration == opt.iterations:
+ progress_bar.close()
+
+ # Log and save
+ training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ iter_start.elapsed_time(iter_end),
+ testing_iterations,
+ scene,
+ render,
+ (pipe, background),
+ )
+ if iteration in saving_iterations:
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
+ scene.save(iteration)
+
+ # Densification
+ if iteration < opt.densify_until_iter:
+ # Keep track of max radii in image-space for pruning
+ gaussians.max_radii2D[visibility_filter] = torch.max(
+ gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
+ )
+ gaussians.add_densification_stats(
+ viewspace_point_tensor, visibility_filter
+ )
+
+ if (
+ iteration > opt.densify_from_iter
+ and iteration % opt.densification_interval == 0
+ ):
+ size_threshold = (
+ 20 if iteration > opt.opacity_reset_interval else None
+ )
+ gaussians.densify_and_prune(
+ opt.densify_grad_threshold,
+ 0.005,
+ scene.cameras_extent,
+ size_threshold,
+ )
+
+ if iteration % opt.opacity_reset_interval == 0 or (
+ dataset.white_background and iteration == opt.densify_from_iter
+ ):
+ gaussians.reset_opacity()
+
+ # Optimizer step
+ if iteration < opt.iterations:
+ gaussians.optimizer.step()
+ gaussians.optimizer.zero_grad(set_to_none=True)
+
+ if iteration in checkpoint_iterations:
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
+ torch.save(
+ (gaussians.capture(), iteration),
+ scene.model_path + "/chkpnt" + str(iteration) + ".pth",
+ )
+
+
+def prepare_output_and_logger(args):
+ if not args.model_path:
+ if os.getenv("OAR_JOB_ID"):
+ unique_str = os.getenv("OAR_JOB_ID")
+ else:
+ unique_str = str(uuid.uuid4())
+ args.model_path = os.path.join("./output/", unique_str[0:10])
+
+ # Set up output folder
+ print("Output folder: {}".format(args.model_path))
+ os.makedirs(args.model_path, exist_ok=True)
+ with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
+ cfg_log_f.write(str(Namespace(**vars(args))))
+
+ # Create Tensorboard writer
+ tb_writer = None
+ if TENSORBOARD_FOUND:
+ tb_writer = SummaryWriter(args.model_path)
+ else:
+ print("Tensorboard not available: not logging progress")
+ return tb_writer
+
+
+def training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ elapsed,
+ testing_iterations,
+ scene: Scene,
+ renderFunc,
+ renderArgs,
+):
+ if tb_writer:
+ tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration)
+ tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration)
+ tb_writer.add_scalar("iter_time", elapsed, iteration)
+
+ # Report test and samples of training set
+ if iteration in testing_iterations:
+ torch.cuda.empty_cache()
+ validation_configs = (
+ {"name": "test", "cameras": scene.getTestCameras()},
+ {
+ "name": "train",
+ "cameras": [
+ scene.getTrainCameras()[idx % len(scene.getTrainCameras())]
+ for idx in range(5, 30, 5)
+ ],
+ },
+ )
+
+ for config in validation_configs:
+ if config["cameras"] and len(config["cameras"]) > 0:
+ l1_test = 0.0
+ psnr_test = 0.0
+ for idx, viewpoint in enumerate(config["cameras"]):
+ image = torch.clamp(
+ renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"],
+ 0.0,
+ 1.0,
+ )
+ gt_image = torch.clamp(
+ viewpoint.original_image.to("cuda"), 0.0, 1.0
+ )
+ if tb_writer and (idx < 5):
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/render".format(viewpoint.image_name),
+ image[None],
+ global_step=iteration,
+ )
+ if iteration == testing_iterations[0]:
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/ground_truth".format(viewpoint.image_name),
+ gt_image[None],
+ global_step=iteration,
+ )
+ l1_test += l1_loss(image, gt_image).mean().double()
+ psnr_test += psnr(image, gt_image).mean().double()
+ psnr_test /= len(config["cameras"])
+ l1_test /= len(config["cameras"])
+ print(
+ "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(
+ iteration, config["name"], l1_test, psnr_test
+ )
+ )
+ if tb_writer:
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration
+ )
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration
+ )
+
+ if tb_writer:
+ tb_writer.add_histogram(
+ "scene/opacity_histogram", scene.gaussians.get_opacity, iteration
+ )
+ tb_writer.add_scalar(
+ "total_points", scene.gaussians.get_xyz.shape[0], iteration
+ )
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--image", type=str, default="assets/images/ceramic.png")
+ parser.add_argument("--ckpt_path", type=str, required=True)
+ parser.add_argument("--ip", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=6009)
+ parser.add_argument("--debug_from", type=int, default=-1)
+ parser.add_argument("--detect_anomaly", action="store_true", default=False)
+ parser.add_argument(
+ "--test_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument(
+ "--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument("--quiet", action="store_true")
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
+ parser.add_argument("--start_checkpoint", type=str, default=None)
+ args = parser.parse_args(sys.argv[1:])
+ args.save_iterations.append(args.iterations)
+
+ print("Optimizing " + args.model_path)
+
+ # Initialize system state (RNG)
+ safe_state(args.quiet)
+
+ # Start GUI server, configure and run training
+ network_gui.init(args.ip, args.port)
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
+
+ print("=====Start generating MV Images=====")
+
+ images, _ = sample_one(args.image, args.ckpt_path, seed=args.seed)
+
+ print("=====Finish generating MV Images=====")
+
+ lp = lp.extract(args)
+ lp.images = images
+
+ training(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ args.test_iterations,
+ args.save_iterations,
+ args.checkpoint_iterations,
+ args.start_checkpoint,
+ args.debug_from,
+ )
+
+ # All done
+ print("\nTraining complete.")
diff --git a/recon/train_512.py b/recon/train_512.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7e70fe62e842d7ac16f937e6ae7a6c878ce5def
--- /dev/null
+++ b/recon/train_512.py
@@ -0,0 +1,381 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import torch
+from random import randint
+from utils.loss_utils import l1_loss, ssim, lpips
+from gaussian_renderer import render, network_gui
+import sys
+from scene import Scene, GaussianModel
+from utils.general_utils import safe_state
+import uuid
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser, Namespace
+from arguments import ModelParams, PipelineParams, OptimizationParams
+
+from scripts.sampling.simple_mv_latent_sample import sample_one
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+
+ TENSORBOARD_FOUND = True
+except ImportError:
+ TENSORBOARD_FOUND = False
+
+
+def training(
+ dataset,
+ opt,
+ pipe,
+ testing_iterations,
+ saving_iterations,
+ checkpoint_iterations,
+ checkpoint,
+ debug_from,
+):
+ first_iter = 0
+ tb_writer = prepare_output_and_logger(dataset)
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians)
+ gaussians.training_setup(opt)
+ if checkpoint:
+ (model_params, first_iter) = torch.load(checkpoint)
+ gaussians.restore(model_params, opt)
+
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+ iter_start = torch.cuda.Event(enable_timing=True)
+ iter_end = torch.cuda.Event(enable_timing=True)
+
+ viewpoint_stack = None
+ ema_loss_for_log = 0.0
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
+ first_iter += 1
+ for iteration in range(first_iter, opt.iterations + 1):
+ if network_gui.conn == None:
+ network_gui.try_connect()
+ while network_gui.conn != None:
+ try:
+ net_image_bytes = None
+ (
+ custom_cam,
+ do_training,
+ pipe.convert_SHs_python,
+ pipe.compute_cov3D_python,
+ keep_alive,
+ scaling_modifer,
+ ) = network_gui.receive()
+ if custom_cam != None:
+ net_image = render(
+ custom_cam, gaussians, pipe, background, scaling_modifer
+ )["render"]
+ net_image_bytes = memoryview(
+ (torch.clamp(net_image, min=0, max=1.0) * 255)
+ .byte()
+ .permute(1, 2, 0)
+ .contiguous()
+ .cpu()
+ .numpy()
+ )
+ network_gui.send(net_image_bytes, dataset.source_path)
+ if do_training and (
+ (iteration < int(opt.iterations)) or not keep_alive
+ ):
+ break
+ except Exception as e:
+ network_gui.conn = None
+
+ iter_start.record()
+
+ gaussians.update_learning_rate(iteration)
+
+ # Every 1000 its we increase the levels of SH up to a maximum degree
+ if iteration % 1000 == 0:
+ gaussians.oneupSHdegree()
+
+ # Pick a random Camera
+ if not viewpoint_stack:
+ viewpoint_stack = scene.getTrainCameras().copy()
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
+
+ # Render
+ if (iteration - 1) == debug_from:
+ pipe.debug = True
+
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+
+ # Loss
+ gt_image = viewpoint_cam.original_image.cuda()
+ Ll1 = l1_loss(image, gt_image)
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
+ 1.0 - ssim(image, gt_image)
+ )
+ if opt.lambda_lpips > 0:
+ loss += opt.lambda_lpips * lpips(image, gt_image)
+ loss.backward()
+
+ iter_end.record()
+
+ with torch.no_grad():
+ # Progress bar
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
+ if iteration % 10 == 0:
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
+ progress_bar.update(10)
+ if iteration == opt.iterations:
+ progress_bar.close()
+
+ # Log and save
+ training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ iter_start.elapsed_time(iter_end),
+ testing_iterations,
+ scene,
+ render,
+ (pipe, background),
+ )
+ if iteration in saving_iterations:
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
+ scene.save(iteration)
+
+ # Densification
+ if iteration < opt.densify_until_iter:
+ # Keep track of max radii in image-space for pruning
+ gaussians.max_radii2D[visibility_filter] = torch.max(
+ gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
+ )
+ gaussians.add_densification_stats(
+ viewspace_point_tensor, visibility_filter
+ )
+
+ if (
+ iteration > opt.densify_from_iter
+ and iteration % opt.densification_interval == 0
+ ):
+ size_threshold = (
+ 20 if iteration > opt.opacity_reset_interval else None
+ )
+ gaussians.densify_and_prune(
+ opt.densify_grad_threshold,
+ 0.005,
+ scene.cameras_extent,
+ size_threshold,
+ )
+
+ if iteration % opt.opacity_reset_interval == 0 or (
+ dataset.white_background and iteration == opt.densify_from_iter
+ ):
+ gaussians.reset_opacity()
+
+ # Optimizer step
+ if iteration < opt.iterations:
+ gaussians.optimizer.step()
+ gaussians.optimizer.zero_grad(set_to_none=True)
+
+ if iteration in checkpoint_iterations:
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
+ torch.save(
+ (gaussians.capture(), iteration),
+ scene.model_path + "/chkpnt" + str(iteration) + ".pth",
+ )
+
+
+def prepare_output_and_logger(args):
+ if not args.model_path:
+ if os.getenv("OAR_JOB_ID"):
+ unique_str = os.getenv("OAR_JOB_ID")
+ else:
+ unique_str = str(uuid.uuid4())
+ args.model_path = os.path.join("./output/", unique_str[0:10])
+
+ # Set up output folder
+ print("Output folder: {}".format(args.model_path))
+ os.makedirs(args.model_path, exist_ok=True)
+ with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
+ cfg_log_f.write(str(Namespace(**vars(args))))
+
+ # Create Tensorboard writer
+ tb_writer = None
+ if TENSORBOARD_FOUND:
+ tb_writer = SummaryWriter(args.model_path)
+ else:
+ print("Tensorboard not available: not logging progress")
+ return tb_writer
+
+
+def training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ elapsed,
+ testing_iterations,
+ scene: Scene,
+ renderFunc,
+ renderArgs,
+):
+ if tb_writer:
+ tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration)
+ tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration)
+ tb_writer.add_scalar("iter_time", elapsed, iteration)
+
+ # Report test and samples of training set
+ if iteration in testing_iterations:
+ torch.cuda.empty_cache()
+ validation_configs = (
+ {"name": "test", "cameras": scene.getTestCameras()},
+ {
+ "name": "train",
+ "cameras": [
+ scene.getTrainCameras()[idx % len(scene.getTrainCameras())]
+ for idx in range(5, 30, 5)
+ ],
+ },
+ )
+
+ for config in validation_configs:
+ if config["cameras"] and len(config["cameras"]) > 0:
+ l1_test = 0.0
+ psnr_test = 0.0
+ for idx, viewpoint in enumerate(config["cameras"]):
+ image = torch.clamp(
+ renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"],
+ 0.0,
+ 1.0,
+ )
+ gt_image = torch.clamp(
+ viewpoint.original_image.to("cuda"), 0.0, 1.0
+ )
+ if tb_writer and (idx < 5):
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/render".format(viewpoint.image_name),
+ image[None],
+ global_step=iteration,
+ )
+ if iteration == testing_iterations[0]:
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/ground_truth".format(viewpoint.image_name),
+ gt_image[None],
+ global_step=iteration,
+ )
+ l1_test += l1_loss(image, gt_image).mean().double()
+ psnr_test += psnr(image, gt_image).mean().double()
+ psnr_test /= len(config["cameras"])
+ l1_test /= len(config["cameras"])
+ print(
+ "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(
+ iteration, config["name"], l1_test, psnr_test
+ )
+ )
+ if tb_writer:
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration
+ )
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration
+ )
+
+ if tb_writer:
+ tb_writer.add_histogram(
+ "scene/opacity_histogram", scene.gaussians.get_opacity, iteration
+ )
+ tb_writer.add_scalar(
+ "total_points", scene.gaussians.get_xyz.shape[0], iteration
+ )
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--image", type=str, default="assets/images/ceramic.png")
+ parser.add_argument("--ckpt_path", type=str, required=True)
+ parser.add_argument("--ip", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=6009)
+ parser.add_argument("--debug_from", type=int, default=-1)
+ parser.add_argument("--detect_anomaly", action="store_true", default=False)
+ parser.add_argument(
+ "--test_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument(
+ "--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument("--quiet", action="store_true")
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
+ parser.add_argument("--start_checkpoint", type=str, default=None)
+ parser.add_argument("--border_ratio", type=float, default=0.3)
+ parser.add_argument("--min_guidance_scale", type=float, default=1.0)
+ parser.add_argument("--max_guidance_scale", type=float, default=2.5)
+ parser.add_argument("--sigma_max", type=float, default=None)
+ args = parser.parse_args(sys.argv[1:])
+ args.save_iterations.append(args.iterations)
+
+ print("Optimizing " + args.model_path)
+
+ # Initialize system state (RNG)
+ safe_state(args.quiet)
+
+ # Start GUI server, configure and run training
+ network_gui.init(args.ip, args.port)
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
+
+ print("=====Start generating MV Images=====")
+
+ images, _ = sample_one(
+ args.image,
+ args.ckpt_path,
+ seed=args.seed,
+ border_ratio=args.border_ratio,
+ min_guidance_scale=args.min_guidance_scale,
+ max_guidance_scale=args.max_guidance_scale,
+ sigma_max=args.sigma_max,
+ )
+
+ print("=====Finish generating MV Images=====")
+
+ lp = lp.extract(args)
+ lp.images = images
+
+ training(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ args.test_iterations,
+ args.save_iterations,
+ args.checkpoint_iterations,
+ args.start_checkpoint,
+ args.debug_from,
+ )
+
+ # All done
+ print("\nTraining complete.")
diff --git a/recon/train_autoaggressive.py b/recon/train_autoaggressive.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/recon/train_from_vid.py b/recon/train_from_vid.py
new file mode 100644
index 0000000000000000000000000000000000000000..88b5fd8cb8144e0d81dafbebd89ebae46b9ff9de
--- /dev/null
+++ b/recon/train_from_vid.py
@@ -0,0 +1,389 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import torch
+from random import randint
+from PIL import Image
+from mediapy import read_video
+from utils.loss_utils import l1_loss, ssim, lpips
+from gaussian_renderer import render, network_gui
+import sys
+from scene import Scene, GaussianModel
+from utils.general_utils import safe_state
+import uuid
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser, Namespace
+from arguments import ModelParams, PipelineParams, OptimizationParams
+
+from scripts.sampling.simple_mv_latent_sample import sample_one
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+
+ TENSORBOARD_FOUND = True
+except ImportError:
+ TENSORBOARD_FOUND = False
+
+
+def training(
+ dataset,
+ opt,
+ pipe,
+ testing_iterations,
+ saving_iterations,
+ checkpoint_iterations,
+ checkpoint,
+ debug_from,
+):
+ first_iter = 0
+ tb_writer = prepare_output_and_logger(dataset)
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians)
+ gaussians.training_setup(opt)
+ if checkpoint:
+ (model_params, first_iter) = torch.load(checkpoint)
+ gaussians.restore(model_params, opt)
+
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+ iter_start = torch.cuda.Event(enable_timing=True)
+ iter_end = torch.cuda.Event(enable_timing=True)
+
+ viewpoint_stack = None
+ ema_loss_for_log = 0.0
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
+ first_iter += 1
+ for iteration in range(first_iter, opt.iterations + 1):
+ if network_gui.conn == None:
+ network_gui.try_connect()
+ while network_gui.conn != None:
+ try:
+ net_image_bytes = None
+ (
+ custom_cam,
+ do_training,
+ pipe.convert_SHs_python,
+ pipe.compute_cov3D_python,
+ keep_alive,
+ scaling_modifer,
+ ) = network_gui.receive()
+ if custom_cam != None:
+ net_image = render(
+ custom_cam, gaussians, pipe, background, scaling_modifer
+ )["render"]
+ net_image_bytes = memoryview(
+ (torch.clamp(net_image, min=0, max=1.0) * 255)
+ .byte()
+ .permute(1, 2, 0)
+ .contiguous()
+ .cpu()
+ .numpy()
+ )
+ network_gui.send(net_image_bytes, dataset.source_path)
+ if do_training and (
+ (iteration < int(opt.iterations)) or not keep_alive
+ ):
+ break
+ except Exception as e:
+ network_gui.conn = None
+
+ iter_start.record()
+
+ gaussians.update_learning_rate(iteration)
+
+ # Every 1000 its we increase the levels of SH up to a maximum degree
+ if iteration % 1000 == 0:
+ gaussians.oneupSHdegree()
+
+ # Pick a random Camera
+ if not viewpoint_stack:
+ viewpoint_stack = scene.getTrainCameras().copy()
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
+
+ # Render
+ if (iteration - 1) == debug_from:
+ pipe.debug = True
+
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+
+ # Loss
+ gt_image = viewpoint_cam.original_image.cuda()
+ Ll1 = l1_loss(image, gt_image)
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
+ 1.0 - ssim(image, gt_image)
+ )
+ if opt.lambda_lpips > 0:
+ loss += opt.lambda_lpips * lpips(image, gt_image)
+
+ loss += torch.mean(gaussians.get_opacity) * 0.1
+
+ loss.backward()
+
+ iter_end.record()
+
+ with torch.no_grad():
+ # Progress bar
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
+ if iteration % 10 == 0:
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
+ progress_bar.update(10)
+ if iteration == opt.iterations:
+ progress_bar.close()
+
+ # Log and save
+ training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ iter_start.elapsed_time(iter_end),
+ testing_iterations,
+ scene,
+ render,
+ (pipe, background),
+ )
+ if iteration in saving_iterations:
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
+ scene.save(iteration)
+
+ # Densification
+ if iteration < opt.densify_until_iter:
+ # Keep track of max radii in image-space for pruning
+ gaussians.max_radii2D[visibility_filter] = torch.max(
+ gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
+ )
+ gaussians.add_densification_stats(
+ viewspace_point_tensor, visibility_filter
+ )
+
+ if (
+ iteration > opt.densify_from_iter
+ and iteration % opt.densification_interval == 0
+ ):
+ size_threshold = (
+ 20 if iteration > opt.opacity_reset_interval else None
+ )
+ gaussians.densify_and_prune(
+ opt.densify_grad_threshold,
+ 0.005,
+ scene.cameras_extent,
+ size_threshold,
+ )
+
+ if iteration % opt.opacity_reset_interval == 0 or (
+ dataset.white_background and iteration == opt.densify_from_iter
+ ):
+ gaussians.reset_opacity()
+
+ # Optimizer step
+ if iteration < opt.iterations:
+ gaussians.optimizer.step()
+ gaussians.optimizer.zero_grad(set_to_none=True)
+
+ if iteration in checkpoint_iterations:
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
+ torch.save(
+ (gaussians.capture(), iteration),
+ scene.model_path + "/chkpnt" + str(iteration) + ".pth",
+ )
+
+
+def prepare_output_and_logger(args):
+ if not args.model_path:
+ if os.getenv("OAR_JOB_ID"):
+ unique_str = os.getenv("OAR_JOB_ID")
+ else:
+ unique_str = str(uuid.uuid4())
+ args.model_path = os.path.join("./output/", unique_str[0:10])
+
+ # Set up output folder
+ print("Output folder: {}".format(args.model_path))
+ os.makedirs(args.model_path, exist_ok=True)
+ with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
+ cfg_log_f.write(str(Namespace(**vars(args))))
+
+ # Create Tensorboard writer
+ tb_writer = None
+ if TENSORBOARD_FOUND:
+ tb_writer = SummaryWriter(args.model_path)
+ else:
+ print("Tensorboard not available: not logging progress")
+ return tb_writer
+
+
+def training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ elapsed,
+ testing_iterations,
+ scene: Scene,
+ renderFunc,
+ renderArgs,
+):
+ if tb_writer:
+ tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration)
+ tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration)
+ tb_writer.add_scalar("iter_time", elapsed, iteration)
+
+ # Report test and samples of training set
+ if iteration in testing_iterations:
+ torch.cuda.empty_cache()
+ validation_configs = (
+ {"name": "test", "cameras": scene.getTestCameras()},
+ {
+ "name": "train",
+ "cameras": [
+ scene.getTrainCameras()[idx % len(scene.getTrainCameras())]
+ for idx in range(5, 30, 5)
+ ],
+ },
+ )
+
+ for config in validation_configs:
+ if config["cameras"] and len(config["cameras"]) > 0:
+ l1_test = 0.0
+ psnr_test = 0.0
+ for idx, viewpoint in enumerate(config["cameras"]):
+ image = torch.clamp(
+ renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"],
+ 0.0,
+ 1.0,
+ )
+ gt_image = torch.clamp(
+ viewpoint.original_image.to("cuda"), 0.0, 1.0
+ )
+ if tb_writer and (idx < 5):
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/render".format(viewpoint.image_name),
+ image[None],
+ global_step=iteration,
+ )
+ if iteration == testing_iterations[0]:
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/ground_truth".format(viewpoint.image_name),
+ gt_image[None],
+ global_step=iteration,
+ )
+ l1_test += l1_loss(image, gt_image).mean().double()
+ psnr_test += psnr(image, gt_image).mean().double()
+ psnr_test /= len(config["cameras"])
+ l1_test /= len(config["cameras"])
+ print(
+ "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(
+ iteration, config["name"], l1_test, psnr_test
+ )
+ )
+ if tb_writer:
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration
+ )
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration
+ )
+
+ if tb_writer:
+ tb_writer.add_histogram(
+ "scene/opacity_histogram", scene.gaussians.get_opacity, iteration
+ )
+ tb_writer.add_scalar(
+ "total_points", scene.gaussians.get_xyz.shape[0], iteration
+ )
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--video", type=str, default="")
+ parser.add_argument("--ip", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=6009)
+ parser.add_argument("--debug_from", type=int, default=-1)
+ parser.add_argument("--detect_anomaly", action="store_true", default=False)
+ parser.add_argument(
+ "--test_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument(
+ "--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument("--quiet", action="store_true")
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
+ parser.add_argument("--start_checkpoint", type=str, default=None)
+ parser.add_argument("--border_ratio", type=float, default=0.3)
+ parser.add_argument("--min_guidance_scale", type=float, default=1.0)
+ parser.add_argument("--max_guidance_scale", type=float, default=2.5)
+ parser.add_argument("--sigma_max", type=float, default=None)
+ args = parser.parse_args(sys.argv[1:])
+ args.save_iterations.append(args.iterations)
+
+ print("Optimizing " + args.model_path)
+
+ # Initialize system state (RNG)
+ safe_state(args.quiet)
+
+ # Start GUI server, configure and run training
+ network_gui.init(args.ip, args.port)
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
+
+ print("=====Start generating MV Images=====")
+
+ # images, _ = sample_one(
+ # args.image,
+ # args.ckpt_path,
+ # seed=args.seed,
+ # border_ratio=args.border_ratio,
+ # min_guidance_scale=args.min_guidance_scale,
+ # max_guidance_scale=args.max_guidance_scale,
+ # sigma_max=args.sigma_max,
+ # )
+ images = []
+ frames = read_video(args.video)
+ for frame in frames:
+ images.append(Image.fromarray(frame))
+
+ print("=====Finish generating MV Images=====")
+
+ lp = lp.extract(args)
+ lp.images = images
+
+ training(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ args.test_iterations,
+ args.save_iterations,
+ args.checkpoint_iterations,
+ args.start_checkpoint,
+ args.debug_from,
+ )
+
+ # All done
+ print("\nTraining complete.")
diff --git a/recon/train_iterative.py b/recon/train_iterative.py
new file mode 100644
index 0000000000000000000000000000000000000000..91d2a79d7307367f864718fbded50863a6cc7d11
--- /dev/null
+++ b/recon/train_iterative.py
@@ -0,0 +1,400 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import torch
+import numpy as np
+from torchvision.transforms.functional import pil_to_tensor, to_tensor
+from torchvision.utils import make_grid, save_image
+from random import randint
+from utils.loss_utils import l1_loss, ssim, lpips
+from gaussian_renderer import render, network_gui
+import sys
+from scene import Scene, GaussianModel
+from utils.general_utils import safe_state
+import uuid
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser, Namespace
+from arguments import ModelParams, PipelineParams, OptimizationParams
+
+from scripts.sampling.simple_mv_sample import sample_one
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+
+ TENSORBOARD_FOUND = True
+except ImportError:
+ TENSORBOARD_FOUND = False
+
+
+def training(
+ dataset,
+ opt,
+ pipe,
+ testing_iterations,
+ saving_iterations,
+ checkpoint_iterations,
+ checkpoint,
+ debug_from,
+ resample_period=500,
+ resample_sigma=0.1,
+ resample_start=1000,
+ model=None,
+):
+ first_iter = 0
+ tb_writer = prepare_output_and_logger(dataset)
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians, shuffle=False)
+ gaussians.training_setup(opt)
+ if checkpoint:
+ (model_params, first_iter) = torch.load(checkpoint)
+ gaussians.restore(model_params, opt)
+
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+ iter_start = torch.cuda.Event(enable_timing=True)
+ iter_end = torch.cuda.Event(enable_timing=True)
+
+ viewpoint_stack = None
+ ema_loss_for_log = 0.0
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
+ first_iter += 1
+ for iteration in range(first_iter, opt.iterations + 1):
+ iter_start.record()
+
+ gaussians.update_learning_rate(iteration)
+
+ # Every 1000 its we increase the levels of SH up to a maximum degree
+ if iteration % 1000 == 0:
+ gaussians.oneupSHdegree()
+
+ with torch.no_grad():
+ if iteration % resample_period == 0 and iteration > resample_start:
+ # if iteration % resample_period:
+ views = []
+ viewpoint_stack = scene.getTrainCameras().copy()
+ for view_cam in viewpoint_stack:
+ bg = (
+ torch.rand((3), device="cuda")
+ if opt.random_background
+ else background
+ )
+ render_pkg = render(view_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+ views.append(image)
+ views = torch.stack(views)
+ save_image(views, f"tmp/views_{iteration}.png")
+ views = views * 2.0 - 1.0
+ views = model.encode_first_stage(views)
+ noisy_views = views + torch.randn_like(views) * resample_sigma
+ noisy_views = (
+ np.sqrt(1 - resample_sigma**2) * views
+ + torch.randn_like(views) * resample_sigma
+ )
+ resampled_images = sample_one(
+ args.image,
+ args.ckpt_path,
+ noise=noisy_views,
+ cached_model=model,
+ )[0]
+ dataset.images = resampled_images
+ scene = Scene(
+ dataset,
+ gaussians,
+ shuffle=False,
+ skip_gaussians=True,
+ )
+ resampled_images_grid = []
+ for img in resampled_images:
+ resampled_images_grid.append(to_tensor(img))
+ resampled_images_grid = torch.stack(resampled_images_grid)
+ resampled_images_grid = make_grid(resampled_images_grid, nrow=3)
+ save_image(
+ resampled_images_grid, f"tmp/resampled_images_{iteration}.png"
+ )
+
+ # Pick a random Camera
+ if not viewpoint_stack:
+ viewpoint_stack = scene.getTrainCameras().copy()
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
+
+ # Render
+ if (iteration - 1) == debug_from:
+ pipe.debug = True
+
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+
+ # Loss
+ gt_image = viewpoint_cam.original_image.cuda()
+ Ll1 = l1_loss(image, gt_image)
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
+ 1.0 - ssim(image, gt_image)
+ )
+ if opt.lambda_lpips > 0:
+ loss += opt.lambda_lpips * lpips(image, gt_image)
+ loss.backward()
+
+ iter_end.record()
+
+ with torch.no_grad():
+ # Progress bar
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
+ if iteration % 10 == 0:
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
+ progress_bar.update(10)
+ if iteration == opt.iterations:
+ progress_bar.close()
+
+ # Log and save
+ training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ iter_start.elapsed_time(iter_end),
+ testing_iterations,
+ scene,
+ render,
+ (pipe, background),
+ )
+ if iteration in saving_iterations:
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
+ scene.save(iteration)
+
+ # Densification
+ if iteration < opt.densify_until_iter:
+ # Keep track of max radii in image-space for pruning
+ gaussians.max_radii2D[visibility_filter] = torch.max(
+ gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
+ )
+ gaussians.add_densification_stats(
+ viewspace_point_tensor, visibility_filter
+ )
+
+ if (
+ iteration > opt.densify_from_iter
+ and iteration % opt.densification_interval == 0
+ ):
+ size_threshold = (
+ 20 if iteration > opt.opacity_reset_interval else None
+ )
+ gaussians.densify_and_prune(
+ opt.densify_grad_threshold,
+ 0.005,
+ scene.cameras_extent,
+ size_threshold,
+ )
+
+ if iteration % opt.opacity_reset_interval == 0 or (
+ dataset.white_background and iteration == opt.densify_from_iter
+ ):
+ gaussians.reset_opacity()
+
+ # Optimizer step
+ if iteration < opt.iterations:
+ gaussians.optimizer.step()
+ gaussians.optimizer.zero_grad(set_to_none=True)
+
+ if iteration in checkpoint_iterations:
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
+ torch.save(
+ (gaussians.capture(), iteration),
+ scene.model_path + "/chkpnt" + str(iteration) + ".pth",
+ )
+
+
+def prepare_output_and_logger(args):
+ if not args.model_path:
+ if os.getenv("OAR_JOB_ID"):
+ unique_str = os.getenv("OAR_JOB_ID")
+ else:
+ unique_str = str(uuid.uuid4())
+ args.model_path = os.path.join("./output/", unique_str[0:10])
+
+ # Set up output folder
+ print("Output folder: {}".format(args.model_path))
+ os.makedirs(args.model_path, exist_ok=True)
+ with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
+ cfg_log_f.write(str(Namespace(**vars(args))))
+
+ # Create Tensorboard writer
+ tb_writer = None
+ if TENSORBOARD_FOUND:
+ tb_writer = SummaryWriter(args.model_path)
+ else:
+ print("Tensorboard not available: not logging progress")
+ return tb_writer
+
+
+def training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ elapsed,
+ testing_iterations,
+ scene: Scene,
+ renderFunc,
+ renderArgs,
+):
+ if tb_writer:
+ tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration)
+ tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration)
+ tb_writer.add_scalar("iter_time", elapsed, iteration)
+
+ # Report test and samples of training set
+ if iteration in testing_iterations:
+ torch.cuda.empty_cache()
+ validation_configs = (
+ {"name": "test", "cameras": scene.getTestCameras()},
+ {
+ "name": "train",
+ "cameras": [
+ scene.getTrainCameras()[idx % len(scene.getTrainCameras())]
+ for idx in range(5, 30, 5)
+ ],
+ },
+ )
+
+ for config in validation_configs:
+ if config["cameras"] and len(config["cameras"]) > 0:
+ l1_test = 0.0
+ psnr_test = 0.0
+ for idx, viewpoint in enumerate(config["cameras"]):
+ image = torch.clamp(
+ renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"],
+ 0.0,
+ 1.0,
+ )
+ gt_image = torch.clamp(
+ viewpoint.original_image.to("cuda"), 0.0, 1.0
+ )
+ if tb_writer and (idx < 5):
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/render".format(viewpoint.image_name),
+ image[None],
+ global_step=iteration,
+ )
+ if iteration == testing_iterations[0]:
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/ground_truth".format(viewpoint.image_name),
+ gt_image[None],
+ global_step=iteration,
+ )
+ l1_test += l1_loss(image, gt_image).mean().double()
+ psnr_test += psnr(image, gt_image).mean().double()
+ psnr_test /= len(config["cameras"])
+ l1_test /= len(config["cameras"])
+ print(
+ "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(
+ iteration, config["name"], l1_test, psnr_test
+ )
+ )
+ if tb_writer:
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration
+ )
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration
+ )
+
+ if tb_writer:
+ tb_writer.add_histogram(
+ "scene/opacity_histogram", scene.gaussians.get_opacity, iteration
+ )
+ tb_writer.add_scalar(
+ "total_points", scene.gaussians.get_xyz.shape[0], iteration
+ )
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--seed", type=int, default=42)
+ parser.add_argument("--image", type=str, default="assets/images/ceramic.png")
+ parser.add_argument("--ckpt_path", type=str, required=True)
+ parser.add_argument("--ip", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=6009)
+ parser.add_argument("--debug_from", type=int, default=-1)
+ parser.add_argument("--detect_anomaly", action="store_true", default=False)
+ parser.add_argument(
+ "--test_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument(
+ "--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument("--quiet", action="store_true")
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
+ parser.add_argument("--start_checkpoint", type=str, default=None)
+ parser.add_argument("--resample_period", type=int, default=500)
+ parser.add_argument("--resample_sigma", type=float, default=0.1)
+ parser.add_argument("--resample_start", type=int, default=500)
+ args = parser.parse_args(sys.argv[1:])
+ args.save_iterations.append(args.iterations)
+
+ print("Optimizing " + args.model_path)
+
+ # Initialize system state (RNG)
+ safe_state(args.quiet)
+
+ # Start GUI server, configure and run training
+ network_gui.init(args.ip, args.port)
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
+
+ print("=====Start generating MV Images=====")
+
+ images, model = sample_one(args.image, args.ckpt_path, seed=args.seed)
+
+ print("=====Finish generating MV Images=====")
+
+ lp = lp.extract(args)
+ lp.images = images
+
+ training(
+ lp,
+ op.extract(args),
+ pp.extract(args),
+ args.test_iterations,
+ args.save_iterations,
+ args.checkpoint_iterations,
+ args.start_checkpoint,
+ args.debug_from,
+ args.resample_period,
+ args.resample_sigma,
+ args.resample_start,
+ model,
+ )
+
+ # All done
+ print("\nTraining complete.")
diff --git a/recon/train_scene.py b/recon/train_scene.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3e373b9b9589a17e3bc188b29baeb4db4ab6fd2
--- /dev/null
+++ b/recon/train_scene.py
@@ -0,0 +1,352 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import os
+import torch
+from random import randint
+from utils.loss_utils import l1_loss, ssim
+from gaussian_renderer import render, network_gui
+import sys
+from scene import Scene, GaussianModel
+from utils.general_utils import safe_state
+import uuid
+from tqdm import tqdm
+from utils.image_utils import psnr
+from argparse import ArgumentParser, Namespace
+from arguments import ModelParams, PipelineParams, OptimizationParams
+
+try:
+ from torch.utils.tensorboard import SummaryWriter
+
+ TENSORBOARD_FOUND = True
+except ImportError:
+ TENSORBOARD_FOUND = False
+
+
+def training(
+ dataset,
+ opt,
+ pipe,
+ testing_iterations,
+ saving_iterations,
+ checkpoint_iterations,
+ checkpoint,
+ debug_from,
+):
+ first_iter = 0
+ tb_writer = prepare_output_and_logger(dataset)
+ gaussians = GaussianModel(dataset.sh_degree)
+ scene = Scene(dataset, gaussians)
+ gaussians.training_setup(opt)
+ if checkpoint:
+ (model_params, first_iter) = torch.load(checkpoint)
+ gaussians.restore(model_params, opt)
+
+ bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
+ background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
+
+ iter_start = torch.cuda.Event(enable_timing=True)
+ iter_end = torch.cuda.Event(enable_timing=True)
+
+ viewpoint_stack = None
+ ema_loss_for_log = 0.0
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
+ first_iter += 1
+ for iteration in range(first_iter, opt.iterations + 1):
+ if network_gui.conn == None:
+ network_gui.try_connect()
+ while network_gui.conn != None:
+ try:
+ net_image_bytes = None
+ (
+ custom_cam,
+ do_training,
+ pipe.convert_SHs_python,
+ pipe.compute_cov3D_python,
+ keep_alive,
+ scaling_modifer,
+ ) = network_gui.receive()
+ if custom_cam != None:
+ net_image = render(
+ custom_cam, gaussians, pipe, background, scaling_modifer
+ )["render"]
+ net_image_bytes = memoryview(
+ (torch.clamp(net_image, min=0, max=1.0) * 255)
+ .byte()
+ .permute(1, 2, 0)
+ .contiguous()
+ .cpu()
+ .numpy()
+ )
+ network_gui.send(net_image_bytes, dataset.source_path)
+ if do_training and (
+ (iteration < int(opt.iterations)) or not keep_alive
+ ):
+ break
+ except Exception as e:
+ network_gui.conn = None
+
+ iter_start.record()
+
+ gaussians.update_learning_rate(iteration)
+
+ # Every 1000 its we increase the levels of SH up to a maximum degree
+ if iteration % 1000 == 0:
+ gaussians.oneupSHdegree()
+
+ # Pick a random Camera
+ if not viewpoint_stack:
+ viewpoint_stack = scene.getTrainCameras().copy()
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1))
+
+ # Render
+ if (iteration - 1) == debug_from:
+ pipe.debug = True
+
+ bg = torch.rand((3), device="cuda") if opt.random_background else background
+
+ render_pkg = render(viewpoint_cam, gaussians, pipe, bg)
+ image, viewspace_point_tensor, visibility_filter, radii = (
+ render_pkg["render"],
+ render_pkg["viewspace_points"],
+ render_pkg["visibility_filter"],
+ render_pkg["radii"],
+ )
+
+ # Loss
+ gt_image = viewpoint_cam.original_image.cuda()
+ Ll1 = l1_loss(image, gt_image)
+ loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (
+ 1.0 - ssim(image, gt_image)
+ )
+ loss.backward()
+
+ iter_end.record()
+
+ with torch.no_grad():
+ # Progress bar
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
+ if iteration % 10 == 0:
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
+ progress_bar.update(10)
+ if iteration == opt.iterations:
+ progress_bar.close()
+
+ # Log and save
+ training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ iter_start.elapsed_time(iter_end),
+ testing_iterations,
+ scene,
+ render,
+ (pipe, background),
+ )
+ if iteration in saving_iterations:
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
+ scene.save(iteration)
+
+ # Densification
+ if iteration < opt.densify_until_iter:
+ # Keep track of max radii in image-space for pruning
+ gaussians.max_radii2D[visibility_filter] = torch.max(
+ gaussians.max_radii2D[visibility_filter], radii[visibility_filter]
+ )
+ gaussians.add_densification_stats(
+ viewspace_point_tensor, visibility_filter
+ )
+
+ if (
+ iteration > opt.densify_from_iter
+ and iteration % opt.densification_interval == 0
+ ):
+ size_threshold = (
+ 20 if iteration > opt.opacity_reset_interval else None
+ )
+ gaussians.densify_and_prune(
+ opt.densify_grad_threshold,
+ 0.005,
+ scene.cameras_extent,
+ size_threshold,
+ )
+
+ if iteration % opt.opacity_reset_interval == 0 or (
+ dataset.white_background and iteration == opt.densify_from_iter
+ ):
+ gaussians.reset_opacity()
+
+ # Optimizer step
+ if iteration < opt.iterations:
+ gaussians.optimizer.step()
+ gaussians.optimizer.zero_grad(set_to_none=True)
+
+ if iteration in checkpoint_iterations:
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
+ torch.save(
+ (gaussians.capture(), iteration),
+ scene.model_path + "/chkpnt" + str(iteration) + ".pth",
+ )
+
+
+def prepare_output_and_logger(args):
+ if not args.model_path:
+ if os.getenv("OAR_JOB_ID"):
+ unique_str = os.getenv("OAR_JOB_ID")
+ else:
+ unique_str = str(uuid.uuid4())
+ args.model_path = os.path.join("./output/", unique_str[0:10])
+
+ # Set up output folder
+ print("Output folder: {}".format(args.model_path))
+ os.makedirs(args.model_path, exist_ok=True)
+ with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f:
+ cfg_log_f.write(str(Namespace(**vars(args))))
+
+ # Create Tensorboard writer
+ tb_writer = None
+ if TENSORBOARD_FOUND:
+ tb_writer = SummaryWriter(args.model_path)
+ else:
+ print("Tensorboard not available: not logging progress")
+ return tb_writer
+
+
+def training_report(
+ tb_writer,
+ iteration,
+ Ll1,
+ loss,
+ l1_loss,
+ elapsed,
+ testing_iterations,
+ scene: Scene,
+ renderFunc,
+ renderArgs,
+):
+ if tb_writer:
+ tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration)
+ tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration)
+ tb_writer.add_scalar("iter_time", elapsed, iteration)
+
+ # Report test and samples of training set
+ if iteration in testing_iterations:
+ torch.cuda.empty_cache()
+ validation_configs = (
+ {"name": "test", "cameras": scene.getTestCameras()},
+ {
+ "name": "train",
+ "cameras": [
+ scene.getTrainCameras()[idx % len(scene.getTrainCameras())]
+ for idx in range(5, 30, 5)
+ ],
+ },
+ )
+
+ for config in validation_configs:
+ if config["cameras"] and len(config["cameras"]) > 0:
+ l1_test = 0.0
+ psnr_test = 0.0
+ for idx, viewpoint in enumerate(config["cameras"]):
+ image = torch.clamp(
+ renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"],
+ 0.0,
+ 1.0,
+ )
+ gt_image = torch.clamp(
+ viewpoint.original_image.to("cuda"), 0.0, 1.0
+ )
+ if tb_writer and (idx < 5):
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/render".format(viewpoint.image_name),
+ image[None],
+ global_step=iteration,
+ )
+ if iteration == testing_iterations[0]:
+ tb_writer.add_images(
+ config["name"]
+ + "_view_{}/ground_truth".format(viewpoint.image_name),
+ gt_image[None],
+ global_step=iteration,
+ )
+ l1_test += l1_loss(image, gt_image).mean().double()
+ psnr_test += psnr(image, gt_image).mean().double()
+ psnr_test /= len(config["cameras"])
+ l1_test /= len(config["cameras"])
+ print(
+ "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(
+ iteration, config["name"], l1_test, psnr_test
+ )
+ )
+ if tb_writer:
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration
+ )
+ tb_writer.add_scalar(
+ config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration
+ )
+
+ if tb_writer:
+ tb_writer.add_histogram(
+ "scene/opacity_histogram", scene.gaussians.get_opacity, iteration
+ )
+ tb_writer.add_scalar(
+ "total_points", scene.gaussians.get_xyz.shape[0], iteration
+ )
+ torch.cuda.empty_cache()
+
+
+if __name__ == "__main__":
+ # Set up command line argument parser
+ parser = ArgumentParser(description="Training script parameters")
+ lp = ModelParams(parser)
+ op = OptimizationParams(parser)
+ pp = PipelineParams(parser)
+ parser.add_argument("--ip", type=str, default="127.0.0.1")
+ parser.add_argument("--port", type=int, default=6009)
+ parser.add_argument("--debug_from", type=int, default=-1)
+ parser.add_argument("--detect_anomaly", action="store_true", default=False)
+ parser.add_argument(
+ "--test_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument(
+ "--save_iterations", nargs="+", type=int, default=[7_000, 30_000]
+ )
+ parser.add_argument("--quiet", action="store_true")
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
+ parser.add_argument("--start_checkpoint", type=str, default=None)
+ args = parser.parse_args(sys.argv[1:])
+ args.save_iterations.append(args.iterations)
+
+ print("Optimizing " + args.model_path)
+
+ # Initialize system state (RNG)
+ safe_state(args.quiet)
+
+ # Start GUI server, configure and run training
+ network_gui.init(args.ip, args.port)
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
+ training(
+ lp.extract(args),
+ op.extract(args),
+ pp.extract(args),
+ args.test_iterations,
+ args.save_iterations,
+ args.checkpoint_iterations,
+ args.start_checkpoint,
+ args.debug_from,
+ )
+
+ # All done
+ print("\nTraining complete.")
diff --git a/recon/utils/camera_utils.py b/recon/utils/camera_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ace2474dd723975515438ae6f5d8a64e0c819317
--- /dev/null
+++ b/recon/utils/camera_utils.py
@@ -0,0 +1,151 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+from pathlib import Path
+from mediapy import read_video, write_video
+from scene.cameras import Camera
+import numpy as np
+from utils.general_utils import PILtoTorch
+from utils.graphics_utils import fov2focal
+
+WARNED = False
+
+
+def loadCam(args, id, cam_info, resolution_scale):
+ orig_w, orig_h = cam_info.image.size
+
+ if args.resolution in [1, 2, 4, 8]:
+ resolution = round(orig_w / (resolution_scale * args.resolution)), round(
+ orig_h / (resolution_scale * args.resolution)
+ )
+ else: # should be a type that converts to float
+ if args.resolution == -1:
+ if orig_w > 1600:
+ global WARNED
+ if not WARNED:
+ print(
+ "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
+ "If this is not desired, please explicitly specify '--resolution/-r' as 1"
+ )
+ WARNED = True
+ global_down = orig_w / 1600
+ else:
+ global_down = 1
+ else:
+ global_down = orig_w / args.resolution
+
+ scale = float(global_down) * float(resolution_scale)
+ resolution = (int(orig_w / scale), int(orig_h / scale))
+
+ resized_image_rgb = PILtoTorch(cam_info.image, resolution)
+
+ gt_image = resized_image_rgb[:3, ...]
+ loaded_mask = None
+
+ if resized_image_rgb.shape[1] == 4:
+ loaded_mask = resized_image_rgb[3:4, ...]
+
+ return Camera(
+ colmap_id=cam_info.uid,
+ R=cam_info.R,
+ T=cam_info.T,
+ FoVx=cam_info.FovX,
+ FoVy=cam_info.FovY,
+ image=gt_image,
+ gt_alpha_mask=loaded_mask,
+ image_name=cam_info.image_name,
+ uid=id,
+ data_device=args.data_device,
+ )
+
+
+def cameraList_from_camInfos(cam_infos, resolution_scale, args):
+ camera_list = []
+
+ for id, c in enumerate(cam_infos):
+ camera_list.append(loadCam(args, id, c, resolution_scale))
+
+ return camera_list
+
+
+def camera_to_JSON(id, camera: Camera):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = camera.R.transpose()
+ Rt[:3, 3] = camera.T
+ Rt[3, 3] = 1.0
+
+ W2C = np.linalg.inv(Rt)
+ pos = W2C[:3, 3]
+ rot = W2C[:3, :3]
+ serializable_array_2d = [x.tolist() for x in rot]
+ camera_entry = {
+ "id": id,
+ "img_name": camera.image_name,
+ "width": camera.width,
+ "height": camera.height,
+ "position": pos.tolist(),
+ "rotation": serializable_array_2d,
+ "fy": fov2focal(camera.FovY, camera.height),
+ "fx": fov2focal(camera.FovX, camera.width),
+ }
+ return camera_entry
+
+
+def get_c2w_from_up_and_look_at(
+ up,
+ look_at,
+ pos,
+ opengl=False,
+):
+ up = up / np.linalg.norm(up)
+ z = look_at - pos
+ z = z / np.linalg.norm(z)
+ y = -up
+ x = np.cross(y, z)
+ x /= np.linalg.norm(x)
+ y = np.cross(z, x)
+
+ c2w = np.zeros([4, 4], dtype=np.float32)
+ c2w[:3, 0] = x
+ c2w[:3, 1] = y
+ c2w[:3, 2] = z
+ c2w[:3, 3] = pos
+ c2w[3, 3] = 1.0
+
+ # opencv to opengl
+ if opengl:
+ c2w[..., 1:3] *= -1
+
+ return c2w
+
+
+def get_uniform_poses(num_frames, radius, elevation, opengl=False):
+ T = num_frames
+ azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T])
+ elevations = np.full_like(azimuths, np.deg2rad(elevation))
+ cam_dists = np.full_like(azimuths, radius)
+
+ campos = np.stack(
+ [
+ cam_dists * np.cos(elevations) * np.cos(azimuths),
+ cam_dists * np.cos(elevations) * np.sin(azimuths),
+ cam_dists * np.sin(elevations),
+ ],
+ axis=-1,
+ )
+
+ center = np.array([0, 0, 0], dtype=np.float32)
+ up = np.array([0, 0, 1], dtype=np.float32)
+ poses = []
+ for t in range(T):
+ poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl))
+
+ return np.stack(poses, axis=0)
diff --git a/recon/utils/colormaps.py b/recon/utils/colormaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ee85b4ff33b5aeb84e6779befb3601e167d744c
--- /dev/null
+++ b/recon/utils/colormaps.py
@@ -0,0 +1,220 @@
+# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Helper functions for visualizing outputs """
+
+from dataclasses import dataclass
+
+# from utils.typing import *
+from typing import *
+
+import matplotlib
+import torch
+from jaxtyping import Bool, Float
+from torch import Tensor
+
+from utils import colors
+
+Colormaps = Literal[
+ "default", "turbo", "viridis", "magma", "inferno", "cividis", "gray", "pca"
+]
+
+
+@dataclass(frozen=True)
+class ColormapOptions:
+ """Options for colormap"""
+
+ colormap: Colormaps = "default"
+ """ The colormap to use """
+ normalize: bool = False
+ """ Whether to normalize the input tensor image """
+ colormap_min: float = 0
+ """ Minimum value for the output colormap """
+ colormap_max: float = 1
+ """ Maximum value for the output colormap """
+ invert: bool = False
+ """ Whether to invert the output colormap """
+
+
+def apply_colormap(
+ image: Float[Tensor, "*bs channels"],
+ colormap_options: ColormapOptions = ColormapOptions(),
+ eps: float = 1e-9,
+) -> Float[Tensor, "*bs rgb"]:
+ """
+ Applies a colormap to a tensor image.
+ If single channel, applies a colormap to the image.
+ If 3 channel, treats the channels as RGB.
+ If more than 3 channel, applies a PCA reduction on the dimensions to 3 channels
+
+ Args:
+ image: Input tensor image.
+ eps: Epsilon value for numerical stability.
+
+ Returns:
+ Tensor with the colormap applied.
+ """
+
+ # default for rgb images
+ if image.shape[-1] == 3:
+ return image
+
+ # rendering depth outputs
+ if image.shape[-1] == 1 and torch.is_floating_point(image):
+ output = image
+ if colormap_options.normalize:
+ output = output - torch.min(output)
+ output = output / (torch.max(output) + eps)
+ output = (
+ output * (colormap_options.colormap_max - colormap_options.colormap_min)
+ + colormap_options.colormap_min
+ )
+ output = torch.clip(output, 0, 1)
+ if colormap_options.invert:
+ output = 1 - output
+ return apply_float_colormap(output, colormap=colormap_options.colormap)
+
+ # rendering boolean outputs
+ if image.dtype == torch.bool:
+ return apply_boolean_colormap(image)
+
+ if image.shape[-1] > 3:
+ return apply_pca_colormap(image)
+
+ raise NotImplementedError
+
+
+def apply_float_colormap(
+ image: Float[Tensor, "*bs 1"], colormap: Colormaps = "viridis"
+) -> Float[Tensor, "*bs rgb"]:
+ """Convert single channel to a color image.
+
+ Args:
+ image: Single channel image.
+ colormap: Colormap for image.
+
+ Returns:
+ Tensor: Colored image with colors in [0, 1]
+ """
+ if colormap == "default":
+ colormap = "turbo"
+
+ image = torch.nan_to_num(image, 0)
+ if colormap == "gray":
+ return image.repeat(1, 1, 3)
+ image = image.clamp(0, 1)
+ image_long = (image * 255).long()
+ image_long_min = torch.min(image_long)
+ image_long_max = torch.max(image_long)
+ assert image_long_min >= 0, f"the min value is {image_long_min}"
+ assert image_long_max <= 255, f"the max value is {image_long_max}"
+ return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[
+ image_long[..., 0]
+ ]
+
+
+def apply_depth_colormap(
+ depth: Float[Tensor, "*bs 1"],
+ accumulation: Optional[Float[Tensor, "*bs 1"]] = None,
+ near_plane: Optional[float] = None,
+ far_plane: Optional[float] = None,
+ colormap_options: ColormapOptions = ColormapOptions(),
+) -> Float[Tensor, "*bs rgb"]:
+ """Converts a depth image to color for easier analysis.
+
+ Args:
+ depth: Depth image.
+ accumulation: Ray accumulation used for masking vis.
+ near_plane: Closest depth to consider. If None, use min image value.
+ far_plane: Furthest depth to consider. If None, use max image value.
+ colormap: Colormap to apply.
+
+ Returns:
+ Colored depth image with colors in [0, 1]
+ """
+
+ near_plane = near_plane or float(torch.min(depth))
+ far_plane = far_plane or float(torch.max(depth))
+
+ depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
+ depth = torch.clip(depth, 0, 1)
+ # depth = torch.nan_to_num(depth, nan=0.0) # TODO(ethan): remove this
+
+ colored_image = apply_colormap(depth, colormap_options=colormap_options)
+
+ if accumulation is not None:
+ colored_image = colored_image * accumulation + (1 - accumulation)
+
+ return colored_image
+
+
+def apply_boolean_colormap(
+ image: Bool[Tensor, "*bs 1"],
+ true_color: Float[Tensor, "*bs rgb"] = colors.WHITE,
+ false_color: Float[Tensor, "*bs rgb"] = colors.BLACK,
+) -> Float[Tensor, "*bs rgb"]:
+ """Converts a depth image to color for easier analysis.
+
+ Args:
+ image: Boolean image.
+ true_color: Color to use for True.
+ false_color: Color to use for False.
+
+ Returns:
+ Colored boolean image
+ """
+
+ colored_image = torch.ones(image.shape[:-1] + (3,))
+ colored_image[image[..., 0], :] = true_color
+ colored_image[~image[..., 0], :] = false_color
+ return colored_image
+
+
+def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb"]:
+ """Convert feature image to 3-channel RGB via PCA. The first three principle
+ components are used for the color channels, with outlier rejection per-channel
+
+ Args:
+ image: image of arbitrary vectors
+
+ Returns:
+ Tensor: Colored image
+ """
+ original_shape = image.shape
+ image = image.view(-1, image.shape[-1])
+ _, _, v = torch.pca_lowrank(image)
+ image = torch.matmul(image, v[..., :3])
+ d = torch.abs(image - torch.median(image, dim=0).values)
+ mdev = torch.median(d, dim=0).values
+ s = d / mdev
+ m = 3.0 # this is a hyperparam controlling how many std dev outside for outliers
+ rins = image[s[:, 0] < m, 0]
+ gins = image[s[:, 1] < m, 1]
+ bins = image[s[:, 2] < m, 2]
+
+ image[:, 0] -= rins.min()
+ image[:, 1] -= gins.min()
+ image[:, 2] -= bins.min()
+
+ image[:, 0] /= rins.max() - rins.min()
+ image[:, 1] /= gins.max() - gins.min()
+ image[:, 2] /= bins.max() - bins.min()
+
+ image = torch.clamp(image, 0, 1)
+ image_long = (image * 255).long()
+ image_long_min = torch.min(image_long)
+ image_long_max = torch.max(image_long)
+ assert image_long_min >= 0, f"the min value is {image_long_min}"
+ assert image_long_max <= 255, f"the max value is {image_long_max}"
+ return image.view(*original_shape[:-1], 3)
diff --git a/recon/utils/colors.py b/recon/utils/colors.py
new file mode 100644
index 0000000000000000000000000000000000000000..66ac8d24357d0c6f5c0db9f560f13dff459a3c83
--- /dev/null
+++ b/recon/utils/colors.py
@@ -0,0 +1,55 @@
+# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Common Colors"""
+from typing import Union
+
+import torch
+from jaxtyping import Float
+from torch import Tensor
+
+WHITE = torch.tensor([1.0, 1.0, 1.0])
+BLACK = torch.tensor([0.0, 0.0, 0.0])
+RED = torch.tensor([1.0, 0.0, 0.0])
+GREEN = torch.tensor([0.0, 1.0, 0.0])
+BLUE = torch.tensor([0.0, 0.0, 1.0])
+
+COLORS_DICT = {
+ "white": WHITE,
+ "black": BLACK,
+ "red": RED,
+ "green": GREEN,
+ "blue": BLUE,
+}
+
+
+def get_color(color: Union[str, list]) -> Float[Tensor, "3"]:
+ """
+ Args:
+ Color as a string or a rgb list
+
+ Returns:
+ Parsed color
+ """
+ if isinstance(color, str):
+ color = color.lower()
+ if color not in COLORS_DICT:
+ raise ValueError(f"{color} is not a valid preset color")
+ return COLORS_DICT[color]
+ if isinstance(color, list):
+ if len(color) != 3:
+ raise ValueError(f"Color should be 3 values (RGB) instead got {color}")
+ return torch.tensor(color)
+
+ raise ValueError(f"Color should be an RGB list or string, instead got {type(color)}")
diff --git a/recon/utils/diffusion.py b/recon/utils/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..0217efb13b1342dcb4b7f22b8528c53c58a11627
--- /dev/null
+++ b/recon/utils/diffusion.py
@@ -0,0 +1,42 @@
+import torch
+from PIL import Image
+from pathlib import Path
+from omegaconf import OmegaConf
+
+from scripts.demo.streamlit_helpers import (
+ load_model_from_config,
+ get_sampler,
+ get_batch,
+ do_sample,
+)
+
+
+def load_config_and_model(ckpt: Path):
+ if (ckpt.parent.parent / "configs").exists():
+ config_path = list((ckpt.parent.parent / "configs").glob("*-project.yaml"))[0]
+ else:
+ config_path = list(
+ (ckpt.parent.parent.parent / "configs").glob("*-project.yaml")
+ )[0]
+
+ config = OmegaConf.load(config_path)
+
+ model, msg = load_model_from_config(config, ckpt)
+
+ return config, model
+
+
+def load_sampler(sampler_cfg):
+ return get_sampler(**sampler_cfg)
+
+
+def load_batch():
+ pass
+
+
+class DiffusionEngine:
+ def __init__(self, cfg) -> None:
+ self.cfg = cfg
+
+ def sample(self):
+ pass
diff --git a/recon/utils/general_utils.py b/recon/utils/general_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d
--- /dev/null
+++ b/recon/utils/general_utils.py
@@ -0,0 +1,133 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import sys
+from datetime import datetime
+import numpy as np
+import random
+
+def inverse_sigmoid(x):
+ return torch.log(x/(1-x))
+
+def PILtoTorch(pil_image, resolution):
+ resized_image_PIL = pil_image.resize(resolution)
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
+ if len(resized_image.shape) == 3:
+ return resized_image.permute(2, 0, 1)
+ else:
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
+
+def get_expon_lr_func(
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
+):
+ """
+ Copied from Plenoxels
+
+ Continuous learning rate decay function. Adapted from JaxNeRF
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
+ function of lr_delay_mult, such that the initial learning rate is
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
+ to the normal learning rate when steps>lr_delay_steps.
+ :param conf: config subtree 'lr' or similar
+ :param max_steps: int, the number of steps during optimization.
+ :return HoF which takes step as input
+ """
+
+ def helper(step):
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
+ # Disable this parameter
+ return 0.0
+ if lr_delay_steps > 0:
+ # A kind of reverse cosine decay.
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
+ )
+ else:
+ delay_rate = 1.0
+ t = np.clip(step / max_steps, 0, 1)
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
+ return delay_rate * log_lerp
+
+ return helper
+
+def strip_lowerdiag(L):
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
+
+ uncertainty[:, 0] = L[:, 0, 0]
+ uncertainty[:, 1] = L[:, 0, 1]
+ uncertainty[:, 2] = L[:, 0, 2]
+ uncertainty[:, 3] = L[:, 1, 1]
+ uncertainty[:, 4] = L[:, 1, 2]
+ uncertainty[:, 5] = L[:, 2, 2]
+ return uncertainty
+
+def strip_symmetric(sym):
+ return strip_lowerdiag(sym)
+
+def build_rotation(r):
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
+
+ q = r / norm[:, None]
+
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
+
+ r = q[:, 0]
+ x = q[:, 1]
+ y = q[:, 2]
+ z = q[:, 3]
+
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
+ R[:, 0, 1] = 2 * (x*y - r*z)
+ R[:, 0, 2] = 2 * (x*z + r*y)
+ R[:, 1, 0] = 2 * (x*y + r*z)
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
+ R[:, 1, 2] = 2 * (y*z - r*x)
+ R[:, 2, 0] = 2 * (x*z - r*y)
+ R[:, 2, 1] = 2 * (y*z + r*x)
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
+ return R
+
+def build_scaling_rotation(s, r):
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
+ R = build_rotation(r)
+
+ L[:,0,0] = s[:,0]
+ L[:,1,1] = s[:,1]
+ L[:,2,2] = s[:,2]
+
+ L = R @ L
+ return L
+
+def safe_state(silent):
+ old_f = sys.stdout
+ class F:
+ def __init__(self, silent):
+ self.silent = silent
+
+ def write(self, x):
+ if not self.silent:
+ if x.endswith("\n"):
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
+ else:
+ old_f.write(x)
+
+ def flush(self):
+ old_f.flush()
+
+ sys.stdout = F(silent)
+
+ random.seed(0)
+ np.random.seed(0)
+ torch.manual_seed(0)
+ torch.cuda.set_device(torch.device("cuda:0"))
diff --git a/recon/utils/graphics_utils.py b/recon/utils/graphics_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4627d837c74fcdffc898fa0c3071cb7b316802b
--- /dev/null
+++ b/recon/utils/graphics_utils.py
@@ -0,0 +1,77 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import math
+import numpy as np
+from typing import NamedTuple
+
+class BasicPointCloud(NamedTuple):
+ points : np.array
+ colors : np.array
+ normals : np.array
+
+def geom_transform_points(points, transf_matrix):
+ P, _ = points.shape
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
+ points_hom = torch.cat([points, ones], dim=1)
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
+
+ denom = points_out[..., 3:] + 0.0000001
+ return (points_out[..., :3] / denom).squeeze(dim=0)
+
+def getWorld2View(R, t):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+ return np.float32(Rt)
+
+def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
+ Rt = np.zeros((4, 4))
+ Rt[:3, :3] = R.transpose()
+ Rt[:3, 3] = t
+ Rt[3, 3] = 1.0
+
+ C2W = np.linalg.inv(Rt)
+ cam_center = C2W[:3, 3]
+ cam_center = (cam_center + translate) * scale
+ C2W[:3, 3] = cam_center
+ Rt = np.linalg.inv(C2W)
+ return np.float32(Rt)
+
+def getProjectionMatrix(znear, zfar, fovX, fovY):
+ tanHalfFovY = math.tan((fovY / 2))
+ tanHalfFovX = math.tan((fovX / 2))
+
+ top = tanHalfFovY * znear
+ bottom = -top
+ right = tanHalfFovX * znear
+ left = -right
+
+ P = torch.zeros(4, 4)
+
+ z_sign = 1.0
+
+ P[0, 0] = 2.0 * znear / (right - left)
+ P[1, 1] = 2.0 * znear / (top - bottom)
+ P[0, 2] = (right + left) / (right - left)
+ P[1, 2] = (top + bottom) / (top - bottom)
+ P[3, 2] = z_sign
+ P[2, 2] = z_sign * zfar / (zfar - znear)
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
+ return P
+
+def fov2focal(fov, pixels):
+ return pixels / (2 * math.tan(fov / 2))
+
+def focal2fov(focal, pixels):
+ return 2*math.atan(pixels/(2*focal))
\ No newline at end of file
diff --git a/recon/utils/image_utils.py b/recon/utils/image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdeaa1b6d250e549181ab165070f82ccd31b3eb9
--- /dev/null
+++ b/recon/utils/image_utils.py
@@ -0,0 +1,19 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+
+def mse(img1, img2):
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+
+def psnr(img1, img2):
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))
diff --git a/recon/utils/loss_utils.py b/recon/utils/loss_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1824708789f207d58f86c0e8350bc70e4b4037a
--- /dev/null
+++ b/recon/utils/loss_utils.py
@@ -0,0 +1,96 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from math import exp
+from lpipsPyTorch import lpips as lpips_fn
+from lpipsPyTorch.modules.lpips import LPIPS
+
+_lpips = None
+
+
+def l1_loss(network_output, gt):
+ return torch.abs((network_output - gt)).mean()
+
+
+def l2_loss(network_output, gt):
+ return ((network_output - gt) ** 2).mean()
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor(
+ [
+ exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
+ for x in range(window_size)
+ ]
+ )
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(
+ _2D_window.expand(channel, 1, window_size, window_size).contiguous()
+ )
+ return window
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ channel = img1.size(-3)
+ window = create_window(window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ return _ssim(img1, img2, window, window_size, channel, size_average)
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = (
+ F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ )
+ sigma2_sq = (
+ F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ )
+ sigma12 = (
+ F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
+ - mu1_mu2
+ )
+
+ C1 = 0.01**2
+ C2 = 0.03**2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
+ (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
+ )
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1).mean(1).mean(1)
+
+
+def lpips(img1, img2):
+ global _lpips
+ if _lpips is None:
+ _lpips = LPIPS("vgg", "0.1").to("cuda")
+ return _lpips(img1, img2).mean()
diff --git a/recon/utils/sh_utils.py b/recon/utils/sh_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785
--- /dev/null
+++ b/recon/utils/sh_utils.py
@@ -0,0 +1,118 @@
+# Copyright 2021 The PlenOctree Authors.
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+import torch
+
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+]
+C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+]
+C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761,
+]
+
+
+def eval_sh(deg, sh, dirs):
+ """
+ Evaluate spherical harmonics at unit directions
+ using hardcoded SH polynomials.
+ Works with torch/np/jnp.
+ ... Can be 0 or more batch dimensions.
+ Args:
+ deg: int SH deg. Currently, 0-3 supported
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
+ dirs: jnp.ndarray unit directions [..., 3]
+ Returns:
+ [..., C]
+ """
+ assert deg <= 4 and deg >= 0
+ coeff = (deg + 1) ** 2
+ assert sh.shape[-1] >= coeff
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+def RGB2SH(rgb):
+ return (rgb - 0.5) / C0
+
+def SH2RGB(sh):
+ return sh * C0 + 0.5
\ No newline at end of file
diff --git a/recon/utils/system_utils.py b/recon/utils/system_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e
--- /dev/null
+++ b/recon/utils/system_utils.py
@@ -0,0 +1,28 @@
+#
+# Copyright (C) 2023, Inria
+# GRAPHDECO research group, https://team.inria.fr/graphdeco
+# All rights reserved.
+#
+# This software is free for non-commercial, research and evaluation use
+# under the terms of the LICENSE.md file.
+#
+# For inquiries contact george.drettakis@inria.fr
+#
+
+from errno import EEXIST
+from os import makedirs, path
+import os
+
+def mkdir_p(folder_path):
+ # Creates a directory. equivalent to using mkdir -p on the command line
+ try:
+ makedirs(folder_path)
+ except OSError as exc: # Python >2.5
+ if exc.errno == EEXIST and path.isdir(folder_path):
+ pass
+ else:
+ raise
+
+def searchForMaxIteration(folder):
+ saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
+ return max(saved_iters)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..16898baf6dfc3f0f4ad7b3b63accac8b1834921a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,43 @@
+black==23.7.0
+chardet==5.1.0
+clip @ git+https://github.com/openai/CLIP.git
+einops>=0.6.1
+fairscale>=0.4.13
+fire>=0.5.0
+fsspec>=2023.6.0
+invisible-watermark>=0.2.0
+kornia==0.6.9
+matplotlib>=3.7.2
+natsort>=8.4.0
+ninja>=1.11.1
+numpy>=1.24.4
+omegaconf>=2.3.0
+open-clip-torch>=2.20.0
+opencv-python==4.6.0.66
+pandas>=2.0.3
+pillow>=9.5.0
+pudb>=2022.1.3
+pytorch-lightning==2.0.1
+pyyaml>=6.0.1
+scipy>=1.10.1
+streamlit>=0.73.1
+tensorboardx==2.6
+timm>=0.9.2
+tokenizers==0.12.1
+torch>=2.0.1
+torchaudio>=2.0.2
+torchdata==0.6.1
+torchmetrics>=1.0.1
+torchvision>=0.15.2
+tqdm>=4.65.0
+transformers==4.19.1
+triton==2.0.0
+urllib3<1.27,>=1.25.4
+wandb>=0.15.6
+webdataset>=0.2.33
+wheel>=0.41.0
+xformers>=0.0.20
+streamlit-keyup==0.2.0
+mediapy
+tyro
+wget
diff --git a/scripts/__init__.py b/scripts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scripts/pub/V3D_512.py b/scripts/pub/V3D_512.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1fd579348d477d395958e13f7e7002bb9be1f2
--- /dev/null
+++ b/scripts/pub/V3D_512.py
@@ -0,0 +1,317 @@
+import math
+import os
+from glob import glob
+from pathlib import Path
+from typing import Optional
+
+import cv2
+import numpy as np
+import torch
+from einops import rearrange, repeat
+from fire import Fire
+import tyro
+from omegaconf import OmegaConf
+from PIL import Image
+from torchvision.transforms import ToTensor
+from mediapy import write_video
+import rembg
+from kiui.op import recenter
+from safetensors.torch import load_file as load_safetensors
+from typing import Any
+
+from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
+from sgm.inference.helpers import embed_watermark
+from sgm.util import default, instantiate_from_config
+
+
+def get_unique_embedder_keys_from_conditioner(conditioner):
+ return list(set([x.input_key for x in conditioner.embedders]))
+
+
+def get_batch(keys, value_dict, N, T, device):
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "fps_id":
+ batch[key] = (
+ torch.tensor([value_dict["fps_id"]])
+ .to(device)
+ .repeat(int(math.prod(N)))
+ )
+ elif key == "motion_bucket_id":
+ batch[key] = (
+ torch.tensor([value_dict["motion_bucket_id"]])
+ .to(device)
+ .repeat(int(math.prod(N)))
+ )
+ elif key == "cond_aug":
+ batch[key] = repeat(
+ torch.tensor([value_dict["cond_aug"]]).to(device),
+ "1 -> b",
+ b=math.prod(N),
+ )
+ elif key == "cond_frames":
+ batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
+ elif key == "cond_frames_without_noise":
+ batch[key] = repeat(
+ value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
+ )
+ else:
+ batch[key] = value_dict[key]
+
+ if T is not None:
+ batch["num_video_frames"] = T
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+def load_model(
+ config: str,
+ device: str,
+ num_frames: int,
+ num_steps: int,
+ ckpt_path: Optional[str] = None,
+ min_cfg: Optional[float] = None,
+ max_cfg: Optional[float] = None,
+ sigma_max: Optional[float] = None,
+):
+ config = OmegaConf.load(config)
+
+ config.model.params.sampler_config.params.num_steps = num_steps
+ config.model.params.sampler_config.params.guider_config.params.num_frames = (
+ num_frames
+ )
+ if max_cfg is not None:
+ config.model.params.sampler_config.params.guider_config.params.max_scale = (
+ max_cfg
+ )
+ if min_cfg is not None:
+ config.model.params.sampler_config.params.guider_config.params.min_scale = (
+ min_cfg
+ )
+ if sigma_max is not None:
+ print("Overriding sigma_max to ", sigma_max)
+ config.model.params.sampler_config.params.discretization_config.params.sigma_max = (
+ sigma_max
+ )
+
+ config.model.params.from_scratch = False
+
+ if ckpt_path is not None:
+ config.model.params.ckpt_path = str(ckpt_path)
+ if device == "cuda":
+ with torch.device(device):
+ model = instantiate_from_config(config.model).to(device).eval()
+ else:
+ model = instantiate_from_config(config.model).to(device).eval()
+
+ return model, None
+
+
+def sample_one(
+ input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
+ checkpoint_path: Optional[str] = None,
+ num_frames: Optional[int] = None,
+ num_steps: Optional[int] = None,
+ fps_id: int = 1,
+ motion_bucket_id: int = 300,
+ cond_aug: float = 0.02,
+ seed: int = 23,
+ decoding_t: int = 24, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
+ device: str = "cuda",
+ output_folder: Optional[str] = None,
+ noise: torch.Tensor = None,
+ save: bool = False,
+ cached_model: Any = None,
+ border_ratio: float = 0.3,
+ min_guidance_scale: float = 3.5,
+ max_guidance_scale: float = 3.5,
+ sigma_max: float = None,
+ ignore_alpha: bool = False,
+):
+ model_config = "scripts/pub/configs/V3D_512.yaml"
+ num_frames = OmegaConf.load(
+ model_config
+ ).model.params.sampler_config.params.guider_config.params.num_frames
+ print("Detected num_frames:", num_frames)
+ num_steps = default(num_steps, 25)
+ output_folder = default(output_folder, f"outputs/V3D_512")
+ decoding_t = min(decoding_t, num_frames)
+
+ sd = load_safetensors("./ckpts/svd_xt.safetensors")
+ clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml")
+ clip_model = instantiate_from_config(clip_model_config).eval()
+ clip_sd = dict()
+ for k, v in sd.items():
+ if "conditioner.embedders.0" in k:
+ clip_sd[k.replace("conditioner.embedders.0.", "")] = v
+ clip_model.load_state_dict(clip_sd)
+ clip_model = clip_model.to(device)
+
+ ae_model_config = OmegaConf.load("configs/ae/video.yaml")
+ ae_model = instantiate_from_config(ae_model_config).eval()
+ encoder_sd = dict()
+ for k, v in sd.items():
+ if "first_stage_model" in k:
+ encoder_sd[k.replace("first_stage_model.", "")] = v
+ ae_model.load_state_dict(encoder_sd)
+ ae_model = ae_model.to(device)
+
+ if cached_model is None:
+ model, filter = load_model(
+ model_config,
+ device,
+ num_frames,
+ num_steps,
+ ckpt_path=checkpoint_path,
+ min_cfg=min_guidance_scale,
+ max_cfg=max_guidance_scale,
+ sigma_max=sigma_max,
+ )
+ else:
+ model = cached_model
+ torch.manual_seed(seed)
+
+ need_return = True
+ path = Path(input_path)
+ if path.is_file():
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
+ all_img_paths = [input_path]
+ else:
+ raise ValueError("Path is not valid image file.")
+ elif path.is_dir():
+ all_img_paths = sorted(
+ [
+ f
+ for f in path.iterdir()
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
+ ]
+ )
+ need_return = False
+ if len(all_img_paths) == 0:
+ raise ValueError("Folder does not contain any images.")
+ else:
+ raise ValueError
+
+ for input_path in all_img_paths:
+ with Image.open(input_path) as image:
+ # if image.mode == "RGBA":
+ # image = image.convert("RGB")
+ w, h = image.size
+
+ if border_ratio > 0:
+ if image.mode != "RGBA" or ignore_alpha:
+ image = image.convert("RGB")
+ image = np.asarray(image)
+ carved_image = rembg.remove(image) # [H, W, 4]
+ else:
+ image = np.asarray(image)
+ carved_image = image
+ mask = carved_image[..., -1] > 0
+ image = recenter(carved_image, mask, border_ratio=border_ratio)
+ image = image.astype(np.float32) / 255.0
+ if image.shape[-1] == 4:
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
+ image = Image.fromarray((image * 255).astype(np.uint8))
+ else:
+ print("Ignore border ratio")
+ image = image.resize((512, 512))
+
+ image = ToTensor()(image)
+ image = image * 2.0 - 1.0
+
+ image = image.unsqueeze(0).to(device)
+ H, W = image.shape[2:]
+ assert image.shape[1] == 3
+ F = 8
+ C = 4
+ shape = (num_frames, C, H // F, W // F)
+
+ value_dict = {}
+ value_dict["motion_bucket_id"] = motion_bucket_id
+ value_dict["fps_id"] = fps_id
+ value_dict["cond_aug"] = cond_aug
+ value_dict["cond_frames_without_noise"] = clip_model(image)
+ value_dict["cond_frames"] = ae_model.encode(image)
+ value_dict["cond_frames"] += cond_aug * torch.randn_like(
+ value_dict["cond_frames"]
+ )
+ value_dict["cond_aug"] = cond_aug
+
+ with torch.no_grad():
+ with torch.autocast(device):
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ [1, num_frames],
+ T=num_frames,
+ device=device,
+ )
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=[
+ "cond_frames",
+ "cond_frames_without_noise",
+ ],
+ )
+
+ for k in ["crossattn", "concat"]:
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
+
+ randn = torch.randn(shape, device=device) if noise is None else noise
+ randn = randn.to(device)
+
+ additional_model_inputs = {}
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
+ 2, num_frames
+ ).to(device)
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
+
+ def denoiser(input, sigma, c):
+ return model.denoiser(
+ model.model, input, sigma, c, **additional_model_inputs
+ )
+
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
+ model.en_and_decode_n_samples_a_time = decoding_t
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ os.makedirs(output_folder, exist_ok=True)
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
+ # writer = cv2.VideoWriter(
+ # video_path,
+ # cv2.VideoWriter_fourcc(*"MP4V"),
+ # fps_id + 1,
+ # (samples.shape[-1], samples.shape[-2]),
+ # )
+
+ frames = (
+ (rearrange(samples, "t c h w -> t h w c") * 255)
+ .cpu()
+ .numpy()
+ .astype(np.uint8)
+ )
+
+ if save:
+ write_video(video_path, frames, fps=3)
+
+ images = []
+ for frame in frames:
+ images.append(Image.fromarray(frame))
+
+ if need_return:
+ return images, model
+
+
+if __name__ == "__main__":
+ tyro.cli(sample_one)
diff --git a/scripts/pub/configs/V3D_512.yaml b/scripts/pub/configs/V3D_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aee4108e741a50a75336d277e72a72d9b1df8ade
--- /dev/null
+++ b/scripts/pub/configs/V3D_512.yaml
@@ -0,0 +1,161 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: sgm.models.video_diffusion.DiffusionEngine
+ params:
+ ckpt_path: ckpts/V3D_512.ckpt
+ scale_factor: 0.18215
+ disable_first_stage_autocast: true
+ input_key: latents
+ log_keys: []
+ scheduler_config:
+ target: sgm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps:
+ - 1
+ cycle_lengths:
+ - 10000000000000
+ f_start:
+ - 1.0e-06
+ f_max:
+ - 1.0
+ f_min:
+ - 1.0
+ denoiser_config:
+ target: sgm.modules.diffusionmodules.denoiser.Denoiser
+ params:
+ scaling_config:
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
+ network_config:
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
+ params:
+ adm_in_channels: 768
+ num_classes: sequential
+ use_checkpoint: true
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_head_channels: 64
+ use_linear_in_transformer: true
+ transformer_depth: 1
+ context_dim: 1024
+ spatial_transformer_attn_type: softmax-xformers
+ extra_ff_mix_layer: true
+ use_spatial_context: true
+ merge_strategy: learned_with_images
+ video_kernel_size:
+ - 3
+ - 1
+ - 1
+ conditioner_config:
+ target: sgm.modules.GeneralConditioner
+ params:
+ emb_models:
+ - is_trainable: false
+ ucg_rate: 0.2
+ input_key: cond_frames_without_noise
+ target: sgm.modules.encoders.modules.IdentityEncoder
+ - input_key: fps_id
+ is_trainable: true
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256
+ - input_key: motion_bucket_id
+ is_trainable: true
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256
+ - input_key: cond_frames
+ is_trainable: false
+ ucg_rate: 0.2
+ target: sgm.modules.encoders.modules.IdentityEncoder
+ - input_key: cond_aug
+ is_trainable: true
+ target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
+ params:
+ outdim: 256
+ first_stage_config:
+ target: sgm.models.autoencoder.AutoencodingEngine
+ params:
+ loss_config:
+ target: torch.nn.Identity
+ regularizer_config:
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
+ encoder_config:
+ target: sgm.modules.diffusionmodules.model.Encoder
+ params:
+ attn_type: vanilla
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ decoder_config:
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
+ params:
+ attn_type: vanilla
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ video_kernel_size:
+ - 3
+ - 1
+ - 1
+ sampler_config:
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
+ params:
+ num_steps: 30
+ discretization_config:
+ target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
+ params:
+ sigma_max: 700.0
+ guider_config:
+ target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
+ params:
+ max_scale: 3.5
+ min_scale: 3.5
+ num_frames: 18
+ loss_fn_config:
+ target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
+ params:
+ batch2model_keys:
+ - num_video_frames
+ - image_only_indicator
+ loss_weighting_config:
+ target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
+ params:
+ sigma_data: 1.0
+ sigma_sampler_config:
+ target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
+ params:
+ p_mean: 1.5
+ p_std: 2.0
\ No newline at end of file
diff --git a/scripts/tests/attention.py b/scripts/tests/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c3f7c8da27c577a7ce0ea3a01ab7f9e9c1baa2
--- /dev/null
+++ b/scripts/tests/attention.py
@@ -0,0 +1,319 @@
+import einops
+import torch
+import torch.nn.functional as F
+import torch.utils.benchmark as benchmark
+from torch.backends.cuda import SDPBackend
+
+from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
+
+
+def benchmark_attn():
+ # Lets define a helpful benchmarking function:
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ return t0.blocked_autorange().mean * 1e6
+
+ # Lets define the hyper-parameters of our input
+ batch_size = 32
+ max_sequence_len = 1024
+ num_heads = 32
+ embed_dimension = 32
+
+ dtype = torch.float16
+
+ query = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+ key = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+ value = torch.rand(
+ batch_size,
+ num_heads,
+ max_sequence_len,
+ embed_dimension,
+ device=device,
+ dtype=dtype,
+ )
+
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
+
+ # Lets explore the speed of each of the 3 implementations
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ # Helpful arguments mapper
+ backend_map = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ }
+
+ from torch.profiler import ProfilerActivity, profile, record_function
+
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
+
+ print(
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("Default detailed stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ print(
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("Math implmentation stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
+ try:
+ print(
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ except RuntimeError:
+ print("FlashAttention is not supported. See warnings for reasons.")
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("FlashAttention stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
+ try:
+ print(
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
+ )
+ except RuntimeError:
+ print("EfficientAttention is not supported. See warnings for reasons.")
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("EfficientAttention stats"):
+ for _ in range(25):
+ o = F.scaled_dot_product_attention(query, key, value)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+
+
+def run_model(model, x, context):
+ return model(x, context)
+
+
+def benchmark_transformer_blocks():
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ import torch.utils.benchmark as benchmark
+
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
+ )
+ return t0.blocked_autorange().mean * 1e6
+
+ checkpoint = True
+ compile = False
+
+ batch_size = 32
+ h, w = 64, 64
+ context_len = 77
+ embed_dimension = 1024
+ context_dim = 1024
+ d_head = 64
+
+ transformer_depth = 4
+
+ n_heads = embed_dimension // d_head
+
+ dtype = torch.float16
+
+ model_native = SpatialTransformer(
+ embed_dimension,
+ n_heads,
+ d_head,
+ context_dim=context_dim,
+ use_linear=True,
+ use_checkpoint=checkpoint,
+ attn_type="softmax",
+ depth=transformer_depth,
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
+ ).to(device)
+ model_efficient_attn = SpatialTransformer(
+ embed_dimension,
+ n_heads,
+ d_head,
+ context_dim=context_dim,
+ use_linear=True,
+ depth=transformer_depth,
+ use_checkpoint=checkpoint,
+ attn_type="softmax-xformers",
+ ).to(device)
+ if not checkpoint and compile:
+ print("compiling models")
+ model_native = torch.compile(model_native)
+ model_efficient_attn = torch.compile(model_efficient_attn)
+
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
+
+ from torch.profiler import ProfilerActivity, profile, record_function
+
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
+
+ with torch.autocast("cuda"):
+ print(
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
+ )
+ print(
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
+ )
+
+ print(75 * "+")
+ print("NATIVE")
+ print(75 * "+")
+ torch.cuda.reset_peak_memory_stats()
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("NativeAttention stats"):
+ for _ in range(25):
+ model_native(x, c)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
+
+ print(75 * "+")
+ print("Xformers")
+ print(75 * "+")
+ torch.cuda.reset_peak_memory_stats()
+ with profile(
+ activities=activities, record_shapes=False, profile_memory=True
+ ) as prof:
+ with record_function("xformers stats"):
+ for _ in range(25):
+ model_efficient_attn(x, c)
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
+
+
+def test01():
+ # conv1x1 vs linear
+ from sgm.util import count_params
+
+ conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
+ print(count_params(conv))
+ linear = torch.nn.Linear(3, 32).cuda()
+ print(count_params(linear))
+
+ print(conv.weight.shape)
+
+ # use same initialization
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
+ linear.bias = torch.nn.Parameter(conv.bias)
+
+ print(linear.weight.shape)
+
+ x = torch.randn(11, 3, 64, 64).cuda()
+
+ xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
+ print(xr.shape)
+ out_linear = linear(xr)
+ print(out_linear.mean(), out_linear.shape)
+
+ out_conv = conv(x)
+ print(out_conv.mean(), out_conv.shape)
+ print("done with test01.\n")
+
+
+def test02():
+ # try cosine flash attention
+ import time
+
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cudnn.benchmark = True
+ print("testing cosine flash attention...")
+ DIM = 1024
+ SEQLEN = 4096
+ BS = 16
+
+ print(" softmax (vanilla) first...")
+ model = BasicTransformerBlock(
+ dim=DIM,
+ n_heads=16,
+ d_head=64,
+ dropout=0.0,
+ context_dim=None,
+ attn_mode="softmax",
+ ).cuda()
+ try:
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
+ tic = time.time()
+ y = model(x)
+ toc = time.time()
+ print(y.shape, toc - tic)
+ except RuntimeError as e:
+ # likely oom
+ print(str(e))
+
+ print("\n now flash-cosine...")
+ model = BasicTransformerBlock(
+ dim=DIM,
+ n_heads=16,
+ d_head=64,
+ dropout=0.0,
+ context_dim=None,
+ attn_mode="flash-cosine",
+ ).cuda()
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
+ tic = time.time()
+ y = model(x)
+ toc = time.time()
+ print(y.shape, toc - tic)
+ print("done with test02.\n")
+
+
+if __name__ == "__main__":
+ # test01()
+ # test02()
+ # test03()
+
+ # benchmark_attn()
+ benchmark_transformer_blocks()
+
+ print("done.")
diff --git a/scripts/util/__init__.py b/scripts/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scripts/util/detection/__init__.py b/scripts/util/detection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/scripts/util/detection/nsfw_and_watermark_dectection.py b/scripts/util/detection/nsfw_and_watermark_dectection.py
new file mode 100644
index 0000000000000000000000000000000000000000..1096b8177d8e3dbcf8e913f924e98d5ce58cb120
--- /dev/null
+++ b/scripts/util/detection/nsfw_and_watermark_dectection.py
@@ -0,0 +1,110 @@
+import os
+
+import clip
+import numpy as np
+import torch
+import torchvision.transforms as T
+from PIL import Image
+
+RESOURCES_ROOT = "scripts/util/detection/"
+
+
+def predict_proba(X, weights, biases):
+ logits = X @ weights.T + biases
+ proba = np.where(
+ logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits))
+ )
+ return proba.T
+
+
+def load_model_weights(path: str):
+ model_weights = np.load(path)
+ return model_weights["weights"], model_weights["biases"]
+
+
+def clip_process_images(images: torch.Tensor) -> torch.Tensor:
+ min_size = min(images.shape[-2:])
+ return T.Compose(
+ [
+ T.CenterCrop(min_size), # TODO: this might affect the watermark, check this
+ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
+ T.Normalize(
+ (0.48145466, 0.4578275, 0.40821073),
+ (0.26862954, 0.26130258, 0.27577711),
+ ),
+ ]
+ )(images)
+
+
+class DeepFloydDataFiltering(object):
+ def __init__(
+ self, verbose: bool = False, device: torch.device = torch.device("cpu")
+ ):
+ super().__init__()
+ self.verbose = verbose
+ self._device = None
+ self.clip_model, _ = clip.load("ViT-L/14", device=device)
+ self.clip_model.eval()
+
+ self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
+ os.path.join(RESOURCES_ROOT, "w_head_v1.npz")
+ )
+ self.cpu_p_weights, self.cpu_p_biases = load_model_weights(
+ os.path.join(RESOURCES_ROOT, "p_head_v1.npz")
+ )
+ self.w_threshold, self.p_threshold = 0.5, 0.5
+
+ @torch.inference_mode()
+ def __call__(self, images: torch.Tensor) -> torch.Tensor:
+ imgs = clip_process_images(images)
+ if self._device is None:
+ self._device = next(p for p in self.clip_model.parameters()).device
+ image_features = self.clip_model.encode_image(imgs.to(self._device))
+ image_features = image_features.detach().cpu().numpy().astype(np.float16)
+ p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
+ w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
+ print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None
+ query = p_pred > self.p_threshold
+ if query.sum() > 0:
+ print(f"Hit for p_threshold: {p_pred}") if self.verbose else None
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
+ query = w_pred > self.w_threshold
+ if query.sum() > 0:
+ print(f"Hit for w_threshold: {w_pred}") if self.verbose else None
+ images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query])
+ return images
+
+
+def load_img(path: str) -> torch.Tensor:
+ image = Image.open(path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image_transforms = T.Compose(
+ [
+ T.ToTensor(),
+ ]
+ )
+ return image_transforms(image)[None, ...]
+
+
+def test(root):
+ from einops import rearrange
+
+ filter = DeepFloydDataFiltering(verbose=True)
+ for p in os.listdir((root)):
+ print(f"running on {p}...")
+ img = load_img(os.path.join(root, p))
+ filtered_img = filter(img)
+ filtered_img = rearrange(
+ 255.0 * (filtered_img.numpy())[0], "c h w -> h w c"
+ ).astype(np.uint8)
+ Image.fromarray(filtered_img).save(
+ os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg")
+ )
+
+
+if __name__ == "__main__":
+ import fire
+
+ fire.Fire(test)
+ print("done.")
diff --git a/sgm/__init__.py b/sgm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..24bc84af8b1041de34b9816e0507cb1ac207bd13
--- /dev/null
+++ b/sgm/__init__.py
@@ -0,0 +1,4 @@
+from .models import AutoencodingEngine, DiffusionEngine
+from .util import get_configs_path, instantiate_from_config
+
+__version__ = "0.1.0"
diff --git a/sgm/data/__init__.py b/sgm/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7664a25c655c376bd1a7b0ccbaca7b983a2bf9ad
--- /dev/null
+++ b/sgm/data/__init__.py
@@ -0,0 +1 @@
+from .dataset import StableDataModuleFromConfig
diff --git a/sgm/data/cam_utils.py b/sgm/data/cam_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d44b38721dafc771c092887d93726b38e1ec0a6
--- /dev/null
+++ b/sgm/data/cam_utils.py
@@ -0,0 +1,1253 @@
+'''
+Common camera utilities
+'''
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+from pytorch3d.renderer import PerspectiveCameras
+from pytorch3d.renderer.cameras import look_at_view_transform
+from pytorch3d.renderer.implicit.raysampling import _xy_to_ray_bundle
+
+class RelativeCameraLoader(nn.Module):
+ def __init__(self,
+ query_batch_size=1,
+ rand_query=True,
+ relative=True,
+ center_at_origin=False,
+ ):
+ super().__init__()
+
+ self.query_batch_size = query_batch_size
+ self.rand_query = rand_query
+ self.relative = relative
+ self.center_at_origin = center_at_origin
+
+ def plot_cameras(self, cameras_1, cameras_2):
+ '''
+ Helper function to plot cameras
+
+ Args:
+ cameras_1 (PyTorch3D camera): cameras object to plot
+ cameras_2 (PyTorch3D camera): cameras object to plot
+ '''
+ from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
+ import plotly.graph_objects as go
+ plotlyplot = plot_scene(
+ {
+ 'scene_batch': {
+ 'cameras': cameras_1.to('cpu'),
+ 'rel_cameras': cameras_2.to('cpu'),
+ }
+ },
+ camera_scale=.5,#0.05,
+ pointcloud_max_points=10000,
+ pointcloud_marker_size=1.0,
+ raybundle_max_rays=100
+ )
+ plotlyplot.show()
+
+ def concat_cameras(self, camera_list):
+ '''
+ Returns a concatenation of a list of cameras
+
+ Args:
+ camera_list (List[PyTorch3D camera]): a list of PyTorch3D cameras
+ '''
+ R_list, T_list, f_list, c_list, size_list = [], [], [], [], []
+ for cameras in camera_list:
+ R_list.append(cameras.R)
+ T_list.append(cameras.T)
+ f_list.append(cameras.focal_length)
+ c_list.append(cameras.principal_point)
+ size_list.append(cameras.image_size)
+
+ camera_slice = PerspectiveCameras(
+ R = torch.cat(R_list),
+ T = torch.cat(T_list),
+ focal_length = torch.cat(f_list),
+ principal_point = torch.cat(c_list),
+ image_size = torch.cat(size_list),
+ device = camera_list[0].device,
+ )
+ return camera_slice
+
+ def get_camera_slice(self, scene_cameras, indices):
+ '''
+ Return a subset of cameras from a super set given indices
+
+ Args:
+ scene_cameras (PyTorch3D Camera): cameras object
+ indices (tensor or List): a flat list or tensor of indices
+
+ Returns:
+ camera_slice (PyTorch3D Camera) - cameras subset
+ '''
+ camera_slice = PerspectiveCameras(
+ R = scene_cameras.R[indices],
+ T = scene_cameras.T[indices],
+ focal_length = scene_cameras.focal_length[indices],
+ principal_point = scene_cameras.principal_point[indices],
+ image_size = scene_cameras.image_size[indices],
+ device = scene_cameras.device,
+ )
+ return camera_slice
+
+
+ def get_relative_camera(self, scene_cameras:PerspectiveCameras, query_idx, center_at_origin=False):
+ """
+ Transform context cameras relative to a base query camera
+
+ Args:
+ scene_cameras (PyTorch3D Camera): cameras object
+ query_idx (tensor or List): a length 1 list defining query idx
+
+ Returns:
+ cams_relative (PyTorch3D Camera): cameras object relative to query camera
+ """
+
+ query_camera = self.get_camera_slice(scene_cameras, query_idx)
+ query_world2view = query_camera.get_world_to_view_transform()
+ all_world2view = scene_cameras.get_world_to_view_transform()
+
+ if center_at_origin:
+ identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=query_camera.T)
+ else:
+ T = torch.zeros((1, 3))
+ identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=T)
+
+ identity_world2view = identity_cam.get_world_to_view_transform()
+
+ # compose the relative transformation as g_i^{-1} g_j
+ relative_world2view = identity_world2view.inverse().compose(all_world2view)
+
+ # generate a camera from the relative transform
+ relative_matrix = relative_world2view.get_matrix()
+ cams_relative = PerspectiveCameras(
+ R = relative_matrix[:, :3, :3],
+ T = relative_matrix[:, 3, :3],
+ focal_length = scene_cameras.focal_length,
+ principal_point = scene_cameras.principal_point,
+ image_size = scene_cameras.image_size,
+ device = scene_cameras.device,
+ )
+ return cams_relative
+
+ def forward(self, scene_cameras, scene_rgb=None, scene_masks=None, query_idx=None, context_size=3, context_idx=None, return_context=False):
+ '''
+ Return a sampled batch of query and context cameras (used in training)
+
+ Args:
+ scene_cameras (PyTorch3D Camera): a batch of PyTorch3D cameras
+ scene_rgb (Tensor): a batch of rgb
+ scene_masks (Tensor): a batch of masks (optional)
+ query_idx (List or Tensor): desired query idx (optional)
+ context_size (int): number of views for context
+
+ Returns:
+ query_cameras, query_rgb, query_masks: random query view
+ context_cameras, context_rgb, context_masks: context views
+ '''
+
+ if query_idx is None:
+ query_idx = [0]
+ if self.rand_query:
+ rand = torch.randperm(len(scene_cameras))
+ query_idx = rand[:1]
+
+ if context_idx is None:
+ rand = torch.randperm(len(scene_cameras))
+ context_idx = rand[:context_size]
+
+
+ if self.relative:
+ rel_cameras = self.get_relative_camera(scene_cameras, query_idx, center_at_origin=self.center_at_origin)
+ else:
+ rel_cameras = scene_cameras
+
+ query_cameras = self.get_camera_slice(rel_cameras, query_idx)
+ query_rgb = None
+ if scene_rgb is not None:
+ query_rgb = scene_rgb[query_idx]
+ query_masks = None
+ if scene_masks is not None:
+ query_masks = scene_masks[query_idx]
+
+ context_cameras = self.get_camera_slice(rel_cameras, context_idx)
+ context_rgb = None
+ if scene_rgb is not None:
+ context_rgb = scene_rgb[context_idx]
+ context_masks = None
+ if scene_masks is not None:
+ context_masks = scene_masks[context_idx]
+
+ if return_context:
+ return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks, context_idx
+ return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks
+
+
+def get_interpolated_path(cameras: PerspectiveCameras, n=50, method='circle', theta_offset_max=0.0):
+ '''
+ Given a camera object containing a set of cameras, fit a circle and get
+ interpolated cameras
+
+ Args:
+ cameras (PyTorch3D Camera): input camera object
+ n (int): length of cameras in new path
+ method (str): 'circle'
+ theta_offset_max (int): max camera jitter in radians
+
+ Returns:
+ path_cameras (PyTorch3D Camera): interpolated cameras
+ '''
+ device = cameras.device
+ cameras = cameras.cpu()
+
+ if method == 'circle':
+
+ #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
+ #@ Fit plane
+ P = cameras.get_camera_center().cpu()
+ P_mean = P.mean(axis=0)
+ P_centered = P - P_mean
+ U,s,V = torch.linalg.svd(P_centered)
+ normal = V[2,:]
+ if (normal*2 - P_mean).norm() < (normal - P_mean).norm():
+ normal = - normal
+ d = -torch.dot(P_mean, normal) # d = -
+
+ #@ Project pts to plane
+ P_xy = rodrigues_rot(P_centered, normal, torch.tensor([0.0,0.0,1.0]))
+
+ #@ Fit circle in 2D
+ xc, yc, r = fit_circle_2d(P_xy[:,0], P_xy[:,1])
+ t = torch.linspace(0, 2*math.pi, 100)
+ xx = xc + r*torch.cos(t)
+ yy = yc + r*torch.sin(t)
+
+ #@ Project circle to 3D
+ C = rodrigues_rot(torch.tensor([xc,yc,0.0]), torch.tensor([0.0,0.0,1.0]), normal) + P_mean
+ C = C.flatten()
+
+ #@ Get pts n 3D
+ t = torch.linspace(0, 2*math.pi, n)
+ u = P[0] - C
+ new_camera_centers = generate_circle_by_vectors(t, C, r, normal, u)
+
+ #@ OPTIONAL THETA OFFSET
+ if theta_offset_max > 0.0:
+ aug_theta = (torch.rand((new_camera_centers.shape[0])) * (2*theta_offset_max)) - theta_offset_max
+ new_camera_centers = rodrigues_rot2(new_camera_centers, normal, aug_theta)
+
+ #@ Get camera look at
+ new_camera_look_at = get_nearest_centroid(cameras)
+
+ #@ Get R T
+ up_vec = -normal
+ R, T = look_at_view_transform(eye=new_camera_centers, at=new_camera_look_at.unsqueeze(0), up=up_vec.unsqueeze(0), device=cameras.device)
+ else:
+ raise NotImplementedError
+
+ c = (cameras.principal_point).mean(dim=0, keepdim=True).expand(R.shape[0],-1)
+ f = (cameras.focal_length).mean(dim=0, keepdim=True).expand(R.shape[0],-1)
+ image_size = cameras.image_size[:1].expand(R.shape[0],-1)
+
+
+ path_cameras = PerspectiveCameras(R=R,T=T,focal_length=f,principal_point=c,image_size=image_size, device=device)
+ cameras = cameras.to(device)
+ return path_cameras
+
+def np_normalize(vec, axis=-1):
+ vec = vec / (np.linalg.norm(vec, axis=axis, keepdims=True) + 1e-9)
+ return vec
+
+
+#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
+#-------------------------------------------------------------------------------
+# Generate points on circle
+# P(t) = r*cos(t)*u + r*sin(t)*(n x u) + C
+#-------------------------------------------------------------------------------
+def generate_circle_by_vectors(t, C, r, n, u):
+ n = n/torch.linalg.norm(n)
+ u = u/torch.linalg.norm(u)
+ P_circle = r*torch.cos(t)[:,None]*u + r*torch.sin(t)[:,None]*torch.cross(n,u) + C
+ return P_circle
+
+#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
+#-------------------------------------------------------------------------------
+# FIT CIRCLE 2D
+# - Find center [xc, yc] and radius r of circle fitting to set of 2D points
+# - Optionally specify weights for points
+#
+# - Implicit circle function:
+# (x-xc)^2 + (y-yc)^2 = r^2
+# (2*xc)*x + (2*yc)*y + (r^2-xc^2-yc^2) = x^2+y^2
+# c[0]*x + c[1]*y + c[2] = x^2+y^2
+#
+# - Solution by method of least squares:
+# A*c = b, c' = argmin(||A*c - b||^2)
+# A = [x y 1], b = [x^2+y^2]
+#-------------------------------------------------------------------------------
+def fit_circle_2d(x, y, w=[]):
+
+ A = torch.stack([x, y, torch.ones(len(x))]).T
+ b = x**2 + y**2
+
+ # Modify A,b for weighted least squares
+ if len(w) == len(x):
+ W = torch.diag(w)
+ A = torch.dot(W,A)
+ b = torch.dot(W,b)
+
+ # Solve by method of least squares
+ c = torch.linalg.lstsq(A,b,rcond=None)[0]
+
+ # Get circle parameters from solution c
+ xc = c[0]/2
+ yc = c[1]/2
+ r = torch.sqrt(c[2] + xc**2 + yc**2)
+ return xc, yc, r
+
+#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
+#-------------------------------------------------------------------------------
+# RODRIGUES ROTATION
+# - Rotate given points based on a starting and ending vector
+# - Axis k and angle of rotation theta given by vectors n0,n1
+# P_rot = P*cos(theta) + (k x P)*sin(theta) + k**(1-cos(theta))
+#-------------------------------------------------------------------------------
+def rodrigues_rot(P, n0, n1):
+
+ # If P is only 1d array (coords of single point), fix it to be matrix
+ if P.ndim == 1:
+ P = P[None,...]
+
+ # Get vector of rotation k and angle theta
+ n0 = n0/torch.linalg.norm(n0)
+ n1 = n1/torch.linalg.norm(n1)
+ k = torch.cross(n0,n1)
+ k = k/torch.linalg.norm(k)
+ theta = torch.arccos(torch.dot(n0,n1))
+
+ # Compute rotated points
+ P_rot = torch.zeros((len(P),3))
+ for i in range(len(P)):
+ P_rot[i] = P[i]*torch.cos(theta) + torch.cross(k,P[i])*torch.sin(theta) + k*torch.dot(k,P[i])*(1-torch.cos(theta))
+
+ return P_rot
+
+def rodrigues_rot2(P, n1, theta):
+ '''
+ Rotate points P wrt axis k by theta radians
+ '''
+
+ # If P is only 1d array (coords of single point), fix it to be matrix
+ if P.ndim == 1:
+ P = P[None,...]
+
+ k = torch.cross(P, n1.unsqueeze(0))
+ k = k/torch.linalg.norm(k)
+
+ # Compute rotated points
+ P_rot = torch.zeros((len(P),3))
+ for i in range(len(P)):
+ P_rot[i] = P[i]*torch.cos(theta[i]) + torch.cross(k[i],P[i])*torch.sin(theta[i]) + k[i]*torch.dot(k[i],P[i])*(1-torch.cos(theta[i]))
+
+ return P_rot
+
+#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/
+#-------------------------------------------------------------------------------
+# ANGLE BETWEEN
+# - Get angle between vectors u,v with sign based on plane with unit normal n
+#-------------------------------------------------------------------------------
+def angle_between(u, v, n=None):
+ if n is None:
+ return torch.arctan2(torch.linalg.norm(torch.cross(u,v)), torch.dot(u,v))
+ else:
+ return torch.arctan2(torch.dot(n,torch.cross(u,v)), torch.dot(u,v))
+
+#@ https://www.crewes.org/Documents/ResearchReports/2010/CRR201032.pdf
+def get_nearest_centroid(cameras: PerspectiveCameras):
+ '''
+ Given PyTorch3D cameras, find the nearest point along their principal ray
+ '''
+
+ #@ GET CAMERA CENTERS AND DIRECTIONS
+ camera_centers = cameras.get_camera_center()
+
+ c_mean = (cameras.principal_point).mean(dim=0)
+ xy_grid = c_mean.unsqueeze(0).unsqueeze(0)
+ ray_vis = _xy_to_ray_bundle(cameras, xy_grid.expand(len(cameras),-1,-1), 1.0, 15.0, 20, True)
+ camera_directions = ray_vis.directions
+
+ #@ CONSTRUCT MATRICIES
+ A = torch.zeros((3*len(cameras)), len(cameras)+3)
+ b = torch.zeros((3*len(cameras), 1))
+ A[:,:3] = torch.eye(3).repeat(len(cameras),1)
+ for ci in range(len(camera_directions)):
+ A[3*ci:3*ci+3, ci+3] = -camera_directions[ci]
+ b[3*ci:3*ci+3, 0] = camera_centers[ci]
+ #' A (3*N, 3*N+3) b (3*N, 1)
+
+ #@ SVD
+ U, s, VT = torch.linalg.svd(A)
+ Sinv = torch.diag(1/s)
+ if len(s) < 3*len(cameras):
+ Sinv = torch.cat((Sinv, torch.zeros((Sinv.shape[0], 3*len(cameras) - Sinv.shape[1]), device=Sinv.device)), dim=1)
+ x = torch.matmul(VT.T, torch.matmul(Sinv,torch.matmul(U.T, b)))
+
+ centroid = x[:3,0]
+ return centroid
+
+
+def get_angles(target_camera: PerspectiveCameras, context_cameras: PerspectiveCameras, centroid=None):
+ '''
+ Get angles between cameras wrt a centroid
+
+ Args:
+ target_camera (Pytorch3D Camera): a camera object with a single camera
+ context_cameras (PyTorch3D Camera): a camera object
+
+ Returns:
+ theta_deg (Tensor): a tensor containing angles in degrees
+ '''
+ a1 = target_camera.get_camera_center()
+ b1 = context_cameras.get_camera_center()
+
+ a = a1 - centroid.unsqueeze(0)
+ a = a.expand(len(context_cameras), -1)
+ b = b1 - centroid.unsqueeze(0)
+
+ ab_dot = (a*b).sum(dim=-1)
+ theta = torch.acos((ab_dot)/(torch.linalg.norm(a, dim=-1) * torch.linalg.norm(b, dim=-1)))
+ theta_deg = theta * 180 / math.pi
+
+ return theta_deg
+
+
+import math
+from typing import List, Literal, Optional, Tuple
+
+import numpy as np
+import torch
+from jaxtyping import Float
+from numpy.typing import NDArray
+from torch import Tensor
+
+_EPS = np.finfo(float).eps * 4.0
+
+
+def unit_vector(data: NDArray, axis: Optional[int] = None) -> np.ndarray:
+ """Return ndarray normalized by length, i.e. Euclidean norm, along axis.
+
+ Args:
+ axis: the axis along which to normalize into unit vector
+ out: where to write out the data to. If None, returns a new np ndarray
+ """
+ data = np.array(data, dtype=np.float64, copy=True)
+ if data.ndim == 1:
+ data /= math.sqrt(np.dot(data, data))
+ return data
+ length = np.atleast_1d(np.sum(data * data, axis))
+ np.sqrt(length, length)
+ if axis is not None:
+ length = np.expand_dims(length, axis)
+ data /= length
+ return data
+
+
+def quaternion_from_matrix(matrix: NDArray, isprecise: bool = False) -> np.ndarray:
+ """Return quaternion from rotation matrix.
+
+ Args:
+ matrix: rotation matrix to obtain quaternion
+ isprecise: if True, input matrix is assumed to be precise rotation matrix and a faster algorithm is used.
+ """
+ M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4]
+ if isprecise:
+ q = np.empty((4,))
+ t = np.trace(M)
+ if t > M[3, 3]:
+ q[0] = t
+ q[3] = M[1, 0] - M[0, 1]
+ q[2] = M[0, 2] - M[2, 0]
+ q[1] = M[2, 1] - M[1, 2]
+ else:
+ i, j, k = 1, 2, 3
+ if M[1, 1] > M[0, 0]:
+ i, j, k = 2, 3, 1
+ if M[2, 2] > M[i, i]:
+ i, j, k = 3, 1, 2
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
+ q[i] = t
+ q[j] = M[i, j] + M[j, i]
+ q[k] = M[k, i] + M[i, k]
+ q[3] = M[k, j] - M[j, k]
+ q *= 0.5 / math.sqrt(t * M[3, 3])
+ else:
+ m00 = M[0, 0]
+ m01 = M[0, 1]
+ m02 = M[0, 2]
+ m10 = M[1, 0]
+ m11 = M[1, 1]
+ m12 = M[1, 2]
+ m20 = M[2, 0]
+ m21 = M[2, 1]
+ m22 = M[2, 2]
+ # symmetric matrix K
+ K = [
+ [m00 - m11 - m22, 0.0, 0.0, 0.0],
+ [m01 + m10, m11 - m00 - m22, 0.0, 0.0],
+ [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0],
+ [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
+ ]
+ K = np.array(K)
+ K /= 3.0
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
+ w, V = np.linalg.eigh(K)
+ q = V[np.array([3, 0, 1, 2]), np.argmax(w)]
+ if q[0] < 0.0:
+ np.negative(q, q)
+ return q
+
+
+def quaternion_slerp(
+ quat0: NDArray, quat1: NDArray, fraction: float, spin: int = 0, shortestpath: bool = True
+) -> np.ndarray:
+ """Return spherical linear interpolation between two quaternions.
+ Args:
+ quat0: first quaternion
+ quat1: second quaternion
+ fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1)
+ spin: how much of an additional spin to place on the interpolation
+ shortestpath: whether to return the short or long path to rotation
+ """
+ q0 = unit_vector(quat0[:4])
+ q1 = unit_vector(quat1[:4])
+ if q0 is None or q1 is None:
+ raise ValueError("Input quaternions invalid.")
+ if fraction == 0.0:
+ return q0
+ if fraction == 1.0:
+ return q1
+ d = np.dot(q0, q1)
+ if abs(abs(d) - 1.0) < _EPS:
+ return q0
+ if shortestpath and d < 0.0:
+ # invert rotation
+ d = -d
+ np.negative(q1, q1)
+ angle = math.acos(d) + spin * math.pi
+ if abs(angle) < _EPS:
+ return q0
+ isin = 1.0 / math.sin(angle)
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
+ q1 *= math.sin(fraction * angle) * isin
+ q0 += q1
+ return q0
+
+
+def quaternion_matrix(quaternion: NDArray) -> np.ndarray:
+ """Return homogeneous rotation matrix from quaternion.
+
+ Args:
+ quaternion: value to convert to matrix
+ """
+ q = np.array(quaternion, dtype=np.float64, copy=True)
+ n = np.dot(q, q)
+ if n < _EPS:
+ return np.identity(4)
+ q *= math.sqrt(2.0 / n)
+ q = np.outer(q, q)
+ return np.array(
+ [
+ [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0],
+ [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0],
+ [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ]
+ )
+
+
+def get_interpolated_poses(pose_a: NDArray, pose_b: NDArray, steps: int = 10) -> List[float]:
+ """Return interpolation of poses with specified number of steps.
+ Args:
+ pose_a: first pose
+ pose_b: second pose
+ steps: number of steps the interpolated pose path should contain
+ """
+
+ quat_a = quaternion_from_matrix(pose_a[:3, :3])
+ quat_b = quaternion_from_matrix(pose_b[:3, :3])
+
+ ts = np.linspace(0, 1, steps)
+ quats = [quaternion_slerp(quat_a, quat_b, t) for t in ts]
+ trans = [(1 - t) * pose_a[:3, 3] + t * pose_b[:3, 3] for t in ts]
+
+ poses_ab = []
+ for quat, tran in zip(quats, trans):
+ pose = np.identity(4)
+ pose[:3, :3] = quaternion_matrix(quat)[:3, :3]
+ pose[:3, 3] = tran
+ poses_ab.append(pose[:3])
+ return poses_ab
+
+
+def get_interpolated_k(
+ k_a: Float[Tensor, "3 3"], k_b: Float[Tensor, "3 3"], steps: int = 10
+) -> List[Float[Tensor, "3 4"]]:
+ """
+ Returns interpolated path between two camera poses with specified number of steps.
+
+ Args:
+ k_a: camera matrix 1
+ k_b: camera matrix 2
+ steps: number of steps the interpolated pose path should contain
+
+ Returns:
+ List of interpolated camera poses
+ """
+ Ks: List[Float[Tensor, "3 3"]] = []
+ ts = np.linspace(0, 1, steps)
+ for t in ts:
+ new_k = k_a * (1.0 - t) + k_b * t
+ Ks.append(new_k)
+ return Ks
+
+
+def get_ordered_poses_and_k(
+ poses: Float[Tensor, "num_poses 3 4"],
+ Ks: Float[Tensor, "num_poses 3 3"],
+) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
+ """
+ Returns ordered poses and intrinsics by euclidian distance between poses.
+
+ Args:
+ poses: list of camera poses
+ Ks: list of camera intrinsics
+
+ Returns:
+ tuple of ordered poses and intrinsics
+
+ """
+
+ poses_num = len(poses)
+
+ ordered_poses = torch.unsqueeze(poses[0], 0)
+ ordered_ks = torch.unsqueeze(Ks[0], 0)
+
+ # remove the first pose from poses
+ poses = poses[1:]
+ Ks = Ks[1:]
+
+ for _ in range(poses_num - 1):
+ distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1)
+ idx = torch.argmin(distances)
+ ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0)
+ ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0)
+ poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0)
+ Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0)
+
+ return ordered_poses, ordered_ks
+
+
+def get_interpolated_poses_many(
+ poses: Float[Tensor, "num_poses 3 4"],
+ Ks: Float[Tensor, "num_poses 3 3"],
+ steps_per_transition: int = 10,
+ order_poses: bool = False,
+) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]:
+ """Return interpolated poses for many camera poses.
+
+ Args:
+ poses: list of camera poses
+ Ks: list of camera intrinsics
+ steps_per_transition: number of steps per transition
+ order_poses: whether to order poses by euclidian distance
+
+ Returns:
+ tuple of new poses and intrinsics
+ """
+ traj = []
+ k_interp = []
+
+ if order_poses:
+ poses, Ks = get_ordered_poses_and_k(poses, Ks)
+
+ for idx in range(poses.shape[0] - 1):
+ pose_a = poses[idx].cpu().numpy()
+ pose_b = poses[idx + 1].cpu().numpy()
+ poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition)
+ traj += poses_ab
+ k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition)
+
+ traj = np.stack(traj, axis=0)
+ k_interp = torch.stack(k_interp, dim=0)
+
+ return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32)
+
+
+def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]:
+ """Returns a normalized vector."""
+ return x / torch.linalg.norm(x)
+
+
+def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Normalize tensor along axis and return normalized value with norms.
+
+ Args:
+ x: tensor to normalize.
+ dim: axis along which to normalize.
+
+ Returns:
+ Tuple of normalized tensor and corresponding norm.
+ """
+
+ norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x))
+ return x / norm, norm
+
+
+def viewmatrix(lookat: torch.Tensor, up: torch.Tensor, pos: torch.Tensor) -> Float[Tensor, "*batch"]:
+ """Returns a camera transformation matrix.
+
+ Args:
+ lookat: The direction the camera is looking.
+ up: The upward direction of the camera.
+ pos: The position of the camera.
+
+ Returns:
+ A camera transformation matrix.
+ """
+ vec2 = normalize(lookat)
+ vec1_avg = normalize(up)
+ vec0 = normalize(torch.cross(vec1_avg, vec2))
+ vec1 = normalize(torch.cross(vec2, vec0))
+ m = torch.stack([vec0, vec1, vec2, pos], 1)
+ return m
+
+
+def get_distortion_params(
+ k1: float = 0.0,
+ k2: float = 0.0,
+ k3: float = 0.0,
+ k4: float = 0.0,
+ p1: float = 0.0,
+ p2: float = 0.0,
+) -> Float[Tensor, "*batch"]:
+ """Returns a distortion parameters matrix.
+
+ Args:
+ k1: The first radial distortion parameter.
+ k2: The second radial distortion parameter.
+ k3: The third radial distortion parameter.
+ k4: The fourth radial distortion parameter.
+ p1: The first tangential distortion parameter.
+ p2: The second tangential distortion parameter.
+ Returns:
+ torch.Tensor: A distortion parameters matrix.
+ """
+ return torch.Tensor([k1, k2, k3, k4, p1, p2])
+
+
+def _compute_residual_and_jacobian(
+ x: torch.Tensor,
+ y: torch.Tensor,
+ xd: torch.Tensor,
+ yd: torch.Tensor,
+ distortion_params: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Auxiliary function of radial_and_tangential_undistort() that computes residuals and jacobians.
+ Adapted from MultiNeRF:
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L427-L474
+
+ Args:
+ x: The updated x coordinates.
+ y: The updated y coordinates.
+ xd: The distorted x coordinates.
+ yd: The distorted y coordinates.
+ distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2].
+
+ Returns:
+ The residuals (fx, fy) and jacobians (fx_x, fx_y, fy_x, fy_y).
+ """
+
+ k1 = distortion_params[..., 0]
+ k2 = distortion_params[..., 1]
+ k3 = distortion_params[..., 2]
+ k4 = distortion_params[..., 3]
+ p1 = distortion_params[..., 4]
+ p2 = distortion_params[..., 5]
+
+ # let r(x, y) = x^2 + y^2;
+ # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 +
+ # k4 * r(x, y)^4;
+ r = x * x + y * y
+ d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4)))
+
+ # The perfect projection is:
+ # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2);
+ # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2);
+ #
+ # Let's define
+ #
+ # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd;
+ # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd;
+ #
+ # We are looking for a solution that satisfies
+ # fx(x, y) = fy(x, y) = 0;
+ fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd
+ fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd
+
+ # Compute derivative of d over [x, y]
+ d_r = k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4))
+ d_x = 2.0 * x * d_r
+ d_y = 2.0 * y * d_r
+
+ # Compute derivative of fx over x and y.
+ fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x
+ fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y
+
+ # Compute derivative of fy over x and y.
+ fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x
+ fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y
+
+ return fx, fy, fx_x, fx_y, fy_x, fy_y
+
+
+# @torch_compile(dynamic=True, mode="reduce-overhead", backend="eager")
+def radial_and_tangential_undistort(
+ coords: torch.Tensor,
+ distortion_params: torch.Tensor,
+ eps: float = 1e-3,
+ max_iterations: int = 10,
+) -> torch.Tensor:
+ """Computes undistorted coords given opencv distortion parameters.
+ Adapted from MultiNeRF
+ https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L477-L509
+
+ Args:
+ coords: The distorted coordinates.
+ distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2].
+ eps: The epsilon for the convergence.
+ max_iterations: The maximum number of iterations to perform.
+
+ Returns:
+ The undistorted coordinates.
+ """
+
+ # Initialize from the distorted point.
+ x = coords[..., 0]
+ y = coords[..., 1]
+
+ for _ in range(max_iterations):
+ fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian(
+ x=x, y=y, xd=coords[..., 0], yd=coords[..., 1], distortion_params=distortion_params
+ )
+ denominator = fy_x * fx_y - fx_x * fy_y
+ x_numerator = fx * fy_y - fy * fx_y
+ y_numerator = fy * fx_x - fx * fy_x
+ step_x = torch.where(torch.abs(denominator) > eps, x_numerator / denominator, torch.zeros_like(denominator))
+ step_y = torch.where(torch.abs(denominator) > eps, y_numerator / denominator, torch.zeros_like(denominator))
+
+ x = x + step_x
+ y = y + step_y
+
+ return torch.stack([x, y], dim=-1)
+
+
+def rotation_matrix(a: Float[Tensor, "3"], b: Float[Tensor, "3"]) -> Float[Tensor, "3 3"]:
+ """Compute the rotation matrix that rotates vector a to vector b.
+
+ Args:
+ a: The vector to rotate.
+ b: The vector to rotate to.
+ Returns:
+ The rotation matrix.
+ """
+ a = a / torch.linalg.norm(a)
+ b = b / torch.linalg.norm(b)
+ v = torch.cross(a, b)
+ c = torch.dot(a, b)
+ # If vectors are exactly opposite, we add a little noise to one of them
+ if c < -1 + 1e-8:
+ eps = (torch.rand(3) - 0.5) * 0.01
+ return rotation_matrix(a + eps, b)
+ s = torch.linalg.norm(v)
+ skew_sym_mat = torch.Tensor(
+ [
+ [0, -v[2], v[1]],
+ [v[2], 0, -v[0]],
+ [-v[1], v[0], 0],
+ ]
+ )
+ return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8))
+
+
+def focus_of_attention(poses: Float[Tensor, "*num_poses 4 4"], initial_focus: Float[Tensor, "3"]) -> Float[Tensor, "3"]:
+ """Compute the focus of attention of a set of cameras. Only cameras
+ that have the focus of attention in front of them are considered.
+
+ Args:
+ poses: The poses to orient.
+ initial_focus: The 3D point views to decide which cameras are initially activated.
+
+ Returns:
+ The 3D position of the focus of attention.
+ """
+ # References to the same method in third-party code:
+ # https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145
+ # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197
+ active_directions = -poses[:, :3, 2:3]
+ active_origins = poses[:, :3, 3:4]
+ # initial value for testing if the focus_pt is in front or behind
+ focus_pt = initial_focus
+ # Prune cameras which have the current have the focus_pt behind them.
+ active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
+ done = False
+ # We need at least two active cameras, else fallback on the previous solution.
+ # This may be the "poses" solution if no cameras are active on first iteration, e.g.
+ # they are in an outward-looking configuration.
+ while torch.sum(active.int()) > 1 and not done:
+ active_directions = active_directions[active]
+ active_origins = active_origins[active]
+ # https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions
+ m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1)
+ mt_m = torch.transpose(m, -2, -1) @ m
+ focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0]
+ active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0
+ if active.all():
+ # the set of active cameras did not change, so we're done.
+ done = True
+ return focus_pt
+
+
+def auto_orient_and_center_poses(
+ poses: Float[Tensor, "*num_poses 4 4"],
+ method: Literal["pca", "up", "vertical", "none"] = "up",
+ center_method: Literal["poses", "focus", "none"] = "poses",
+) -> Tuple[Float[Tensor, "*num_poses 3 4"], Float[Tensor, "3 4"]]:
+ """Orients and centers the poses.
+
+ We provide three methods for orientation:
+
+ - pca: Orient the poses so that the principal directions of the camera centers are aligned
+ with the axes, Z corresponding to the smallest principal component.
+ This method works well when all of the cameras are in the same plane, for example when
+ images are taken using a mobile robot.
+ - up: Orient the poses so that the average up vector is aligned with the z axis.
+ This method works well when images are not at arbitrary angles.
+ - vertical: Orient the poses so that the Z 3D direction projects close to the
+ y axis in images. This method works better if cameras are not all
+ looking in the same 3D direction, which may happen in camera arrays or in LLFF.
+
+ There are two centering methods:
+
+ - poses: The poses are centered around the origin.
+ - focus: The origin is set to the focus of attention of all cameras (the
+ closest point to cameras optical axes). Recommended for inward-looking
+ camera configurations.
+
+ Args:
+ poses: The poses to orient.
+ method: The method to use for orientation.
+ center_method: The method to use to center the poses.
+
+ Returns:
+ Tuple of the oriented poses and the transform matrix.
+ """
+
+ origins = poses[..., :3, 3]
+
+ mean_origin = torch.mean(origins, dim=0)
+ translation_diff = origins - mean_origin
+
+ if center_method == "poses":
+ translation = mean_origin
+ elif center_method == "focus":
+ translation = focus_of_attention(poses, mean_origin)
+ elif center_method == "none":
+ translation = torch.zeros_like(mean_origin)
+ else:
+ raise ValueError(f"Unknown value for center_method: {center_method}")
+
+ if method == "pca":
+ _, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff)
+ eigvec = torch.flip(eigvec, dims=(-1,))
+
+ if torch.linalg.det(eigvec) < 0:
+ eigvec[:, 2] = -eigvec[:, 2]
+
+ transform = torch.cat([eigvec, eigvec @ -translation[..., None]], dim=-1)
+ oriented_poses = transform @ poses
+
+ if oriented_poses.mean(dim=0)[2, 1] < 0:
+ oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3]
+ elif method in ("up", "vertical"):
+ up = torch.mean(poses[:, :3, 1], dim=0)
+ up = up / torch.linalg.norm(up)
+ if method == "vertical":
+ # If cameras are not all parallel (e.g. not in an LLFF configuration),
+ # we can find the 3D direction that most projects vertically in all
+ # cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares
+ # problem is solved by SVD.
+ x_axis_matrix = poses[:, :3, 0]
+ _, S, Vh = torch.linalg.svd(x_axis_matrix, full_matrices=False)
+ # Singular values are S_i=||Xv_i|| for each right singular vector v_i.
+ # ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i
+ # are an orthonormal basis.
+ # ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the
+ # RMS of cosines between x axes and v_i. If the second smallest singular
+ # value corresponds to an angle error less than 10° (cos(80°)=0.17),
+ # this is probably a degenerate camera configuration (typical values
+ # are around 5° average error for the true vertical). In this case,
+ # rather than taking the vector corresponding to the smallest singular
+ # value, we project the "up" vector on the plane spanned by the two
+ # best singular vectors. We could also just fallback to the "up"
+ # solution.
+ if S[1] > 0.17 * math.sqrt(poses.shape[0]):
+ # regular non-degenerate configuration
+ up_vertical = Vh[2, :]
+ # It may be pointing up or down. Use "up" to disambiguate the sign.
+ up = up_vertical if torch.dot(up_vertical, up) > 0 else -up_vertical
+ else:
+ # Degenerate configuration: project "up" on the plane spanned by
+ # the last two right singular vectors (which are orthogonal to the
+ # first). v_0 is a unit vector, no need to divide by its norm when
+ # projecting.
+ up = up - Vh[0, :] * torch.dot(up, Vh[0, :])
+ # re-normalize
+ up = up / torch.linalg.norm(up)
+
+ rotation = rotation_matrix(up, torch.Tensor([0, 0, 1]))
+ transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1)
+ oriented_poses = transform @ poses
+ elif method == "none":
+ transform = torch.eye(4)
+ transform[:3, 3] = -translation
+ transform = transform[:3, :]
+ oriented_poses = transform @ poses
+ else:
+ raise ValueError(f"Unknown value for method: {method}")
+
+ return oriented_poses, transform
+
+
+@torch.jit.script
+def fisheye624_project(xyz, params):
+ """
+ Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
+ model project() function.
+ Inputs:
+ xyz: BxNx3 tensor of 3D points to be projected
+ params: Bx16 tensor of Fisheye624 parameters formatted like this:
+ [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
+ or Bx15 tensor of Fisheye624 parameters formatted like this:
+ [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
+ Outputs:
+ uv: BxNx2 tensor of 2D projections of xyz in image plane
+ Model for fisheye cameras with radial, tangential, and thin-prism distortion.
+ This model allows fu != fv.
+ Specifically, the model is:
+ uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion
+ [y_r]
+ proj = diag(fu,fv) * uvDistorted + [cu;cv];
+ where:
+ a = x/z, b = y/z, r = (a^2+b^2)^(1/2)
+ th = atan(r)
+ cosPhi = a/r, sinPhi = b/r
+ [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi]
+ [y_r] [sinPhi]
+ the number of terms in the series is determined by the template parameter numK.
+ tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1]
+ [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0]
+ where rd^2 = x_r^2 + y_r^2
+ thinPrismDistortion = [s0 * rd^2 + s1 rd^4]
+ [s2 * rd^2 + s3 rd^4]
+ Author: Daniel DeTone (ddetone@meta.com)
+ """
+
+ assert xyz.ndim == 3
+ assert params.ndim == 2
+ assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy"
+ eps = 1e-9
+ B, N = xyz.shape[0], xyz.shape[1]
+
+ # Radial correction.
+ z = xyz[:, :, 2].reshape(B, N, 1)
+ z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
+ ab = xyz[:, :, :2] / z
+ r = torch.norm(ab, dim=-1, p=2, keepdim=True)
+ th = torch.atan(r)
+ th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)
+ th_k = th.reshape(B, N, 1).clone()
+ for i in range(6):
+ th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2)
+ xr_yr = th_k * th_divr
+ uv_dist = xr_yr
+
+ # Tangential correction.
+ p0 = params[:, -6].reshape(B, 1)
+ p1 = params[:, -5].reshape(B, 1)
+ xr = xr_yr[:, :, 0].reshape(B, N)
+ yr = xr_yr[:, :, 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+ uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
+ uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
+ uv_dist = torch.stack([uv_dist_tu, uv_dist_tv], dim=-1) # Avoids in-place complaint.
+
+ # Thin Prism correction.
+ s0 = params[:, -4].reshape(B, 1)
+ s1 = params[:, -3].reshape(B, 1)
+ s2 = params[:, -2].reshape(B, 1)
+ s3 = params[:, -1].reshape(B, 1)
+ rd_4 = torch.square(rd_sq)
+ uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
+ uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
+
+ # Finally, apply standard terms: focal length and camera centers.
+ if params.shape[-1] == 15:
+ fx_fy = params[:, 0].reshape(B, 1, 1)
+ cx_cy = params[:, 1:3].reshape(B, 1, 2)
+ else:
+ fx_fy = params[:, 0:2].reshape(B, 1, 2)
+ cx_cy = params[:, 2:4].reshape(B, 1, 2)
+ result = uv_dist * fx_fy + cx_cy
+
+ return result
+
+
+# Core implementation of fisheye 624 unprojection. More details are documented here:
+# https://facebookresearch.github.io/projectaria_tools/docs/tech_insights/camera_intrinsic_models#the-fisheye62-model
+@torch.jit.script
+def fisheye624_unproject_helper(uv, params, max_iters: int = 5):
+ """
+ Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera
+ model. There is no analytical solution for the inverse of the project()
+ function so this solves an optimization problem using Newton's method to get
+ the inverse.
+ Inputs:
+ uv: BxNx2 tensor of 2D pixels to be unprojected
+ params: Bx16 tensor of Fisheye624 parameters formatted like this:
+ [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
+ or Bx15 tensor of Fisheye624 parameters formatted like this:
+ [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}]
+ Outputs:
+ xyz: BxNx3 tensor of 3D rays of uv points with z = 1.
+ Model for fisheye cameras with radial, tangential, and thin-prism distortion.
+ This model assumes fu=fv. This unproject function holds that:
+ X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0]
+ and
+ x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2]
+ Author: Daniel DeTone (ddetone@meta.com)
+ """
+
+ assert uv.ndim == 3, "Expected batched input shaped BxNx3"
+ assert params.ndim == 2
+ assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy"
+ eps = 1e-6
+ B, N = uv.shape[0], uv.shape[1]
+
+ if params.shape[-1] == 15:
+ fx_fy = params[:, 0].reshape(B, 1, 1)
+ cx_cy = params[:, 1:3].reshape(B, 1, 2)
+ else:
+ fx_fy = params[:, 0:2].reshape(B, 1, 2)
+ cx_cy = params[:, 2:4].reshape(B, 1, 2)
+
+ uv_dist = (uv - cx_cy) / fx_fy
+
+ # Compute xr_yr using Newton's method.
+ xr_yr = uv_dist.clone() # Initial guess.
+ for _ in range(max_iters):
+ uv_dist_est = xr_yr.clone()
+ # Tangential terms.
+ p0 = params[:, -6].reshape(B, 1)
+ p1 = params[:, -5].reshape(B, 1)
+ xr = xr_yr[:, :, 0].reshape(B, N)
+ yr = xr_yr[:, :, 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+ uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1)
+ uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0)
+ # Thin Prism terms.
+ s0 = params[:, -4].reshape(B, 1)
+ s1 = params[:, -3].reshape(B, 1)
+ s2 = params[:, -2].reshape(B, 1)
+ s3 = params[:, -1].reshape(B, 1)
+ rd_4 = torch.square(rd_sq)
+ uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
+ uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
+ # Compute the derivative of uv_dist w.r.t. xr_yr.
+ duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
+ duv_dist_dxr_yr[:, :, 0, 0] = 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1
+ offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0)
+ duv_dist_dxr_yr[:, :, 0, 1] = offdiag
+ duv_dist_dxr_yr[:, :, 1, 0] = offdiag
+ duv_dist_dxr_yr[:, :, 1, 1] = 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0
+ xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1]
+ temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
+ duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (xr_yr[:, :, 0] * temp1)
+ duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (xr_yr[:, :, 1] * temp1)
+ temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
+ duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (xr_yr[:, :, 0] * temp2)
+ duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (xr_yr[:, :, 1] * temp2)
+ # Compute 2x2 inverse manually here since torch.inverse() is very slow.
+ # Because this is slow: inv = duv_dist_dxr_yr.inverse()
+ # About a 10x reduction in speed with above line.
+ mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
+ a = mat[:, 0, 0].reshape(-1, 1, 1)
+ b = mat[:, 0, 1].reshape(-1, 1, 1)
+ c = mat[:, 1, 0].reshape(-1, 1, 1)
+ d = mat[:, 1, 1].reshape(-1, 1, 1)
+ det = 1.0 / ((a * d) - (b * c))
+ top = torch.cat([d, -b], dim=2)
+ bot = torch.cat([-c, a], dim=2)
+ inv = det * torch.cat([top, bot], dim=1)
+ inv = inv.reshape(B, N, 2, 2)
+ # Manually compute 2x2 @ 2x1 matrix multiply.
+ # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0]
+ diff = uv_dist - uv_dist_est
+ a = inv[:, :, 0, 0]
+ b = inv[:, :, 0, 1]
+ c = inv[:, :, 1, 0]
+ d = inv[:, :, 1, 1]
+ e = diff[:, :, 0]
+ f = diff[:, :, 1]
+ step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
+ # Newton step.
+ xr_yr = xr_yr + step
+
+ # Compute theta using Newton's method.
+ xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
+ th = xr_yr_norm.clone()
+ for _ in range(max_iters):
+ th_radial = uv.new_ones(B, N, 1)
+ dthd_th = uv.new_ones(B, N, 1)
+ for k in range(6):
+ r_k = params[:, -12 + k].reshape(B, 1, 1)
+ th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2))
+ dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2))
+ th_radial = th_radial * th
+ step = (xr_yr_norm - th_radial) / dthd_th
+ # handle dthd_th close to 0.
+ step = torch.where(dthd_th.abs() > eps, step, torch.sign(step) * eps * 10.0)
+ th = th + step
+ # Compute the ray direction using theta and xr_yr.
+ close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
+ ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)
+ ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
+ return ray
+
+
+# unproject 2D point to 3D with fisheye624 model
+def fisheye624_unproject(coords: torch.Tensor, distortion_params: torch.Tensor) -> torch.Tensor:
+ dirs = fisheye624_unproject_helper(coords.unsqueeze(0), distortion_params[0].unsqueeze(0))
+ # correct for camera space differences:
+ dirs[..., 1] = -dirs[..., 1]
+ dirs[..., 2] = -dirs[..., 2]
+ return dirs
diff --git a/sgm/data/cifar10.py b/sgm/data/cifar10.py
new file mode 100644
index 0000000000000000000000000000000000000000..6083646f136bad308a0485843b89234cf7a9d6cd
--- /dev/null
+++ b/sgm/data/cifar10.py
@@ -0,0 +1,67 @@
+import pytorch_lightning as pl
+import torchvision
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+
+class CIFAR10DataDictWrapper(Dataset):
+ def __init__(self, dset):
+ super().__init__()
+ self.dset = dset
+
+ def __getitem__(self, i):
+ x, y = self.dset[i]
+ return {"jpg": x, "cls": y}
+
+ def __len__(self):
+ return len(self.dset)
+
+
+class CIFAR10Loader(pl.LightningDataModule):
+ def __init__(self, batch_size, num_workers=0, shuffle=True):
+ super().__init__()
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
+ )
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.shuffle = shuffle
+ self.train_dataset = CIFAR10DataDictWrapper(
+ torchvision.datasets.CIFAR10(
+ root=".data/", train=True, download=True, transform=transform
+ )
+ )
+ self.test_dataset = CIFAR10DataDictWrapper(
+ torchvision.datasets.CIFAR10(
+ root=".data/", train=False, download=True, transform=transform
+ )
+ )
+
+ def prepare_data(self):
+ pass
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ )
diff --git a/sgm/data/co3d.py b/sgm/data/co3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba95cfbb540e4664b0fdb313f67bb5013bdea6bf
--- /dev/null
+++ b/sgm/data/co3d.py
@@ -0,0 +1,1367 @@
+"""
+adopted from SparseFusion
+Wrapper for the full CO3Dv2 dataset
+#@ Modified from https://github.com/facebookresearch/pytorch3d
+"""
+
+import json
+import logging
+import math
+import os
+import random
+import time
+import warnings
+from collections import defaultdict
+from itertools import islice
+from typing import (
+ Any,
+ ClassVar,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypedDict,
+ Union,
+)
+from einops import rearrange, repeat
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.functional as TF
+from pytorch3d.utils import opencv_from_cameras_projection
+from pytorch3d.implicitron.dataset import types
+from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
+from sgm.data.json_index_dataset import (
+ FrameAnnotsEntry,
+ _bbox_xywh_to_xyxy,
+ _bbox_xyxy_to_xywh,
+ _clamp_box_to_image_bounds_and_round,
+ _crop_around_box,
+ _get_1d_bounds,
+ _get_bbox_from_mask,
+ _get_clamp_bbox,
+ _load_1bit_png_mask,
+ _load_16big_png_depth,
+ _load_depth,
+ _load_depth_mask,
+ _load_image,
+ _load_mask,
+ _load_pointcloud,
+ _rescale_bbox,
+ _safe_as_tensor,
+ _seq_name_to_seed,
+)
+from sgm.data.objaverse import video_collate_fn
+from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
+ get_available_subset_names,
+)
+from pytorch3d.renderer.cameras import PerspectiveCameras
+
+logger = logging.getLogger(__name__)
+
+
+from dataclasses import dataclass, field, fields
+
+from pytorch3d.renderer.camera_utils import join_cameras_as_batch
+from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
+from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
+from pytorch_lightning import LightningDataModule
+from torch.utils.data import DataLoader
+
+CO3D_ALL_CATEGORIES = list(
+ reversed(
+ [
+ "baseballbat",
+ "banana",
+ "bicycle",
+ "microwave",
+ "tv",
+ "cellphone",
+ "toilet",
+ "hairdryer",
+ "couch",
+ "kite",
+ "pizza",
+ "umbrella",
+ "wineglass",
+ "laptop",
+ "hotdog",
+ "stopsign",
+ "frisbee",
+ "baseballglove",
+ "cup",
+ "parkingmeter",
+ "backpack",
+ "toyplane",
+ "toybus",
+ "handbag",
+ "chair",
+ "keyboard",
+ "car",
+ "motorcycle",
+ "carrot",
+ "bottle",
+ "sandwich",
+ "remote",
+ "bowl",
+ "skateboard",
+ "toaster",
+ "mouse",
+ "toytrain",
+ "book",
+ "toytruck",
+ "orange",
+ "broccoli",
+ "plant",
+ "teddybear",
+ "suitcase",
+ "bench",
+ "ball",
+ "cake",
+ "vase",
+ "hydrant",
+ "apple",
+ "donut",
+ ]
+ )
+)
+
+CO3D_ALL_TEN = [
+ "donut",
+ "apple",
+ "hydrant",
+ "vase",
+ "cake",
+ "ball",
+ "bench",
+ "suitcase",
+ "teddybear",
+ "plant",
+]
+
+
+# @ FROM https://github.com/facebookresearch/pytorch3d
+@dataclass
+class FrameData(Mapping[str, Any]):
+ """
+ A type of the elements returned by indexing the dataset object.
+ It can represent both individual frames and batches of thereof;
+ in this documentation, the sizes of tensors refer to single frames;
+ add the first batch dimension for the collation result.
+ Args:
+ frame_number: The number of the frame within its sequence.
+ 0-based continuous integers.
+ sequence_name: The unique name of the frame's sequence.
+ sequence_category: The object category of the sequence.
+ frame_timestamp: The time elapsed since the start of a sequence in sec.
+ image_size_hw: The size of the image in pixels; (height, width) tensor
+ of shape (2,).
+ image_path: The qualified path to the loaded image (with dataset_root).
+ image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
+ of the frame; elements are floats in [0, 1].
+ mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
+ regions. Regions can be invalid (mask_crop[i,j]=0) in case they
+ are a result of zero-padding of the image after cropping around
+ the object bounding box; elements are floats in {0.0, 1.0}.
+ depth_path: The qualified path to the frame's depth map.
+ depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
+ of the frame; values correspond to distances from the camera;
+ use `depth_mask` and `mask_crop` to filter for valid pixels.
+ depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
+ depth map that are valid for evaluation, they have been checked for
+ consistency across views; elements are floats in {0.0, 1.0}.
+ mask_path: A qualified path to the foreground probability mask.
+ fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
+ pixels belonging to the captured object; elements are floats
+ in [0, 1].
+ bbox_xywh: The bounding box tightly enclosing the foreground object in the
+ format (x0, y0, width, height). The convention assumes that
+ `x0+width` and `y0+height` includes the boundary of the box.
+ I.e., to slice out the corresponding crop from an image tensor `I`
+ we execute `crop = I[..., y0:y0+height, x0:x0+width]`
+ crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
+ in the original image coordinates in the format (x0, y0, width, height).
+ The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
+ from `bbox_xywh` due to padding (which can happen e.g. due to
+ setting `JsonIndexDataset.box_crop_context > 0`)
+ camera: A PyTorch3D camera object corresponding the frame's viewpoint,
+ corrected for cropping if it happened.
+ camera_quality_score: The score proportional to the confidence of the
+ frame's camera estimation (the higher the more accurate).
+ point_cloud_quality_score: The score proportional to the accuracy of the
+ frame's sequence point cloud (the higher the more accurate).
+ sequence_point_cloud_path: The path to the sequence's point cloud.
+ sequence_point_cloud: A PyTorch3D Pointclouds object holding the
+ point cloud corresponding to the frame's sequence. When the object
+ represents a batch of frames, point clouds may be deduplicated;
+ see `sequence_point_cloud_idx`.
+ sequence_point_cloud_idx: Integer indices mapping frame indices to the
+ corresponding point clouds in `sequence_point_cloud`; to get the
+ corresponding point cloud to `image_rgb[i]`, use
+ `sequence_point_cloud[sequence_point_cloud_idx[i]]`.
+ frame_type: The type of the loaded frame specified in
+ `subset_lists_file`, if provided.
+ meta: A dict for storing additional frame information.
+ """
+
+ frame_number: Optional[torch.LongTensor]
+ sequence_name: Union[str, List[str]]
+ sequence_category: Union[str, List[str]]
+ frame_timestamp: Optional[torch.Tensor] = None
+ image_size_hw: Optional[torch.Tensor] = None
+ image_path: Union[str, List[str], None] = None
+ image_rgb: Optional[torch.Tensor] = None
+ # masks out padding added due to cropping the square bit
+ mask_crop: Optional[torch.Tensor] = None
+ depth_path: Union[str, List[str], None] = ""
+ depth_map: Optional[torch.Tensor] = torch.zeros(1)
+ depth_mask: Optional[torch.Tensor] = torch.zeros(1)
+ mask_path: Union[str, List[str], None] = None
+ fg_probability: Optional[torch.Tensor] = None
+ bbox_xywh: Optional[torch.Tensor] = None
+ crop_bbox_xywh: Optional[torch.Tensor] = None
+ camera: Optional[PerspectiveCameras] = None
+ camera_quality_score: Optional[torch.Tensor] = None
+ point_cloud_quality_score: Optional[torch.Tensor] = None
+ sequence_point_cloud_path: Union[str, List[str], None] = ""
+ sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1)
+ sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1)
+ frame_type: Union[str, List[str], None] = "" # known | unseen
+ meta: dict = field(default_factory=lambda: {})
+ valid_region: Optional[torch.Tensor] = None
+ category_one_hot: Optional[torch.Tensor] = None
+
+ def to(self, *args, **kwargs):
+ new_params = {}
+ for f in fields(self):
+ value = getattr(self, f.name)
+ if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
+ new_params[f.name] = value.to(*args, **kwargs)
+ else:
+ new_params[f.name] = value
+ return type(self)(**new_params)
+
+ def cpu(self):
+ return self.to(device=torch.device("cpu"))
+
+ def cuda(self):
+ return self.to(device=torch.device("cuda"))
+
+ # the following functions make sure **frame_data can be passed to functions
+ def __iter__(self):
+ for f in fields(self):
+ yield f.name
+
+ def __getitem__(self, key):
+ return getattr(self, key)
+
+ def __len__(self):
+ return len(fields(self))
+
+ @classmethod
+ def collate(cls, batch):
+ """
+ Given a list objects `batch` of class `cls`, collates them into a batched
+ representation suitable for processing with deep networks.
+ """
+
+ elem = batch[0]
+
+ if isinstance(elem, cls):
+ pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
+ id_to_idx = defaultdict(list)
+ for i, pc_id in enumerate(pointcloud_ids):
+ id_to_idx[pc_id].append(i)
+
+ sequence_point_cloud = []
+ sequence_point_cloud_idx = -np.ones((len(batch),))
+ for i, ind in enumerate(id_to_idx.values()):
+ sequence_point_cloud_idx[ind] = i
+ sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
+ assert (sequence_point_cloud_idx >= 0).all()
+
+ override_fields = {
+ "sequence_point_cloud": sequence_point_cloud,
+ "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
+ }
+ # note that the pre-collate value of sequence_point_cloud_idx is unused
+
+ collated = {}
+ for f in fields(elem):
+ list_values = override_fields.get(
+ f.name, [getattr(d, f.name) for d in batch]
+ )
+ collated[f.name] = (
+ cls.collate(list_values)
+ if all(list_value is not None for list_value in list_values)
+ else None
+ )
+ return cls(**collated)
+
+ elif isinstance(elem, Pointclouds):
+ return join_pointclouds_as_batch(batch)
+
+ elif isinstance(elem, CamerasBase):
+ # TODO: don't store K; enforce working in NDC space
+ return join_cameras_as_batch(batch)
+ else:
+ return torch.utils.data._utils.collate.default_collate(batch)
+
+
+# @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d
+class CO3Dv2Wrapper(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ root_dir="/drive/datasets/co3d/",
+ category="hydrant",
+ subset="fewview_train",
+ stage="train",
+ sample_batch_size=20,
+ image_size=256,
+ masked=False,
+ deprecated_val_region=False,
+ return_frame_data_list=False,
+ reso: int = 256,
+ mask_type: str = "random",
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ fps_id=0.0,
+ motion_bucket_id=300.0,
+ num_frames: int = 20,
+ use_mask: bool = True,
+ load_pixelnerf: bool = True,
+ scale_pose: bool = True,
+ max_n_cond: int = 5,
+ min_n_cond: int = 2,
+ cond_on_multi: bool = False,
+ ):
+ root = root_dir
+ from typing import List
+
+ from co3d.dataset.data_types import (
+ FrameAnnotation,
+ SequenceAnnotation,
+ load_dataclass_jgzip,
+ )
+
+ self.dataset_root = root
+ self.path_manager = None
+ self.subset = subset
+ self.stage = stage
+ self.subset_lists_file: List[str] = [
+ f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json"
+ ]
+ self.subsets: Optional[List[str]] = [subset]
+ self.sample_batch_size = sample_batch_size
+ self.limit_to: int = 0
+ self.limit_sequences_to: int = 0
+ self.pick_sequence: Tuple[str, ...] = ()
+ self.exclude_sequence: Tuple[str, ...] = ()
+ self.limit_category_to: Tuple[int, ...] = ()
+ self.load_images: bool = True
+ self.load_depths: bool = False
+ self.load_depth_masks: bool = False
+ self.load_masks: bool = True
+ self.load_point_clouds: bool = False
+ self.max_points: int = 0
+ self.mask_images: bool = False
+ self.mask_depths: bool = False
+ self.image_height: Optional[int] = image_size
+ self.image_width: Optional[int] = image_size
+ self.box_crop: bool = True
+ self.box_crop_mask_thr: float = 0.4
+ self.box_crop_context: float = 0.3
+ self.remove_empty_masks: bool = True
+ self.n_frames_per_sequence: int = -1
+ self.seed: int = 0
+ self.sort_frames: bool = False
+ self.eval_batches: Any = None
+
+ self.img_h = self.image_height
+ self.img_w = self.image_width
+ self.masked = masked
+ self.deprecated_val_region = deprecated_val_region
+ self.return_frame_data_list = return_frame_data_list
+
+ self.reso = reso
+ self.num_frames = num_frames
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+ self.fps_id = fps_id
+ self.motion_bucket_id = motion_bucket_id
+ self.mask_type = mask_type
+ self.use_mask = use_mask
+ self.load_pixelnerf = load_pixelnerf
+ self.scale_pose = scale_pose
+ self.max_n_cond = max_n_cond
+ self.min_n_cond = min_n_cond
+ self.cond_on_multi = cond_on_multi
+
+ if self.cond_on_multi:
+ assert self.min_n_cond == self.max_n_cond
+
+ start_time = time.time()
+ if "all_" in category or category == "all":
+ self.category_frame_annotations = []
+ self.category_sequence_annotations = []
+ self.subset_lists_file = []
+
+ if category == "all":
+ cats = CO3D_ALL_CATEGORIES
+ elif category == "all_four":
+ cats = ["hydrant", "teddybear", "motorcycle", "bench"]
+ elif category == "all_ten":
+ cats = [
+ "donut",
+ "apple",
+ "hydrant",
+ "vase",
+ "cake",
+ "ball",
+ "bench",
+ "suitcase",
+ "teddybear",
+ "plant",
+ ]
+ elif category == "all_15":
+ cats = [
+ "hydrant",
+ "teddybear",
+ "motorcycle",
+ "bench",
+ "hotdog",
+ "remote",
+ "suitcase",
+ "donut",
+ "plant",
+ "toaster",
+ "keyboard",
+ "handbag",
+ "toyplane",
+ "tv",
+ "orange",
+ ]
+ else:
+ print("UNSPECIFIED CATEGORY SUBSET")
+ cats = ["hydrant", "teddybear"]
+ print("loading", cats)
+ for cat in cats:
+ self.category_frame_annotations.extend(
+ load_dataclass_jgzip(
+ f"{self.dataset_root}/{cat}/frame_annotations.jgz",
+ List[FrameAnnotation],
+ )
+ )
+ self.category_sequence_annotations.extend(
+ load_dataclass_jgzip(
+ f"{self.dataset_root}/{cat}/sequence_annotations.jgz",
+ List[SequenceAnnotation],
+ )
+ )
+ self.subset_lists_file.append(
+ f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json"
+ )
+
+ else:
+ self.category_frame_annotations = load_dataclass_jgzip(
+ f"{self.dataset_root}/{category}/frame_annotations.jgz",
+ List[FrameAnnotation],
+ )
+ self.category_sequence_annotations = load_dataclass_jgzip(
+ f"{self.dataset_root}/{category}/sequence_annotations.jgz",
+ List[SequenceAnnotation],
+ )
+
+ self.subset_to_image_path = None
+ self._load_frames()
+ self._load_sequences()
+ self._sort_frames()
+ self._load_subset_lists()
+ self._filter_db() # also computes sequence indices
+ # self._extract_and_set_eval_batches()
+ # print(self.eval_batches)
+ logger.info(str(self))
+
+ self.seq_to_frames = {}
+ for fi, item in enumerate(self.frame_annots):
+ if item["frame_annotation"].sequence_name in self.seq_to_frames:
+ self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi)
+ else:
+ self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi]
+
+ if self.stage != "test" or self.subset != "fewview_test":
+ count = 0
+ new_seq_to_frames = {}
+ for item in self.seq_to_frames:
+ if len(self.seq_to_frames[item]) > 10:
+ count += 1
+ new_seq_to_frames[item] = self.seq_to_frames[item]
+ self.seq_to_frames = new_seq_to_frames
+
+ self.seq_list = list(self.seq_to_frames.keys())
+
+ # @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG
+ remove_list = ["411_55952_107659", "376_42884_85882"]
+ for remove_idx in remove_list:
+ if remove_idx in self.seq_to_frames:
+ self.seq_list.remove(remove_idx)
+ print("removing", remove_idx)
+
+ print("total training seq", len(self.seq_to_frames))
+ print("data loading took", time.time() - start_time, "seconds")
+
+ self.all_category_list = list(CO3D_ALL_CATEGORIES)
+ self.all_category_list.sort()
+ self.cat_to_idx = {}
+ for ci, cname in enumerate(self.all_category_list):
+ self.cat_to_idx[cname] = ci
+
+ def __len__(self):
+ return len(self.seq_list)
+
+ def __getitem__(self, index):
+ seq_index = self.seq_list[index]
+
+ if self.subset == "fewview_test" and self.stage == "test":
+ batch_idx = torch.arange(len(self.seq_to_frames[seq_index]))
+
+ elif self.stage == "test":
+ batch_idx = (
+ torch.linspace(
+ 0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size
+ )
+ .long()
+ .tolist()
+ )
+ else:
+ rand = torch.randperm(len(self.seq_to_frames[seq_index]))
+ batch_idx = rand[: min(len(rand), self.sample_batch_size)]
+
+ frame_data_list = []
+ idx_list = []
+ timestamp_list = []
+ for idx in batch_idx:
+ idx_list.append(self.seq_to_frames[seq_index][idx])
+ timestamp_list.append(
+ self.frame_annots[self.seq_to_frames[seq_index][idx]][
+ "frame_annotation"
+ ].frame_timestamp
+ )
+ frame_data_list.append(
+ self._get_frame(int(self.seq_to_frames[seq_index][idx]))
+ )
+
+ time_order = torch.argsort(torch.tensor(timestamp_list))
+ frame_data_list = [frame_data_list[i] for i in time_order]
+
+ frame_data = FrameData.collate(frame_data_list)
+ image_size = torch.Tensor([self.image_height]).repeat(
+ frame_data.camera.R.shape[0], 2
+ )
+ frame_dict = {
+ "R": frame_data.camera.R,
+ "T": frame_data.camera.T,
+ "f": frame_data.camera.focal_length,
+ "c": frame_data.camera.principal_point,
+ "images": frame_data.image_rgb * frame_data.fg_probability
+ + (1 - frame_data.fg_probability),
+ "valid_region": frame_data.mask_crop,
+ "bbox": frame_data.valid_region,
+ "image_size": image_size,
+ "frame_type": frame_data.frame_type,
+ "idx": seq_index,
+ "category": frame_data.category_one_hot,
+ }
+ if not self.masked:
+ frame_dict["images_full"] = frame_data.image_rgb
+ frame_dict["masks"] = frame_data.fg_probability
+ frame_dict["mask_crop"] = frame_data.mask_crop
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+
+ def _pad(input):
+ return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[
+ : self.num_frames
+ ]
+
+ if len(frame_dict["images"]) < self.num_frames:
+ for k in frame_dict:
+ if isinstance(frame_dict[k], torch.Tensor):
+ frame_dict[k] = _pad(frame_dict[k])
+
+ data = dict()
+ if "images_full" in frame_dict:
+ frames = frame_dict["images_full"] * 2 - 1
+ else:
+ frames = frame_dict["images"] * 2 - 1
+ data["frames"] = frames
+ cond = frames[0]
+ data["cond_frames_without_noise"] = cond
+ data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
+ data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
+ data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
+ data["motion_bucket_id"] = torch.as_tensor(
+ [self.motion_bucket_id] * self.num_frames
+ )
+ data["num_video_frames"] = self.num_frames
+ data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
+
+ if self.load_pixelnerf:
+ data["pixelnerf_input"] = dict()
+ # Rs = frame_dict["R"].transpose(-1, -2)
+ # Ts = frame_dict["T"]
+ # Rs[:, :, 2] *= -1
+ # Rs[:, :, 0] *= -1
+ # Ts[:, 2] *= -1
+ # Ts[:, 0] *= -1
+ # c2ws = torch.zeros(Rs.shape[0], 4, 4)
+ # c2ws[:, :3, :3] = Rs
+ # c2ws[:, :3, 3] = Ts
+ # c2ws[:, 3, 3] = 1
+ # c2ws = c2ws.inverse()
+ # # c2ws[..., 0] *= -1
+ # # c2ws[..., 2] *= -1
+ # cx = frame_dict["c"][:, 0]
+ # cy = frame_dict["c"][:, 1]
+ # fx = frame_dict["f"][:, 0]
+ # fy = frame_dict["f"][:, 1]
+ # intrinsics = torch.zeros(cx.shape[0], 3, 3)
+ # intrinsics[:, 2, 2] = 1
+ # intrinsics[:, 0, 0] = fx
+ # intrinsics[:, 1, 1] = fy
+ # intrinsics[:, 0, 2] = cx
+ # intrinsics[:, 1, 2] = cy
+
+ scene_cameras = PerspectiveCameras(
+ R=frame_dict["R"],
+ T=frame_dict["T"],
+ focal_length=frame_dict["f"],
+ principal_point=frame_dict["c"],
+ image_size=frame_dict["image_size"],
+ )
+ R, T, intrinsics = opencv_from_cameras_projection(
+ scene_cameras, frame_dict["image_size"]
+ )
+ c2ws = torch.zeros(R.shape[0], 4, 4)
+ c2ws[:, :3, :3] = R
+ c2ws[:, :3, 3] = T
+ c2ws[:, 3, 3] = 1.0
+ c2ws = c2ws.inverse()
+ c2ws[..., 1:3] *= -1
+ intrinsics[:, :2] /= 256
+
+ cameras = torch.zeros(c2ws.shape[0], 25)
+ cameras[..., :16] = c2ws.reshape(-1, 16)
+ cameras[..., 16:] = intrinsics.reshape(-1, 9)
+ if self.scale_pose:
+ c2ws = cameras[..., :16].reshape(-1, 4, 4)
+ center = c2ws[:, :3, 3].mean(0)
+ radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
+ scale = 1.5 / radius
+ c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
+ cameras[..., :16] = c2ws.reshape(-1, 16)
+
+ data["pixelnerf_input"]["frames"] = frames
+ data["pixelnerf_input"]["cameras"] = cameras
+ data["pixelnerf_input"]["rgb"] = (
+ F.interpolate(
+ frames,
+ (self.image_width // 8, self.image_height // 8),
+ mode="bilinear",
+ align_corners=False,
+ )
+ + 1
+ ) * 0.5
+
+ return data
+ # if self.return_frame_data_list:
+ # return (frame_dict, frame_data_list)
+ # return frame_dict
+
+ def collate_fn(self, batch):
+ # a hack to add source index and keep consistent within a batch
+ if self.max_n_cond > 1:
+ # TODO implement this
+ n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
+ # debug
+ # source_index = [0]
+ if n_cond > 1:
+ for b in batch:
+ source_index = [0] + np.random.choice(
+ np.arange(1, self.num_frames),
+ self.max_n_cond - 1,
+ replace=False,
+ ).tolist()
+ b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
+ b["pixelnerf_input"]["n_cond"] = n_cond
+ b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
+ b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
+ "cameras"
+ ][source_index]
+
+ if self.cond_on_multi:
+ b["cond_frames_without_noise"] = b["frames"][source_index]
+
+ ret = video_collate_fn(batch)
+
+ if self.cond_on_multi:
+ ret["cond_frames_without_noise"] = rearrange(
+ ret["cond_frames_without_noise"], "b t ... -> (b t) ..."
+ )
+
+ return ret
+
+ def _get_frame(self, index):
+ # if index >= len(self.frame_annots):
+ # raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
+
+ entry = self.frame_annots[index]["frame_annotation"]
+ # pyre-ignore[16]
+ point_cloud = self.seq_annots[entry.sequence_name].point_cloud
+ frame_data = FrameData(
+ frame_number=_safe_as_tensor(entry.frame_number, torch.long),
+ frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
+ sequence_name=entry.sequence_name,
+ sequence_category=self.seq_annots[entry.sequence_name].category,
+ camera_quality_score=_safe_as_tensor(
+ self.seq_annots[entry.sequence_name].viewpoint_quality_score,
+ torch.float,
+ ),
+ point_cloud_quality_score=_safe_as_tensor(
+ point_cloud.quality_score, torch.float
+ )
+ if point_cloud is not None
+ else None,
+ )
+
+ # The rest of the fields are optional
+ frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
+
+ (
+ frame_data.fg_probability,
+ frame_data.mask_path,
+ frame_data.bbox_xywh,
+ clamp_bbox_xyxy,
+ frame_data.crop_bbox_xywh,
+ ) = self._load_crop_fg_probability(entry)
+
+ scale = 1.0
+ if self.load_images and entry.image is not None:
+ # original image size
+ frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
+
+ (
+ frame_data.image_rgb,
+ frame_data.image_path,
+ frame_data.mask_crop,
+ scale,
+ ) = self._load_crop_images(
+ entry, frame_data.fg_probability, clamp_bbox_xyxy
+ )
+ # print(frame_data.fg_probability.sum())
+ # print('scale', scale)
+
+ #! INSERT
+ if self.deprecated_val_region:
+ # print(frame_data.crop_bbox_xywh)
+ valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float()
+ # print(valid_bbox, frame_data.image_size_hw)
+ valid_bbox[0] = torch.clip(
+ (
+ valid_bbox[0]
+ - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
+ )
+ / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
+ -1.0,
+ 1.0,
+ )
+ valid_bbox[1] = torch.clip(
+ (
+ valid_bbox[1]
+ - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
+ )
+ / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
+ -1.0,
+ 1.0,
+ )
+ valid_bbox[2] = torch.clip(
+ (
+ valid_bbox[2]
+ - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor")
+ )
+ / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"),
+ -1.0,
+ 1.0,
+ )
+ valid_bbox[3] = torch.clip(
+ (
+ valid_bbox[3]
+ - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor")
+ )
+ / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"),
+ -1.0,
+ 1.0,
+ )
+ # print(valid_bbox)
+ frame_data.valid_region = valid_bbox
+ else:
+ #! UPDATED VALID BBOX
+ if self.stage == "train":
+ assert self.image_height == 256 and self.image_width == 256
+ valid = torch.nonzero(frame_data.mask_crop[0])
+ min_y = valid[:, 0].min()
+ min_x = valid[:, 1].min()
+ max_y = valid[:, 0].max()
+ max_x = valid[:, 1].max()
+ valid_bbox = torch.tensor(
+ [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
+ ).unsqueeze(0)
+ valid_bbox = torch.clip(
+ (valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0
+ )
+ frame_data.valid_region = valid_bbox[0]
+ else:
+ valid = torch.nonzero(frame_data.mask_crop[0])
+ min_y = valid[:, 0].min()
+ min_x = valid[:, 1].min()
+ max_y = valid[:, 0].max()
+ max_x = valid[:, 1].max()
+ valid_bbox = torch.tensor(
+ [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device
+ ).unsqueeze(0)
+ valid_bbox = torch.clip(
+ (valid_bbox - (self.image_height // 2)) / (self.image_height // 2),
+ -1.0,
+ 1.0,
+ )
+ frame_data.valid_region = valid_bbox[0]
+
+ #! SET CLASS ONEHOT
+ frame_data.category_one_hot = torch.zeros(
+ (len(self.all_category_list)), device=frame_data.image_rgb.device
+ )
+ frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1
+
+ if self.load_depths and entry.depth is not None:
+ (
+ frame_data.depth_map,
+ frame_data.depth_path,
+ frame_data.depth_mask,
+ ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
+
+ if entry.viewpoint is not None:
+ frame_data.camera = self._get_pytorch3d_camera(
+ entry,
+ scale,
+ clamp_bbox_xyxy,
+ )
+
+ if self.load_point_clouds and point_cloud is not None:
+ frame_data.sequence_point_cloud_path = pcl_path = os.path.join(
+ self.dataset_root, point_cloud.path
+ )
+ frame_data.sequence_point_cloud = _load_pointcloud(
+ self._local_path(pcl_path), max_points=self.max_points
+ )
+
+ # for key in frame_data:
+ # if frame_data[key] == None:
+ # print(key)
+ return frame_data
+
+ def _extract_and_set_eval_batches(self):
+ """
+ Sets eval_batches based on input eval_batch_index.
+ """
+ if self.eval_batch_index is not None:
+ if self.eval_batches is not None:
+ raise ValueError(
+ "Cannot define both eval_batch_index and eval_batches."
+ )
+ self.eval_batches = self.seq_frame_index_to_dataset_index(
+ self.eval_batch_index
+ )
+
+ def _load_crop_fg_probability(
+ self, entry: types.FrameAnnotation
+ ) -> Tuple[
+ Optional[torch.Tensor],
+ Optional[str],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ ]:
+ fg_probability = None
+ full_path = None
+ bbox_xywh = None
+ clamp_bbox_xyxy = None
+ crop_box_xywh = None
+
+ if (self.load_masks or self.box_crop) and entry.mask is not None:
+ full_path = os.path.join(self.dataset_root, entry.mask.path)
+ mask = _load_mask(self._local_path(full_path))
+
+ if mask.shape[-2:] != entry.image.size:
+ raise ValueError(
+ f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
+ )
+
+ bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
+
+ if self.box_crop:
+ clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
+ _get_clamp_bbox(
+ bbox_xywh,
+ image_path=entry.image.path,
+ box_crop_context=self.box_crop_context,
+ ),
+ image_size_hw=tuple(mask.shape[-2:]),
+ )
+ crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
+
+ mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
+
+ fg_probability, _, _ = self._resize_image(mask, mode="nearest")
+
+ return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
+
+ def _load_crop_images(
+ self,
+ entry: types.FrameAnnotation,
+ fg_probability: Optional[torch.Tensor],
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
+ assert self.dataset_root is not None and entry.image is not None
+ path = os.path.join(self.dataset_root, entry.image.path)
+ image_rgb = _load_image(self._local_path(path))
+
+ if image_rgb.shape[-2:] != entry.image.size:
+ raise ValueError(
+ f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
+ )
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
+
+ image_rgb, scale, mask_crop = self._resize_image(image_rgb)
+
+ if self.mask_images:
+ assert fg_probability is not None
+ image_rgb *= fg_probability
+
+ return image_rgb, path, mask_crop, scale
+
+ def _load_mask_depth(
+ self,
+ entry: types.FrameAnnotation,
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ fg_probability: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, str, torch.Tensor]:
+ entry_depth = entry.depth
+ assert entry_depth is not None
+ path = os.path.join(self.dataset_root, entry_depth.path)
+ depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ depth_bbox_xyxy = _rescale_bbox(
+ clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
+ )
+ depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
+
+ depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
+
+ if self.mask_depths:
+ assert fg_probability is not None
+ depth_map *= fg_probability
+
+ if self.load_depth_masks:
+ assert entry_depth.mask_path is not None
+ mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
+ depth_mask = _load_depth_mask(self._local_path(mask_path))
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ depth_mask_bbox_xyxy = _rescale_bbox(
+ clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
+ )
+ depth_mask = _crop_around_box(
+ depth_mask, depth_mask_bbox_xyxy, mask_path
+ )
+
+ depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
+ else:
+ depth_mask = torch.ones_like(depth_map)
+
+ return depth_map, path, depth_mask
+
+ def _get_pytorch3d_camera(
+ self,
+ entry: types.FrameAnnotation,
+ scale: float,
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ ) -> PerspectiveCameras:
+ entry_viewpoint = entry.viewpoint
+ assert entry_viewpoint is not None
+ # principal point and focal length
+ principal_point = torch.tensor(
+ entry_viewpoint.principal_point, dtype=torch.float
+ )
+ focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
+
+ half_image_size_wh_orig = (
+ torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
+ )
+
+ # first, we convert from the dataset's NDC convention to pixels
+ format = entry_viewpoint.intrinsics_format
+ if format.lower() == "ndc_norm_image_bounds":
+ # this is e.g. currently used in CO3D for storing intrinsics
+ rescale = half_image_size_wh_orig
+ elif format.lower() == "ndc_isotropic":
+ rescale = half_image_size_wh_orig.min()
+ else:
+ raise ValueError(f"Unknown intrinsics format: {format}")
+
+ # principal point and focal length in pixels
+ principal_point_px = half_image_size_wh_orig - principal_point * rescale
+ focal_length_px = focal_length * rescale
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ principal_point_px -= clamp_bbox_xyxy[:2]
+
+ # now, convert from pixels to PyTorch3D v0.5+ NDC convention
+ if self.image_height is None or self.image_width is None:
+ out_size = list(reversed(entry.image.size))
+ else:
+ out_size = [self.image_width, self.image_height]
+
+ half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
+ half_min_image_size_output = half_image_size_output.min()
+
+ # rescaled principal point and focal length in ndc
+ principal_point = (
+ half_image_size_output - principal_point_px * scale
+ ) / half_min_image_size_output
+ focal_length = focal_length_px * scale / half_min_image_size_output
+
+ return PerspectiveCameras(
+ focal_length=focal_length[None],
+ principal_point=principal_point[None],
+ R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
+ T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
+ )
+
+ def _load_frames(self) -> None:
+ self.frame_annots = [
+ FrameAnnotsEntry(frame_annotation=a, subset=None)
+ for a in self.category_frame_annotations
+ ]
+
+ def _load_sequences(self) -> None:
+ self.seq_annots = {
+ entry.sequence_name: entry for entry in self.category_sequence_annotations
+ }
+
+ def _load_subset_lists(self) -> None:
+ logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
+ if not self.subset_lists_file:
+ return
+
+ frame_path_to_subset = {}
+
+ for subset_list_file in self.subset_lists_file:
+ with open(self._local_path(subset_list_file), "r") as f:
+ subset_to_seq_frame = json.load(f)
+
+ #! PRINT SUBSET_LIST STATS
+ # if len(self.subset_lists_file) == 1:
+ # print('train frames', len(subset_to_seq_frame['train']))
+ # print('val frames', len(subset_to_seq_frame['val']))
+ # print('test frames', len(subset_to_seq_frame['test']))
+
+ for set_ in subset_to_seq_frame:
+ for _, _, path in subset_to_seq_frame[set_]:
+ if path in frame_path_to_subset:
+ frame_path_to_subset[path].add(set_)
+ else:
+ frame_path_to_subset[path] = {set_}
+
+ # pyre-ignore[16]
+ for frame in self.frame_annots:
+ frame["subset"] = frame_path_to_subset.get(
+ frame["frame_annotation"].image.path, None
+ )
+
+ if frame["subset"] is None:
+ continue
+ warnings.warn(
+ "Subset lists are given but don't include "
+ + frame["frame_annotation"].image.path
+ )
+
+ def _sort_frames(self) -> None:
+ # Sort frames to have them grouped by sequence, ordered by timestamp
+ # pyre-ignore[16]
+ self.frame_annots = sorted(
+ self.frame_annots,
+ key=lambda f: (
+ f["frame_annotation"].sequence_name,
+ f["frame_annotation"].frame_timestamp or 0,
+ ),
+ )
+
+ def _filter_db(self) -> None:
+ if self.remove_empty_masks:
+ logger.info("Removing images with empty masks.")
+ # pyre-ignore[16]
+ old_len = len(self.frame_annots)
+
+ msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
+
+ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
+ mask = frame_annot.mask
+ if mask is None:
+ return False
+ if mask.mass is None:
+ raise ValueError(msg)
+ return mask.mass > 1
+
+ self.frame_annots = [
+ frame
+ for frame in self.frame_annots
+ if positive_mass(frame["frame_annotation"])
+ ]
+ logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
+
+ # this has to be called after joining with categories!!
+ subsets = self.subsets
+ if subsets:
+ if not self.subset_lists_file:
+ raise ValueError(
+ "Subset filter is on but subset_lists_file was not given"
+ )
+
+ logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
+
+ # truncate the list of subsets to the valid one
+ self.frame_annots = [
+ entry
+ for entry in self.frame_annots
+ if (entry["subset"] is not None and self.stage in entry["subset"])
+ ]
+
+ if len(self.frame_annots) == 0:
+ raise ValueError(f"There are no frames in the '{subsets}' subsets!")
+
+ self._invalidate_indexes(filter_seq_annots=True)
+
+ if len(self.limit_category_to) > 0:
+ logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
+ # pyre-ignore[16]
+ self.seq_annots = {
+ name: entry
+ for name, entry in self.seq_annots.items()
+ if entry.category in self.limit_category_to
+ }
+
+ # sequence filters
+ for prefix in ("pick", "exclude"):
+ orig_len = len(self.seq_annots)
+ attr = f"{prefix}_sequence"
+ arr = getattr(self, attr)
+ if len(arr) > 0:
+ logger.info(f"{attr}: {str(arr)}")
+ self.seq_annots = {
+ name: entry
+ for name, entry in self.seq_annots.items()
+ if (name in arr) == (prefix == "pick")
+ }
+ logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
+
+ if self.limit_sequences_to > 0:
+ self.seq_annots = dict(
+ islice(self.seq_annots.items(), self.limit_sequences_to)
+ )
+
+ # retain only frames from retained sequences
+ self.frame_annots = [
+ f
+ for f in self.frame_annots
+ if f["frame_annotation"].sequence_name in self.seq_annots
+ ]
+
+ self._invalidate_indexes()
+
+ if self.n_frames_per_sequence > 0:
+ logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
+ keep_idx = []
+ # pyre-ignore[16]
+ for seq, seq_indices in self._seq_to_idx.items():
+ # infer the seed from the sequence name, this is reproducible
+ # and makes the selection differ for different sequences
+ seed = _seq_name_to_seed(seq) + self.seed
+ seq_idx_shuffled = random.Random(seed).sample(
+ sorted(seq_indices), len(seq_indices)
+ )
+ keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
+
+ logger.info(
+ "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
+ )
+ self.frame_annots = [self.frame_annots[i] for i in keep_idx]
+ self._invalidate_indexes(filter_seq_annots=False)
+ # sequences are not decimated, so self.seq_annots is valid
+
+ if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
+ logger.info(
+ "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
+ )
+ self.frame_annots = self.frame_annots[: self.limit_to]
+ self._invalidate_indexes(filter_seq_annots=True)
+
+ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
+ # update _seq_to_idx and filter seq_meta according to frame_annots change
+ # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
+ self._invalidate_seq_to_idx()
+
+ if filter_seq_annots:
+ # pyre-ignore[16]
+ self.seq_annots = {
+ k: v
+ for k, v in self.seq_annots.items()
+ # pyre-ignore[16]
+ if k in self._seq_to_idx
+ }
+
+ def _invalidate_seq_to_idx(self) -> None:
+ seq_to_idx = defaultdict(list)
+ # pyre-ignore[16]
+ for idx, entry in enumerate(self.frame_annots):
+ seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
+ # pyre-ignore[16]
+ self._seq_to_idx = seq_to_idx
+
+ def _resize_image(
+ self, image, mode="bilinear"
+ ) -> Tuple[torch.Tensor, float, torch.Tensor]:
+ image_height, image_width = self.image_height, self.image_width
+ if image_height is None or image_width is None:
+ # skip the resizing
+ imre_ = torch.from_numpy(image)
+ return imre_, 1.0, torch.ones_like(imre_[:1])
+ # takes numpy array, returns pytorch tensor
+ minscale = min(
+ image_height / image.shape[-2],
+ image_width / image.shape[-1],
+ )
+ imre = torch.nn.functional.interpolate(
+ torch.from_numpy(image)[None],
+ scale_factor=minscale,
+ mode=mode,
+ align_corners=False if mode == "bilinear" else None,
+ recompute_scale_factor=True,
+ )[0]
+ # pyre-fixme[19]: Expected 1 positional argument.
+ imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
+ imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
+ # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
+ # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
+ mask = torch.zeros(1, self.image_height, self.image_width)
+ mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
+ return imre_, minscale, mask
+
+ def _local_path(self, path: str) -> str:
+ if self.path_manager is None:
+ return path
+ return self.path_manager.get_local_path(path)
+
+ def get_frame_numbers_and_timestamps(
+ self, idxs: Sequence[int]
+ ) -> List[Tuple[int, float]]:
+ out: List[Tuple[int, float]] = []
+ for idx in idxs:
+ # pyre-ignore[16]
+ frame_annotation = self.frame_annots[idx]["frame_annotation"]
+ out.append(
+ (frame_annotation.frame_number, frame_annotation.frame_timestamp)
+ )
+ return out
+
+ def get_eval_batches(self) -> Optional[List[List[int]]]:
+ return self.eval_batches
+
+ def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
+ return entry["frame_annotation"].meta["frame_type"]
+
+
+class CO3DDataset(LightningDataModule):
+ def __init__(
+ self,
+ root_dir,
+ batch_size=2,
+ shuffle=True,
+ num_workers=10,
+ prefetch_factor=2,
+ category="hydrant",
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor
+ self.shuffle = shuffle
+
+ self.train_dataset = CO3Dv2Wrapper(
+ root_dir=root_dir,
+ stage="train",
+ category=category,
+ **kwargs,
+ )
+
+ self.test_dataset = CO3Dv2Wrapper(
+ root_dir=root_dir,
+ stage="test",
+ subset="fewview_dev",
+ category=category,
+ **kwargs,
+ )
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=self.train_dataset.collate_fn,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=self.test_dataset.collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=video_collate_fn,
+ )
diff --git a/sgm/data/colmap.py b/sgm/data/colmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..b739f2e9637c0c96b80c42fce05dfeab6c5e1228
--- /dev/null
+++ b/sgm/data/colmap.py
@@ -0,0 +1,605 @@
+# Copyright (c) 2023, ETH Zurich and UNC Chapel Hill.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+
+import os
+import collections
+import numpy as np
+import struct
+import argparse
+
+
+CameraModel = collections.namedtuple(
+ "CameraModel", ["model_id", "model_name", "num_params"]
+)
+Camera = collections.namedtuple(
+ "Camera", ["id", "model", "width", "height", "params"]
+)
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
+)
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
+)
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
+}
+CAMERA_MODEL_IDS = dict(
+ [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
+)
+CAMERA_MODEL_NAMES = dict(
+ [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS]
+)
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model,
+ width=width,
+ height=height,
+ params=params,
+ )
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(
+ fid, num_bytes=24, format_char_sequence="iiQQ"
+ )
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(
+ fid,
+ num_bytes=8 * num_params,
+ format_char_sequence="d" * num_params,
+ )
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params),
+ )
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = (
+ "# Camera list with one line of data per camera:\n"
+ + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
+ + "# Number of cameras: {}\n".format(len(cameras))
+ )
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id, model_id, cam.width, cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack(
+ [
+ tuple(map(float, elems[0::3])),
+ tuple(map(float, elems[1::3])),
+ ]
+ )
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi"
+ )
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ binary_image_name = b""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ binary_image_name += current_char
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ image_name = binary_image_name.decode("utf-8")
+ num_points2D = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q"
+ )[0]
+ x_y_id_s = read_next_bytes(
+ fid,
+ num_bytes=24 * num_points2D,
+ format_char_sequence="ddq" * num_points2D,
+ )
+ xys = np.column_stack(
+ [
+ tuple(map(float, x_y_id_s[0::3])),
+ tuple(map(float, x_y_id_s[1::3])),
+ ]
+ )
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum(
+ (len(img.point3D_ids) for _, img in images.items())
+ ) / len(images)
+ HEADER = (
+ "# Image list with two lines of data per image:\n"
+ + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
+ + "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
+ + "# Number of images: {}, mean observations per image: {}\n".format(
+ len(images), mean_observations
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [
+ img.id,
+ *img.qvec,
+ *img.tvec,
+ img.camera_id,
+ img.name,
+ ]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path, "r") as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
+ )
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(
+ fid, num_bytes=8, format_char_sequence="Q"
+ )[0]
+ track_elems = read_next_bytes(
+ fid,
+ num_bytes=8 * track_length,
+ format_char_sequence="ii" * track_length,
+ )
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum(
+ (len(pt.image_ids) for _, pt in points3D.items())
+ ) / len(points3D)
+ HEADER = (
+ "# 3D point list with one line of data per point:\n"
+ + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
+ + "# Number of points: {}, mean track length: {}\n".format(
+ len(points3D), mean_track_length
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3D_binary(points3D, path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def detect_model_format(path, ext):
+ if (
+ os.path.isfile(os.path.join(path, "cameras" + ext))
+ and os.path.isfile(os.path.join(path, "images" + ext))
+ and os.path.isfile(os.path.join(path, "points3D" + ext))
+ ):
+ print("Detected model format: '" + ext + "'")
+ return True
+
+ return False
+
+
+def read_model(path, ext=""):
+ # try to detect the extension automatically
+ if ext == "":
+ if detect_model_format(path, ".bin"):
+ ext = ".bin"
+ elif detect_model_format(path, ".txt"):
+ ext = ".txt"
+ else:
+ print("Provide model format: '.bin' or '.txt'")
+ return
+
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext=".bin"):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array(
+ [
+ [
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
+ ],
+ [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
+ ],
+ [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
+ ],
+ ]
+ )
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = (
+ np.array(
+ [
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
+ ]
+ )
+ / 3.0
+ )
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Read and write COLMAP binary and text models"
+ )
+ parser.add_argument("--input_model", help="path to input model folder")
+ parser.add_argument(
+ "--input_format",
+ choices=[".bin", ".txt"],
+ help="input model format",
+ default="",
+ )
+ parser.add_argument("--output_model", help="path to output model folder")
+ parser.add_argument(
+ "--output_format",
+ choices=[".bin", ".txt"],
+ help="outut model format",
+ default=".txt",
+ )
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(
+ path=args.input_model, ext=args.input_format
+ )
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(
+ cameras,
+ images,
+ points3D,
+ path=args.output_model,
+ ext=args.output_format,
+ )
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b726149996591c6c3db69230e1bb68c07d2faa12
--- /dev/null
+++ b/sgm/data/dataset.py
@@ -0,0 +1,80 @@
+from typing import Optional
+
+import torchdata.datapipes.iter
+import webdataset as wds
+from omegaconf import DictConfig
+from pytorch_lightning import LightningDataModule
+
+try:
+ from sdata import create_dataset, create_dummy_dataset, create_loader
+except ImportError as e:
+ print("#" * 100)
+ print("Datasets not yet available")
+ print("to enable, we need to add stable-datasets as a submodule")
+ print("please use ``git submodule update --init --recursive``")
+ print("and do ``pip install -e stable-datasets/`` from the root of this repo")
+ print("#" * 100)
+ exit(1)
+
+
+class StableDataModuleFromConfig(LightningDataModule):
+ def __init__(
+ self,
+ train: DictConfig,
+ validation: Optional[DictConfig] = None,
+ test: Optional[DictConfig] = None,
+ skip_val_loader: bool = False,
+ dummy: bool = False,
+ ):
+ super().__init__()
+ self.train_config = train
+ assert (
+ "datapipeline" in self.train_config and "loader" in self.train_config
+ ), "train config requires the fields `datapipeline` and `loader`"
+
+ self.val_config = validation
+ if not skip_val_loader:
+ if self.val_config is not None:
+ assert (
+ "datapipeline" in self.val_config and "loader" in self.val_config
+ ), "validation config requires the fields `datapipeline` and `loader`"
+ else:
+ print(
+ "Warning: No Validation datapipeline defined, using that one from training"
+ )
+ self.val_config = train
+
+ self.test_config = test
+ if self.test_config is not None:
+ assert (
+ "datapipeline" in self.test_config and "loader" in self.test_config
+ ), "test config requires the fields `datapipeline` and `loader`"
+
+ self.dummy = dummy
+ if self.dummy:
+ print("#" * 100)
+ print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
+ print("#" * 100)
+
+ def setup(self, stage: str) -> None:
+ print("Preparing datasets")
+ if self.dummy:
+ data_fn = create_dummy_dataset
+ else:
+ data_fn = create_dataset
+
+ self.train_datapipeline = data_fn(**self.train_config.datapipeline)
+ if self.val_config:
+ self.val_datapipeline = data_fn(**self.val_config.datapipeline)
+ if self.test_config:
+ self.test_datapipeline = data_fn(**self.test_config.datapipeline)
+
+ def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
+ loader = create_loader(self.train_datapipeline, **self.train_config.loader)
+ return loader
+
+ def val_dataloader(self) -> wds.DataPipeline:
+ return create_loader(self.val_datapipeline, **self.val_config.loader)
+
+ def test_dataloader(self) -> wds.DataPipeline:
+ return create_loader(self.test_datapipeline, **self.test_config.loader)
diff --git a/sgm/data/joint3d.py b/sgm/data/joint3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..0569210466a2391bdbb3be358c5cd8f8477aeba1
--- /dev/null
+++ b/sgm/data/joint3d.py
@@ -0,0 +1,10 @@
+import torch
+from torch.utils.data import Dataset
+
+default_sub_data_config = {}
+
+
+class Joint3D(Dataset):
+ def __init__(self, sub_data_config: dict) -> None:
+ super().__init__()
+ self.sub_data_config = sub_data_config
diff --git a/sgm/data/json_index_dataset.py b/sgm/data/json_index_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..16f1dbf3bbae4fb6861f45703d1493914ffaf791
--- /dev/null
+++ b/sgm/data/json_index_dataset.py
@@ -0,0 +1,1080 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+import functools
+import gzip
+import hashlib
+import json
+import logging
+import os
+import random
+import warnings
+from collections import defaultdict
+from itertools import islice
+from pathlib import Path
+from typing import (
+ Any,
+ ClassVar,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TYPE_CHECKING,
+ Union,
+)
+
+import numpy as np
+import torch
+from PIL import Image
+from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
+from pytorch3d.io import IO
+from pytorch3d.renderer.camera_utils import join_cameras_as_batch
+from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
+from pytorch3d.structures.pointclouds import Pointclouds
+from tqdm import tqdm
+
+from pytorch3d.implicitron.dataset import types
+from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
+from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar
+
+
+logger = logging.getLogger(__name__)
+
+
+if TYPE_CHECKING:
+ from typing import TypedDict
+
+ class FrameAnnotsEntry(TypedDict):
+ subset: Optional[str]
+ frame_annotation: types.FrameAnnotation
+
+else:
+ FrameAnnotsEntry = dict
+
+
+@registry.register
+class JsonIndexDataset(DatasetBase, ReplaceableBase):
+ """
+ A dataset with annotations in json files like the Common Objects in 3D
+ (CO3D) dataset.
+
+ Args:
+ frame_annotations_file: A zipped json file containing metadata of the
+ frames in the dataset, serialized List[types.FrameAnnotation].
+ sequence_annotations_file: A zipped json file containing metadata of the
+ sequences in the dataset, serialized List[types.SequenceAnnotation].
+ subset_lists_file: A json file containing the lists of frames corresponding
+ corresponding to different subsets (e.g. train/val/test) of the dataset;
+ format: {subset: (sequence_name, frame_id, file_path)}.
+ subsets: Restrict frames/sequences only to the given list of subsets
+ as defined in subset_lists_file (see above).
+ limit_to: Limit the dataset to the first #limit_to frames (after other
+ filters have been applied).
+ limit_sequences_to: Limit the dataset to the first
+ #limit_sequences_to sequences (after other sequence filters have been
+ applied but before frame-based filters).
+ pick_sequence: A list of sequence names to restrict the dataset to.
+ exclude_sequence: A list of the names of the sequences to exclude.
+ limit_category_to: Restrict the dataset to the given list of categories.
+ dataset_root: The root folder of the dataset; all the paths in jsons are
+ specified relative to this root (but not json paths themselves).
+ load_images: Enable loading the frame RGB data.
+ load_depths: Enable loading the frame depth maps.
+ load_depth_masks: Enable loading the frame depth map masks denoting the
+ depth values used for evaluation (the points consistent across views).
+ load_masks: Enable loading frame foreground masks.
+ load_point_clouds: Enable loading sequence-level point clouds.
+ max_points: Cap on the number of loaded points in the point cloud;
+ if reached, they are randomly sampled without replacement.
+ mask_images: Whether to mask the images with the loaded foreground masks;
+ 0 value is used for background.
+ mask_depths: Whether to mask the depth maps with the loaded foreground
+ masks; 0 value is used for background.
+ image_height: The height of the returned images, masks, and depth maps;
+ aspect ratio is preserved during cropping/resizing.
+ image_width: The width of the returned images, masks, and depth maps;
+ aspect ratio is preserved during cropping/resizing.
+ box_crop: Enable cropping of the image around the bounding box inferred
+ from the foreground region of the loaded segmentation mask; masks
+ and depth maps are cropped accordingly; cameras are corrected.
+ box_crop_mask_thr: The threshold used to separate pixels into foreground
+ and background based on the foreground_probability mask; if no value
+ is greater than this threshold, the loader lowers it and repeats.
+ box_crop_context: The amount of additional padding added to each
+ dimension of the cropping bounding box, relative to box size.
+ remove_empty_masks: Removes the frames with no active foreground pixels
+ in the segmentation mask after thresholding (see box_crop_mask_thr).
+ n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
+ frames in each sequences uniformly without replacement if it has
+ more frames than that; applied before other frame-level filters.
+ seed: The seed of the random generator sampling #n_frames_per_sequence
+ random frames per sequence.
+ sort_frames: Enable frame annotations sorting to group frames from the
+ same sequences together and order them by timestamps
+ eval_batches: A list of batches that form the evaluation set;
+ list of batch-sized lists of indices corresponding to __getitem__
+ of this class, thus it can be used directly as a batch sampler.
+ eval_batch_index:
+ ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
+ A list of batches of frames described as (sequence_name, frame_idx)
+ that can form the evaluation set, `eval_batches` will be set from this.
+
+ """
+
+ frame_annotations_type: ClassVar[
+ Type[types.FrameAnnotation]
+ ] = types.FrameAnnotation
+
+ path_manager: Any = None
+ frame_annotations_file: str = ""
+ sequence_annotations_file: str = ""
+ subset_lists_file: str = ""
+ subsets: Optional[List[str]] = None
+ limit_to: int = 0
+ limit_sequences_to: int = 0
+ pick_sequence: Tuple[str, ...] = ()
+ exclude_sequence: Tuple[str, ...] = ()
+ limit_category_to: Tuple[int, ...] = ()
+ dataset_root: str = ""
+ load_images: bool = True
+ load_depths: bool = True
+ load_depth_masks: bool = True
+ load_masks: bool = True
+ load_point_clouds: bool = False
+ max_points: int = 0
+ mask_images: bool = False
+ mask_depths: bool = False
+ image_height: Optional[int] = 800
+ image_width: Optional[int] = 800
+ box_crop: bool = True
+ box_crop_mask_thr: float = 0.4
+ box_crop_context: float = 0.3
+ remove_empty_masks: bool = True
+ n_frames_per_sequence: int = -1
+ seed: int = 0
+ sort_frames: bool = False
+ eval_batches: Any = None
+ eval_batch_index: Any = None
+ # frame_annots: List[FrameAnnotsEntry] = field(init=False)
+ # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
+
+ def __post_init__(self) -> None:
+ # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
+ self.subset_to_image_path = None
+ self._load_frames()
+ self._load_sequences()
+ if self.sort_frames:
+ self._sort_frames()
+ self._load_subset_lists()
+ self._filter_db() # also computes sequence indices
+ self._extract_and_set_eval_batches()
+ logger.info(str(self))
+
+ def _extract_and_set_eval_batches(self):
+ """
+ Sets eval_batches based on input eval_batch_index.
+ """
+ if self.eval_batch_index is not None:
+ if self.eval_batches is not None:
+ raise ValueError(
+ "Cannot define both eval_batch_index and eval_batches."
+ )
+ self.eval_batches = self.seq_frame_index_to_dataset_index(
+ self.eval_batch_index
+ )
+
+ def join(self, other_datasets: Iterable[DatasetBase]) -> None:
+ """
+ Join the dataset with other JsonIndexDataset objects.
+
+ Args:
+ other_datasets: A list of JsonIndexDataset objects to be joined
+ into the current dataset.
+ """
+ if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
+ raise ValueError("This function can only join a list of JsonIndexDataset")
+ # pyre-ignore[16]
+ self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
+ # pyre-ignore[16]
+ self.seq_annots.update(
+ # https://gist.github.com/treyhunner/f35292e676efa0be1728
+ functools.reduce(
+ lambda a, b: {**a, **b},
+ [d.seq_annots for d in other_datasets], # pyre-ignore[16]
+ )
+ )
+ all_eval_batches = [
+ self.eval_batches,
+ # pyre-ignore
+ *[d.eval_batches for d in other_datasets],
+ ]
+ if not (
+ all(ba is None for ba in all_eval_batches)
+ or all(ba is not None for ba in all_eval_batches)
+ ):
+ raise ValueError(
+ "When joining datasets, either all joined datasets have to have their"
+ " eval_batches defined, or all should have their eval batches undefined."
+ )
+ if self.eval_batches is not None:
+ self.eval_batches = sum(all_eval_batches, [])
+ self._invalidate_indexes(filter_seq_annots=True)
+
+ def is_filtered(self) -> bool:
+ """
+ Returns `True` in case the dataset has been filtered and thus some frame annotations
+ stored on the disk might be missing in the dataset object.
+
+ Returns:
+ is_filtered: `True` if the dataset has been filtered, else `False`.
+ """
+ return (
+ self.remove_empty_masks
+ or self.limit_to > 0
+ or self.limit_sequences_to > 0
+ or len(self.pick_sequence) > 0
+ or len(self.exclude_sequence) > 0
+ or len(self.limit_category_to) > 0
+ or self.n_frames_per_sequence > 0
+ )
+
+ def seq_frame_index_to_dataset_index(
+ self,
+ seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
+ allow_missing_indices: bool = False,
+ remove_missing_indices: bool = False,
+ suppress_missing_index_warning: bool = True,
+ ) -> List[List[Union[Optional[int], int]]]:
+ """
+ Obtain indices into the dataset object given a list of frame ids.
+
+ Args:
+ seq_frame_index: The list of frame ids specified as
+ `List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
+ Image paths relative to the dataset_root can be stored specified as well:
+ `List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
+ allow_missing_indices: If `False`, throws an IndexError upon reaching the first
+ entry from `seq_frame_index` which is missing in the dataset.
+ Otherwise, depending on `remove_missing_indices`, either returns `None`
+ in place of missing entries or removes the indices of missing entries.
+ remove_missing_indices: Active when `allow_missing_indices=True`.
+ If `False`, returns `None` in place of `seq_frame_index` entries that
+ are not present in the dataset.
+ If `True` removes missing indices from the returned indices.
+ suppress_missing_index_warning:
+ Active if `allow_missing_indices==True`. Suppressess a warning message
+ in case an entry from `seq_frame_index` is missing in the dataset
+ (expected in certain cases - e.g. when setting
+ `self.remove_empty_masks=True`).
+
+ Returns:
+ dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
+ """
+ _dataset_seq_frame_n_index = {
+ seq: {
+ # pyre-ignore[16]
+ self.frame_annots[idx]["frame_annotation"].frame_number: idx
+ for idx in seq_idx
+ }
+ # pyre-ignore[16]
+ for seq, seq_idx in self._seq_to_idx.items()
+ }
+
+ def _get_dataset_idx(
+ seq_name: str, frame_no: int, path: Optional[str] = None
+ ) -> Optional[int]:
+ idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
+ idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
+ if idx is None:
+ msg = (
+ f"sequence_name={seq_name} / frame_number={frame_no}"
+ " not in the dataset!"
+ )
+ if not allow_missing_indices:
+ raise IndexError(msg)
+ if not suppress_missing_index_warning:
+ warnings.warn(msg)
+ return idx
+ if path is not None:
+ # Check that the loaded frame path is consistent
+ # with the one stored in self.frame_annots.
+ assert os.path.normpath(
+ # pyre-ignore[16]
+ self.frame_annots[idx]["frame_annotation"].image.path
+ ) == os.path.normpath(
+ path
+ ), f"Inconsistent frame indices {seq_name, frame_no, path}."
+ return idx
+
+ dataset_idx = [
+ [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
+ for batch in seq_frame_index
+ ]
+
+ if allow_missing_indices and remove_missing_indices:
+ # remove all None indices, and also batches with only None entries
+ valid_dataset_idx = [
+ [b for b in batch if b is not None] for batch in dataset_idx
+ ]
+ return [ # pyre-ignore[7]
+ batch for batch in valid_dataset_idx if len(batch) > 0
+ ]
+
+ return dataset_idx
+
+ def subset_from_frame_index(
+ self,
+ frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
+ allow_missing_indices: bool = True,
+ ) -> "JsonIndexDataset":
+ """
+ Generate a dataset subset given the list of frames specified in `frame_index`.
+
+ Args:
+ frame_index: The list of frame indentifiers (as stored in the metadata)
+ specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
+ Image paths relative to the dataset_root can be stored specified as well:
+ `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
+ in the latter case, if imaga_path do not match the stored paths, an error
+ is raised.
+ allow_missing_indices: If `False`, throws an IndexError upon reaching the first
+ entry from `frame_index` which is missing in the dataset.
+ Otherwise, generates a subset consisting of frames entries that actually
+ exist in the dataset.
+ """
+ # Get the indices into the frame annots.
+ dataset_indices = self.seq_frame_index_to_dataset_index(
+ [frame_index],
+ allow_missing_indices=self.is_filtered() and allow_missing_indices,
+ )[0]
+ valid_dataset_indices = [i for i in dataset_indices if i is not None]
+
+ # Deep copy the whole dataset except frame_annots, which are large so we
+ # deep copy only the requested subset of frame_annots.
+ memo = {id(self.frame_annots): None} # pyre-ignore[16]
+ dataset_new = copy.deepcopy(self, memo)
+ dataset_new.frame_annots = copy.deepcopy(
+ [self.frame_annots[i] for i in valid_dataset_indices]
+ )
+
+ # This will kill all unneeded sequence annotations.
+ dataset_new._invalidate_indexes(filter_seq_annots=True)
+
+ # Finally annotate the frame annotations with the name of the subset
+ # stored in meta.
+ for frame_annot in dataset_new.frame_annots:
+ frame_annotation = frame_annot["frame_annotation"]
+ if frame_annotation.meta is not None:
+ frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
+
+ # A sanity check - this will crash in case some entries from frame_index are missing
+ # in dataset_new.
+ valid_frame_index = [
+ fi for fi, di in zip(frame_index, dataset_indices) if di is not None
+ ]
+ dataset_new.seq_frame_index_to_dataset_index(
+ [valid_frame_index], allow_missing_indices=False
+ )
+
+ return dataset_new
+
+ def __str__(self) -> str:
+ # pyre-ignore[16]
+ return f"JsonIndexDataset #frames={len(self.frame_annots)}"
+
+ def __len__(self) -> int:
+ # pyre-ignore[16]
+ return len(self.frame_annots)
+
+ def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
+ return entry["subset"]
+
+ def get_all_train_cameras(self) -> CamerasBase:
+ """
+ Returns the cameras corresponding to all the known frames.
+ """
+ logger.info("Loading all train cameras.")
+ cameras = []
+ # pyre-ignore[16]
+ for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)):
+ frame_type = self._get_frame_type(frame_annot)
+ if frame_type is None:
+ raise ValueError("subsets not loaded")
+ if is_known_frame_scalar(frame_type):
+ cameras.append(self[frame_idx].camera)
+ return join_cameras_as_batch(cameras)
+
+ def __getitem__(self, index) -> FrameData:
+ # pyre-ignore[16]
+ if index >= len(self.frame_annots):
+ raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
+
+ entry = self.frame_annots[index]["frame_annotation"]
+ # pyre-ignore[16]
+ point_cloud = self.seq_annots[entry.sequence_name].point_cloud
+ frame_data = FrameData(
+ frame_number=_safe_as_tensor(entry.frame_number, torch.long),
+ frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
+ sequence_name=entry.sequence_name,
+ sequence_category=self.seq_annots[entry.sequence_name].category,
+ camera_quality_score=_safe_as_tensor(
+ self.seq_annots[entry.sequence_name].viewpoint_quality_score,
+ torch.float,
+ ),
+ point_cloud_quality_score=_safe_as_tensor(
+ point_cloud.quality_score, torch.float
+ )
+ if point_cloud is not None
+ else None,
+ )
+
+ # The rest of the fields are optional
+ frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
+
+ (
+ frame_data.fg_probability,
+ frame_data.mask_path,
+ frame_data.bbox_xywh,
+ clamp_bbox_xyxy,
+ frame_data.crop_bbox_xywh,
+ ) = self._load_crop_fg_probability(entry)
+
+ scale = 1.0
+ if self.load_images and entry.image is not None:
+ # original image size
+ frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
+
+ (
+ frame_data.image_rgb,
+ frame_data.image_path,
+ frame_data.mask_crop,
+ scale,
+ ) = self._load_crop_images(
+ entry, frame_data.fg_probability, clamp_bbox_xyxy
+ )
+
+ if self.load_depths and entry.depth is not None:
+ (
+ frame_data.depth_map,
+ frame_data.depth_path,
+ frame_data.depth_mask,
+ ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
+
+ if entry.viewpoint is not None:
+ frame_data.camera = self._get_pytorch3d_camera(
+ entry,
+ scale,
+ clamp_bbox_xyxy,
+ )
+
+ if self.load_point_clouds and point_cloud is not None:
+ pcl_path = self._fix_point_cloud_path(point_cloud.path)
+ frame_data.sequence_point_cloud = _load_pointcloud(
+ self._local_path(pcl_path), max_points=self.max_points
+ )
+ frame_data.sequence_point_cloud_path = pcl_path
+
+ return frame_data
+
+ def _fix_point_cloud_path(self, path: str) -> str:
+ """
+ Fix up a point cloud path from the dataset.
+ Some files in Co3Dv2 have an accidental absolute path stored.
+ """
+ unwanted_prefix = (
+ "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
+ )
+ if path.startswith(unwanted_prefix):
+ path = path[len(unwanted_prefix) :]
+ return os.path.join(self.dataset_root, path)
+
+ def _load_crop_fg_probability(
+ self, entry: types.FrameAnnotation
+ ) -> Tuple[
+ Optional[torch.Tensor],
+ Optional[str],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ Optional[torch.Tensor],
+ ]:
+ fg_probability = None
+ full_path = None
+ bbox_xywh = None
+ clamp_bbox_xyxy = None
+ crop_box_xywh = None
+
+ if (self.load_masks or self.box_crop) and entry.mask is not None:
+ full_path = os.path.join(self.dataset_root, entry.mask.path)
+ mask = _load_mask(self._local_path(full_path))
+
+ if mask.shape[-2:] != entry.image.size:
+ raise ValueError(
+ f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
+ )
+
+ bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
+
+ if self.box_crop:
+ clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
+ _get_clamp_bbox(
+ bbox_xywh,
+ image_path=entry.image.path,
+ box_crop_context=self.box_crop_context,
+ ),
+ image_size_hw=tuple(mask.shape[-2:]),
+ )
+ crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
+
+ mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
+
+ fg_probability, _, _ = self._resize_image(mask, mode="nearest")
+
+ return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
+
+ def _load_crop_images(
+ self,
+ entry: types.FrameAnnotation,
+ fg_probability: Optional[torch.Tensor],
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
+ assert self.dataset_root is not None and entry.image is not None
+ path = os.path.join(self.dataset_root, entry.image.path)
+ image_rgb = _load_image(self._local_path(path))
+
+ if image_rgb.shape[-2:] != entry.image.size:
+ raise ValueError(
+ f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
+ )
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
+
+ image_rgb, scale, mask_crop = self._resize_image(image_rgb)
+
+ if self.mask_images:
+ assert fg_probability is not None
+ image_rgb *= fg_probability
+
+ return image_rgb, path, mask_crop, scale
+
+ def _load_mask_depth(
+ self,
+ entry: types.FrameAnnotation,
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ fg_probability: Optional[torch.Tensor],
+ ) -> Tuple[torch.Tensor, str, torch.Tensor]:
+ entry_depth = entry.depth
+ assert entry_depth is not None
+ path = os.path.join(self.dataset_root, entry_depth.path)
+ depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ depth_bbox_xyxy = _rescale_bbox(
+ clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
+ )
+ depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
+
+ depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
+
+ if self.mask_depths:
+ assert fg_probability is not None
+ depth_map *= fg_probability
+
+ if self.load_depth_masks:
+ assert entry_depth.mask_path is not None
+ mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
+ depth_mask = _load_depth_mask(self._local_path(mask_path))
+
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ depth_mask_bbox_xyxy = _rescale_bbox(
+ clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
+ )
+ depth_mask = _crop_around_box(
+ depth_mask, depth_mask_bbox_xyxy, mask_path
+ )
+
+ depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
+ else:
+ depth_mask = torch.ones_like(depth_map)
+
+ return depth_map, path, depth_mask
+
+ def _get_pytorch3d_camera(
+ self,
+ entry: types.FrameAnnotation,
+ scale: float,
+ clamp_bbox_xyxy: Optional[torch.Tensor],
+ ) -> PerspectiveCameras:
+ entry_viewpoint = entry.viewpoint
+ assert entry_viewpoint is not None
+ # principal point and focal length
+ principal_point = torch.tensor(
+ entry_viewpoint.principal_point, dtype=torch.float
+ )
+ focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
+
+ half_image_size_wh_orig = (
+ torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
+ )
+
+ # first, we convert from the dataset's NDC convention to pixels
+ format = entry_viewpoint.intrinsics_format
+ if format.lower() == "ndc_norm_image_bounds":
+ # this is e.g. currently used in CO3D for storing intrinsics
+ rescale = half_image_size_wh_orig
+ elif format.lower() == "ndc_isotropic":
+ rescale = half_image_size_wh_orig.min()
+ else:
+ raise ValueError(f"Unknown intrinsics format: {format}")
+
+ # principal point and focal length in pixels
+ principal_point_px = half_image_size_wh_orig - principal_point * rescale
+ focal_length_px = focal_length * rescale
+ if self.box_crop:
+ assert clamp_bbox_xyxy is not None
+ principal_point_px -= clamp_bbox_xyxy[:2]
+
+ # now, convert from pixels to PyTorch3D v0.5+ NDC convention
+ if self.image_height is None or self.image_width is None:
+ out_size = list(reversed(entry.image.size))
+ else:
+ out_size = [self.image_width, self.image_height]
+
+ half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
+ half_min_image_size_output = half_image_size_output.min()
+
+ # rescaled principal point and focal length in ndc
+ principal_point = (
+ half_image_size_output - principal_point_px * scale
+ ) / half_min_image_size_output
+ focal_length = focal_length_px * scale / half_min_image_size_output
+
+ return PerspectiveCameras(
+ focal_length=focal_length[None],
+ principal_point=principal_point[None],
+ R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
+ T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
+ )
+
+ def _load_frames(self) -> None:
+ logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.")
+ local_file = self._local_path(self.frame_annotations_file)
+ with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
+ frame_annots_list = types.load_dataclass(
+ zipfile, List[self.frame_annotations_type]
+ )
+ if not frame_annots_list:
+ raise ValueError("Empty dataset!")
+ # pyre-ignore[16]
+ self.frame_annots = [
+ FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
+ ]
+
+ def _load_sequences(self) -> None:
+ logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
+ local_file = self._local_path(self.sequence_annotations_file)
+ with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
+ seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
+ if not seq_annots:
+ raise ValueError("Empty sequences file!")
+ # pyre-ignore[16]
+ self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
+
+ def _load_subset_lists(self) -> None:
+ logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.")
+ if not self.subset_lists_file:
+ return
+
+ with open(self._local_path(self.subset_lists_file), "r") as f:
+ subset_to_seq_frame = json.load(f)
+
+ frame_path_to_subset = {
+ path: subset
+ for subset, frames in subset_to_seq_frame.items()
+ for _, _, path in frames
+ }
+ # pyre-ignore[16]
+ for frame in self.frame_annots:
+ frame["subset"] = frame_path_to_subset.get(
+ frame["frame_annotation"].image.path, None
+ )
+ if frame["subset"] is None:
+ warnings.warn(
+ "Subset lists are given but don't include "
+ + frame["frame_annotation"].image.path
+ )
+
+ def _sort_frames(self) -> None:
+ # Sort frames to have them grouped by sequence, ordered by timestamp
+ # pyre-ignore[16]
+ self.frame_annots = sorted(
+ self.frame_annots,
+ key=lambda f: (
+ f["frame_annotation"].sequence_name,
+ f["frame_annotation"].frame_timestamp or 0,
+ ),
+ )
+
+ def _filter_db(self) -> None:
+ if self.remove_empty_masks:
+ logger.info("Removing images with empty masks.")
+ # pyre-ignore[16]
+ old_len = len(self.frame_annots)
+
+ msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
+
+ def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
+ mask = frame_annot.mask
+ if mask is None:
+ return False
+ if mask.mass is None:
+ raise ValueError(msg)
+ return mask.mass > 1
+
+ self.frame_annots = [
+ frame
+ for frame in self.frame_annots
+ if positive_mass(frame["frame_annotation"])
+ ]
+ logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
+
+ # this has to be called after joining with categories!!
+ subsets = self.subsets
+ if subsets:
+ if not self.subset_lists_file:
+ raise ValueError(
+ "Subset filter is on but subset_lists_file was not given"
+ )
+
+ logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.")
+
+ # truncate the list of subsets to the valid one
+ self.frame_annots = [
+ entry for entry in self.frame_annots if entry["subset"] in subsets
+ ]
+ if len(self.frame_annots) == 0:
+ raise ValueError(f"There are no frames in the '{subsets}' subsets!")
+
+ self._invalidate_indexes(filter_seq_annots=True)
+
+ if len(self.limit_category_to) > 0:
+ logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
+ # pyre-ignore[16]
+ self.seq_annots = {
+ name: entry
+ for name, entry in self.seq_annots.items()
+ if entry.category in self.limit_category_to
+ }
+
+ # sequence filters
+ for prefix in ("pick", "exclude"):
+ orig_len = len(self.seq_annots)
+ attr = f"{prefix}_sequence"
+ arr = getattr(self, attr)
+ if len(arr) > 0:
+ logger.info(f"{attr}: {str(arr)}")
+ self.seq_annots = {
+ name: entry
+ for name, entry in self.seq_annots.items()
+ if (name in arr) == (prefix == "pick")
+ }
+ logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
+
+ if self.limit_sequences_to > 0:
+ self.seq_annots = dict(
+ islice(self.seq_annots.items(), self.limit_sequences_to)
+ )
+
+ # retain only frames from retained sequences
+ self.frame_annots = [
+ f
+ for f in self.frame_annots
+ if f["frame_annotation"].sequence_name in self.seq_annots
+ ]
+
+ self._invalidate_indexes()
+
+ if self.n_frames_per_sequence > 0:
+ logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
+ keep_idx = []
+ # pyre-ignore[16]
+ for seq, seq_indices in self._seq_to_idx.items():
+ # infer the seed from the sequence name, this is reproducible
+ # and makes the selection differ for different sequences
+ seed = _seq_name_to_seed(seq) + self.seed
+ seq_idx_shuffled = random.Random(seed).sample(
+ sorted(seq_indices), len(seq_indices)
+ )
+ keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
+
+ logger.info(
+ "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))
+ )
+ self.frame_annots = [self.frame_annots[i] for i in keep_idx]
+ self._invalidate_indexes(filter_seq_annots=False)
+ # sequences are not decimated, so self.seq_annots is valid
+
+ if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
+ logger.info(
+ "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
+ )
+ self.frame_annots = self.frame_annots[: self.limit_to]
+ self._invalidate_indexes(filter_seq_annots=True)
+
+ def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
+ # update _seq_to_idx and filter seq_meta according to frame_annots change
+ # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx
+ self._invalidate_seq_to_idx()
+
+ if filter_seq_annots:
+ # pyre-ignore[16]
+ self.seq_annots = {
+ k: v
+ for k, v in self.seq_annots.items()
+ # pyre-ignore[16]
+ if k in self._seq_to_idx
+ }
+
+ def _invalidate_seq_to_idx(self) -> None:
+ seq_to_idx = defaultdict(list)
+ # pyre-ignore[16]
+ for idx, entry in enumerate(self.frame_annots):
+ seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
+ # pyre-ignore[16]
+ self._seq_to_idx = seq_to_idx
+
+ def _resize_image(
+ self, image, mode="bilinear"
+ ) -> Tuple[torch.Tensor, float, torch.Tensor]:
+ image_height, image_width = self.image_height, self.image_width
+ if image_height is None or image_width is None:
+ # skip the resizing
+ imre_ = torch.from_numpy(image)
+ return imre_, 1.0, torch.ones_like(imre_[:1])
+ # takes numpy array, returns pytorch tensor
+ minscale = min(
+ image_height / image.shape[-2],
+ image_width / image.shape[-1],
+ )
+ imre = torch.nn.functional.interpolate(
+ torch.from_numpy(image)[None],
+ scale_factor=minscale,
+ mode=mode,
+ align_corners=False if mode == "bilinear" else None,
+ recompute_scale_factor=True,
+ )[0]
+ # pyre-fixme[19]: Expected 1 positional argument.
+ imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
+ imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
+ # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
+ # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
+ mask = torch.zeros(1, self.image_height, self.image_width)
+ mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
+ return imre_, minscale, mask
+
+ def _local_path(self, path: str) -> str:
+ if self.path_manager is None:
+ return path
+ return self.path_manager.get_local_path(path)
+
+ def get_frame_numbers_and_timestamps(
+ self, idxs: Sequence[int]
+ ) -> List[Tuple[int, float]]:
+ out: List[Tuple[int, float]] = []
+ for idx in idxs:
+ # pyre-ignore[16]
+ frame_annotation = self.frame_annots[idx]["frame_annotation"]
+ out.append(
+ (frame_annotation.frame_number, frame_annotation.frame_timestamp)
+ )
+ return out
+
+ def category_to_sequence_names(self) -> Dict[str, List[str]]:
+ c2seq = defaultdict(list)
+ # pyre-ignore
+ for sequence_name, sa in self.seq_annots.items():
+ c2seq[sa.category].append(sequence_name)
+ return dict(c2seq)
+
+ def get_eval_batches(self) -> Optional[List[List[int]]]:
+ return self.eval_batches
+
+
+def _seq_name_to_seed(seq_name) -> int:
+ return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
+
+
+def _load_image(path) -> np.ndarray:
+ with Image.open(path) as pil_im:
+ im = np.array(pil_im.convert("RGB"))
+ im = im.transpose((2, 0, 1))
+ im = im.astype(np.float32) / 255.0
+ return im
+
+
+def _load_16big_png_depth(depth_png) -> np.ndarray:
+ with Image.open(depth_png) as depth_pil:
+ # the image is stored with 16-bit depth but PIL reads it as I (32 bit).
+ # we cast it to uint16, then reinterpret as float16, then cast to float32
+ depth = (
+ np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
+ .astype(np.float32)
+ .reshape((depth_pil.size[1], depth_pil.size[0]))
+ )
+ return depth
+
+
+def _load_1bit_png_mask(file: str) -> np.ndarray:
+ with Image.open(file) as pil_im:
+ mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
+ return mask
+
+
+def _load_depth_mask(path: str) -> np.ndarray:
+ if not path.lower().endswith(".png"):
+ raise ValueError('unsupported depth mask file name "%s"' % path)
+ m = _load_1bit_png_mask(path)
+ return m[None] # fake feature channel
+
+
+def _load_depth(path, scale_adjustment) -> np.ndarray:
+ if not path.lower().endswith(".png"):
+ raise ValueError('unsupported depth file name "%s"' % path)
+
+ d = _load_16big_png_depth(path) * scale_adjustment
+ d[~np.isfinite(d)] = 0.0
+ return d[None] # fake feature channel
+
+
+def _load_mask(path) -> np.ndarray:
+ with Image.open(path) as pil_im:
+ mask = np.array(pil_im)
+ mask = mask.astype(np.float32) / 255.0
+ return mask[None] # fake feature channel
+
+
+def _get_1d_bounds(arr) -> Tuple[int, int]:
+ nz = np.flatnonzero(arr)
+ return nz[0], nz[-1] + 1
+
+
+def _get_bbox_from_mask(
+ mask, thr, decrease_quant: float = 0.05
+) -> Tuple[int, int, int, int]:
+ # bbox in xywh
+ masks_for_box = np.zeros_like(mask)
+ while masks_for_box.sum() <= 1.0:
+ masks_for_box = (mask > thr).astype(np.float32)
+ thr -= decrease_quant
+ if thr <= 0.0:
+ warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
+
+ x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
+ y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
+
+ return x0, y0, x1 - x0, y1 - y0
+
+
+def _get_clamp_bbox(
+ bbox: torch.Tensor,
+ box_crop_context: float = 0.0,
+ image_path: str = "",
+) -> torch.Tensor:
+ # box_crop_context: rate of expansion for bbox
+ # returns possibly expanded bbox xyxy as float
+
+ bbox = bbox.clone() # do not edit bbox in place
+
+ # increase box size
+ if box_crop_context > 0.0:
+ c = box_crop_context
+ bbox = bbox.float()
+ bbox[0] -= bbox[2] * c / 2
+ bbox[1] -= bbox[3] * c / 2
+ bbox[2] += bbox[2] * c
+ bbox[3] += bbox[3] * c
+
+ if (bbox[2:] <= 1.0).any():
+ raise ValueError(
+ f"squashed image {image_path}!! The bounding box contains no pixels."
+ )
+
+ bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
+ bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
+
+ return bbox_xyxy
+
+
+def _crop_around_box(tensor, bbox, impath: str = ""):
+ # bbox is xyxy, where the upper bound is corrected with +1
+ bbox = _clamp_box_to_image_bounds_and_round(
+ bbox,
+ image_size_hw=tensor.shape[-2:],
+ )
+ tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
+ assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
+ return tensor
+
+
+def _clamp_box_to_image_bounds_and_round(
+ bbox_xyxy: torch.Tensor,
+ image_size_hw: Tuple[int, int],
+) -> torch.LongTensor:
+ bbox_xyxy = bbox_xyxy.clone()
+ bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
+ bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
+ if not isinstance(bbox_xyxy, torch.LongTensor):
+ bbox_xyxy = bbox_xyxy.round().long()
+ return bbox_xyxy # pyre-ignore [7]
+
+
+def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
+ assert bbox is not None
+ assert np.prod(orig_res) > 1e-8
+ # average ratio of dimensions
+ rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
+ return bbox * rel_size
+
+
+def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
+ wh = xyxy[2:] - xyxy[:2]
+ xywh = torch.cat([xyxy[:2], wh])
+ return xywh
+
+
+def _bbox_xywh_to_xyxy(
+ xywh: torch.Tensor, clamp_size: Optional[int] = None
+) -> torch.Tensor:
+ xyxy = xywh.clone()
+ if clamp_size is not None:
+ xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
+ xyxy[2:] += xyxy[:2]
+ return xyxy
+
+
+def _safe_as_tensor(data, dtype):
+ if data is None:
+ return None
+ return torch.tensor(data, dtype=dtype)
+
+
+# NOTE this cache is per-worker; they are implemented as processes.
+# each batch is loaded and collated by a single worker;
+# since sequences tend to co-occur within batches, this is useful.
+@functools.lru_cache(maxsize=256)
+def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
+ pcl = IO().load_pointcloud(pcl_path)
+ if max_points > 0:
+ pcl = pcl.subsample(max_points)
+
+ return pcl
\ No newline at end of file
diff --git a/sgm/data/latent_objaverse.py b/sgm/data/latent_objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..8819c1e7529efb1fcf44a6f95f92df3d73869517
--- /dev/null
+++ b/sgm/data/latent_objaverse.py
@@ -0,0 +1,52 @@
+import numpy as np
+from pathlib import Path
+from PIL import Image
+import json
+import torch
+from torch.utils.data import Dataset, DataLoader, default_collate
+from torchvision.transforms import ToTensor, Normalize, Compose, Resize
+from pytorch_lightning import LightningDataModule
+from einops import rearrange
+
+
+class LatentObjaverseSpiral(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ transform=None,
+ random_front=False,
+ max_item=None,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ **unused_kwargs,
+ ):
+ print("Using LVIS subset with precomputed Latents")
+ self.root_dir = Path(root_dir)
+ self.split = split
+ self.random_front = random_front
+ self.transform = transform
+
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
+
+ self.ids = json.load(open("./assets/lvis_uids.json", "r"))
+ self.n_views = 18
+ valid_ids = []
+ for idx in self.ids:
+ if (self.root_dir / idx).exists():
+ valid_ids.append(idx)
+ self.ids = valid_ids
+ print("=" * 30)
+ print("Number of valid ids: ", len(self.ids))
+ print("=" * 30)
+
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+
+ if max_item is not None:
+ self.ids = self.ids[:max_item]
+
+ ## debug
+ self.ids = self.ids * 10000
diff --git a/sgm/data/mnist.py b/sgm/data/mnist.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea4d7e670666bec80ecb22aa89603345e173d09
--- /dev/null
+++ b/sgm/data/mnist.py
@@ -0,0 +1,85 @@
+import pytorch_lightning as pl
+import torchvision
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+
+class MNISTDataDictWrapper(Dataset):
+ def __init__(self, dset):
+ super().__init__()
+ self.dset = dset
+
+ def __getitem__(self, i):
+ x, y = self.dset[i]
+ return {"jpg": x, "cls": y}
+
+ def __len__(self):
+ return len(self.dset)
+
+
+class MNISTLoader(pl.LightningDataModule):
+ def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True):
+ super().__init__()
+
+ transform = transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
+ )
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor if num_workers > 0 else 0
+ self.shuffle = shuffle
+ self.train_dataset = MNISTDataDictWrapper(
+ torchvision.datasets.MNIST(
+ root=".data/", train=True, download=True, transform=transform
+ )
+ )
+ self.test_dataset = MNISTDataDictWrapper(
+ torchvision.datasets.MNIST(
+ root=".data/", train=False, download=True, transform=transform
+ )
+ )
+
+ def prepare_data(self):
+ pass
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ )
+
+
+if __name__ == "__main__":
+ dset = MNISTDataDictWrapper(
+ torchvision.datasets.MNIST(
+ root=".data/",
+ train=False,
+ download=True,
+ transform=transforms.Compose(
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)]
+ ),
+ )
+ )
+ ex = dset[0]
diff --git a/sgm/data/mvimagenet.py b/sgm/data/mvimagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..b20c398c08dd976c8bef1455845022f181cfcb73
--- /dev/null
+++ b/sgm/data/mvimagenet.py
@@ -0,0 +1,408 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset, DataLoader, default_collate
+from pathlib import Path
+from PIL import Image
+from scipy.spatial.transform import Rotation
+import rembg
+from rembg import remove, new_session
+from einops import rearrange
+
+from torchvision.transforms import ToTensor, Normalize, Compose, Resize
+from torchvision.transforms.functional import to_tensor
+from pytorch_lightning import LightningDataModule
+
+from sgm.data.colmap import read_cameras_binary, read_images_binary
+from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video
+
+
+def qvec2rotmat(qvec):
+ return np.array(
+ [
+ [
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
+ ],
+ [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
+ ],
+ [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
+ ],
+ ]
+ )
+
+
+def qt2c2w(q, t):
+ # NOTE: remember to convert to opengl coordinate system
+ # rot = Rotation.from_quat(q).as_matrix()
+ rot = qvec2rotmat(q)
+ c2w = np.eye(4)
+ c2w[:3, :3] = np.transpose(rot)
+ c2w[:3, 3] = -np.transpose(rot) @ t
+ c2w[..., 1:3] *= -1
+ return c2w
+
+
+def random_crop():
+ pass
+
+
+class MVImageNet(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split,
+ transform,
+ reso: int = 256,
+ mask_type: str = "random",
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ fps_id=0.0,
+ motion_bucket_id=300.0,
+ num_frames: int = 24,
+ use_mask: bool = True,
+ load_pixelnerf: bool = False,
+ scale_pose: bool = False,
+ max_n_cond: int = 1,
+ min_n_cond: int = 1,
+ cond_on_multi: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.root_dir = Path(root_dir)
+ self.split = split
+
+ avails = self.root_dir.glob("*/*")
+ self.ids = list(
+ map(
+ lambda x: str(x.relative_to(self.root_dir)),
+ filter(lambda x: x.is_dir(), avails),
+ )
+ )
+
+ self.transform = transform
+ self.reso = reso
+ self.num_frames = num_frames
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+ self.fps_id = fps_id
+ self.motion_bucket_id = motion_bucket_id
+ self.mask_type = mask_type
+ self.use_mask = use_mask
+ self.load_pixelnerf = load_pixelnerf
+ self.scale_pose = scale_pose
+ self.max_n_cond = max_n_cond
+ self.min_n_cond = min_n_cond
+ self.cond_on_multi = cond_on_multi
+
+ if self.cond_on_multi:
+ assert self.min_n_cond == self.max_n_cond
+ self.session = new_session()
+
+ def __getitem__(self, index: int):
+ # mvimgnet starts with idx==1
+ idx_list = np.arange(0, self.num_frames)
+ this_image_dir = self.root_dir / self.ids[index] / "images"
+ this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
+
+ # while not this_camera_dir.exists():
+ # index = (index + 1) % len(self.ids)
+ # this_image_dir = self.root_dir / self.ids[index] / "images"
+ # this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
+ if not this_camera_dir.exists():
+ index = 0
+ this_image_dir = self.root_dir / self.ids[index] / "images"
+ this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
+
+ this_images = read_images_binary(this_camera_dir / "images.bin")
+ # filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
+ filenames = list(this_images.keys())
+
+ if len(filenames) == 0:
+ index = 0
+ this_image_dir = self.root_dir / self.ids[index] / "images"
+ this_camera_dir = self.root_dir / self.ids[index] / "sparse/0"
+ this_images = read_images_binary(this_camera_dir / "images.bin")
+ # filenames = list(map(lambda x: f"{x:03d}", this_images.keys()))
+ filenames = list(this_images.keys())
+
+ filenames = list(
+ filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames)
+ )
+
+ filenames = sorted(filenames, key=lambda x: this_images[x].name)
+
+ # # debug
+ # names = []
+ # for v in filenames:
+ # names.append(this_images[v].name)
+ # breakpoint()
+
+ while len(filenames) < self.num_frames:
+ num_surpass = self.num_frames - len(filenames)
+ filenames += list(reversed(filenames[-num_surpass:]))
+
+ if len(filenames) < self.num_frames:
+ print(f"\n\n{self.ids[index]}\n\n")
+
+ frames = []
+ cameras = []
+ downsampled_rgb = []
+ for view_idx in idx_list:
+ this_id = filenames[view_idx]
+ frame = Image.open(this_image_dir / this_images[this_id].name)
+ w, h = frame.size
+
+ if self.mask_type == "random":
+ image_size = min(h, w)
+ left = np.random.randint(0, w - image_size + 1)
+ right = left + image_size
+ top = np.random.randint(0, h - image_size + 1)
+ bottom = top + image_size
+ ## need to assign left, right, top, bottom, image_size
+ elif self.mask_type == "object":
+ pass
+ elif self.mask_type == "rembg":
+ image_size = min(h, w)
+ if (
+ cached := this_image_dir
+ / f"{this_images[this_id].name[:-4]}_rembg.png"
+ ).exists():
+ try:
+ mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3]
+ except:
+ mask = remove(frame, session=self.session)
+ mask.save(cached)
+ mask = np.asarray(mask)[..., 3]
+ else:
+ mask = remove(frame, session=self.session)
+ mask.save(cached)
+ mask = np.asarray(mask)[..., 3]
+ # in h,w order
+ y, x = np.array(mask.nonzero())
+ bbox_cx = x.mean()
+ bbox_cy = y.mean()
+
+ if bbox_cy - image_size / 2 < 0:
+ top = 0
+ elif bbox_cy + image_size / 2 > h:
+ top = h - image_size
+ else:
+ top = int(bbox_cy - image_size / 2)
+
+ if bbox_cx - image_size / 2 < 0:
+ left = 0
+ elif bbox_cx + image_size / 2 > w:
+ left = w - image_size
+ else:
+ left = int(bbox_cx - image_size / 2)
+
+ # top = max(int(bbox_cy - image_size / 2), 0)
+ # left = max(int(bbox_cx - image_size / 2), 0)
+ bottom = top + image_size
+ right = left + image_size
+ else:
+ raise ValueError(f"Unknown mask type: {self.mask_type}")
+
+ frame = frame.crop((left, top, right, bottom))
+ frame = frame.resize((self.reso, self.reso))
+ frames.append(self.transform(frame))
+
+ if self.load_pixelnerf:
+ # extrinsics
+ extrinsics = this_images[this_id]
+ c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec)
+ # intrinsics
+ intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin")
+ assert len(intrinsics) == 1
+ intrinsics = intrinsics[1]
+ f, cx, cy, _ = intrinsics.params
+ f *= 1 / image_size
+ cx -= left
+ cy -= top
+ cx *= 1 / image_size
+ cy *= 1 / image_size # all are relative values
+ intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]])
+
+ this_camera = np.zeros(25)
+ this_camera[:16] = c2w.reshape(-1)
+ this_camera[16:] = intrinsics.reshape(-1)
+
+ cameras.append(this_camera)
+ downsampled = frame.resize((self.reso // 8, self.reso // 8))
+ downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5)
+
+ data = dict()
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+ frames = torch.stack(frames)
+ cond = frames[0]
+ # setting all things in data
+ data["frames"] = frames
+ data["cond_frames_without_noise"] = cond
+ data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames)
+ data["cond_frames"] = cond + cond_aug * torch.randn_like(cond)
+ data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames)
+ data["motion_bucket_id"] = torch.as_tensor(
+ [self.motion_bucket_id] * self.num_frames
+ )
+ data["num_video_frames"] = self.num_frames
+ data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames)
+
+ if self.load_pixelnerf:
+ # TODO: normalize camera poses
+ data["pixelnerf_input"] = dict()
+ data["pixelnerf_input"]["frames"] = frames
+ data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb)
+
+ cameras = torch.from_numpy(np.stack(cameras)).float()
+ if self.scale_pose:
+ c2ws = cameras[..., :16].reshape(-1, 4, 4)
+ center = c2ws[:, :3, 3].mean(0)
+ radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
+ scale = 1.5 / radius
+ c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
+ cameras[..., :16] = c2ws.reshape(-1, 16)
+
+ # if self.max_n_cond > 1:
+ # # TODO implement this
+ # n_cond = np.random.randint(1, self.max_n_cond + 1)
+ # # debug
+ # source_index = [0]
+ # if n_cond > 1:
+ # source_index += np.random.choice(
+ # np.arange(1, self.num_frames),
+ # self.max_n_cond - 1,
+ # replace=False,
+ # ).tolist()
+ # data["pixelnerf_input"]["source_index"] = torch.as_tensor(
+ # source_index
+ # )
+ # data["pixelnerf_input"]["n_cond"] = n_cond
+ # data["pixelnerf_input"]["source_images"] = frames[source_index]
+ # data["pixelnerf_input"]["source_cameras"] = cameras[source_index]
+
+ data["pixelnerf_input"]["cameras"] = cameras
+
+ return data
+
+ def __len__(self):
+ return len(self.ids)
+
+ def collate_fn(self, batch):
+ # a hack to add source index and keep consistent within a batch
+ if self.max_n_cond > 1:
+ # TODO implement this
+ n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1)
+ # debug
+ # source_index = [0]
+ if n_cond > 1:
+ for b in batch:
+ source_index = [0] + np.random.choice(
+ np.arange(1, self.num_frames),
+ self.max_n_cond - 1,
+ replace=False,
+ ).tolist()
+ b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
+ b["pixelnerf_input"]["n_cond"] = n_cond
+ b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
+ b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
+ "cameras"
+ ][source_index]
+
+ if self.cond_on_multi:
+ b["cond_frames_without_noise"] = b["frames"][source_index]
+
+ ret = video_collate_fn(batch)
+
+ if self.cond_on_multi:
+ ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...")
+
+ return ret
+
+
+class MVImageNetFixedCond(MVImageNet):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+
+class MVImageNetDataset(LightningDataModule):
+ def __init__(
+ self,
+ root_dir,
+ batch_size=2,
+ shuffle=True,
+ num_workers=10,
+ prefetch_factor=2,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor
+ self.shuffle = shuffle
+
+ self.transform = Compose(
+ [
+ ToTensor(),
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ]
+ )
+
+ self.train_dataset = MVImageNet(
+ root_dir=root_dir,
+ split="train",
+ transform=self.transform,
+ **kwargs,
+ )
+
+ self.test_dataset = MVImageNet(
+ root_dir=root_dir,
+ split="test",
+ transform=self.transform,
+ **kwargs,
+ )
+
+ def train_dataloader(self):
+ def worker_init_fn(worker_id):
+ np.random.seed(np.random.get_state()[1][0])
+
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=self.train_dataset.collate_fn,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=self.test_dataset.collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=video_collate_fn,
+ )
diff --git a/sgm/data/objaverse.py b/sgm/data/objaverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9ae0730ab09dc4e5ad87e3212b3f2ae22581934
--- /dev/null
+++ b/sgm/data/objaverse.py
@@ -0,0 +1,882 @@
+import numpy as np
+from pathlib import Path
+from PIL import Image
+import json
+import torch
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader, default_collate
+from torchvision.transforms import ToTensor, Normalize, Compose, Resize
+from torchvision.transforms.functional import to_tensor
+from pytorch_lightning import LightningDataModule
+from einops import rearrange
+
+
+def read_camera_matrix_single(json_file):
+ # for gobjaverse
+ with open(json_file, "r", encoding="utf8") as reader:
+ json_content = json.load(reader)
+
+ # negative sign for opencv to opengl
+ camera_matrix = torch.zeros(3, 4)
+ camera_matrix[:3, 0] = torch.tensor(json_content["x"])
+ camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
+ camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
+ camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
+ """
+ camera_matrix = np.eye(4)
+ camera_matrix[:3, 0] = np.array(json_content['x'])
+ camera_matrix[:3, 1] = np.array(json_content['y'])
+ camera_matrix[:3, 2] = np.array(json_content['z'])
+ camera_matrix[:3, 3] = np.array(json_content['origin'])
+ # print(camera_matrix)
+ """
+
+ return camera_matrix
+
+
+def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0):
+ with open(json_file, "r", encoding="utf8") as reader:
+ json_content = json.load(reader)
+
+ h = int(h * scale)
+ w = int(w * scale)
+
+ y_fov = json_content["y_fov"]
+ x_fov = json_content["x_fov"]
+
+ fy = h / 2 / np.tan(y_fov / 2)
+ fx = w / 2 / np.tan(x_fov / 2)
+
+ cx = w // 2
+ cy = h // 2
+
+ intrinsics = torch.tensor(
+ [
+ [fx, fy],
+ [cx, cy],
+ [w, h],
+ ],
+ dtype=torch.float32,
+ )
+ return intrinsics
+
+
+def compose_extrinsic_RT(RT: torch.Tensor):
+ """
+ Compose the standard form extrinsic matrix from RT.
+ Batched I/O.
+ """
+ return torch.cat(
+ [
+ RT,
+ torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(
+ RT.shape[0], 1, 1
+ ),
+ ],
+ dim=1,
+ )
+
+
+def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
+ """
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ Return batched fx, fy, cx, cy
+ """
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
+ cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
+ width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
+ fx, fy = fx / width, fy / height
+ cx, cy = cx / width, cy / height
+ return fx, fy, cx, cy
+
+
+def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor):
+ """
+ RT: (N, 3, 4)
+ intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
+ """
+ E = compose_extrinsic_RT(RT)
+ fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
+ I = torch.stack(
+ [
+ torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
+ torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
+ torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1),
+ ],
+ dim=1,
+ )
+ return torch.cat(
+ [
+ E.reshape(-1, 16),
+ I.reshape(-1, 9),
+ ],
+ dim=-1,
+ )
+
+
+def calc_elevation(c2w):
+ ## works for single or batched c2w
+ ## assume world up is (0, 0, 1)
+ pos = c2w[..., :3, 3]
+
+ return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False))
+
+
+def read_camera_matrix_single(json_file):
+ with open(json_file, "r", encoding="utf8") as reader:
+ json_content = json.load(reader)
+
+ # negative sign for opencv to opengl
+ # camera_matrix = np.zeros([3, 4])
+ # camera_matrix[:3, 0] = np.array(json_content["x"])
+ # camera_matrix[:3, 1] = -np.array(json_content["y"])
+ # camera_matrix[:3, 2] = -np.array(json_content["z"])
+ # camera_matrix[:3, 3] = np.array(json_content["origin"])
+ camera_matrix = torch.zeros([3, 4])
+ camera_matrix[:3, 0] = torch.tensor(json_content["x"])
+ camera_matrix[:3, 1] = -torch.tensor(json_content["y"])
+ camera_matrix[:3, 2] = -torch.tensor(json_content["z"])
+ camera_matrix[:3, 3] = torch.tensor(json_content["origin"])
+ """
+ camera_matrix = np.eye(4)
+ camera_matrix[:3, 0] = np.array(json_content['x'])
+ camera_matrix[:3, 1] = np.array(json_content['y'])
+ camera_matrix[:3, 2] = np.array(json_content['z'])
+ camera_matrix[:3, 3] = np.array(json_content['origin'])
+ # print(camera_matrix)
+ """
+
+ return camera_matrix
+
+
+def blend_white_bg(image):
+ new_image = Image.new("RGB", image.size, (255, 255, 255))
+ new_image.paste(image, mask=image.split()[3])
+
+ return new_image
+
+
+def flatten_for_video(input):
+ return input.flatten()
+
+
+FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"]
+
+
+def video_collate_fn(batch: list[dict], *args, **kwargs):
+ out = {}
+ for key in batch[0].keys():
+ if key in FLATTEN_FIELDS:
+ out[key] = default_collate([item[key] for item in batch])
+ out[key] = flatten_for_video(out[key])
+ elif key == "num_video_frames":
+ out[key] = batch[0][key]
+ elif key in ["frames", "latents", "rgb"]:
+ out[key] = default_collate([item[key] for item in batch])
+ out[key] = rearrange(out[key], "b t c h w -> (b t) c h w")
+ else:
+ out[key] = default_collate([item[key] for item in batch])
+
+ if "pixelnerf_input" in out:
+ out["pixelnerf_input"]["rgb"] = rearrange(
+ out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w"
+ )
+
+ return out
+
+
+class GObjaverse(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ transform=None,
+ random_front=False,
+ max_item=None,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ fps_id=0.0,
+ motion_bucket_id=300.0,
+ use_latents=False,
+ load_caps=False,
+ front_view_selection="random",
+ load_pixelnerf=False,
+ debug_base_idx=None,
+ scale_pose: bool = False,
+ max_n_cond: int = 1,
+ **unused_kwargs,
+ ):
+ self.root_dir = Path(root_dir)
+ self.split = split
+ self.random_front = random_front
+ self.transform = transform
+ self.use_latents = use_latents
+
+ self.ids = json.load(open(self.root_dir / "valid_uids.json", "r"))
+ self.n_views = 24
+
+ self.load_caps = load_caps
+ if self.load_caps:
+ self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r"))
+
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+ self.fps_id = fps_id
+ self.motion_bucket_id = motion_bucket_id
+ self.load_pixelnerf = load_pixelnerf
+ self.scale_pose = scale_pose
+ self.max_n_cond = max_n_cond
+
+ if self.use_latents:
+ self.latents_dir = self.root_dir / "latents256"
+ self.clip_dir = self.root_dir / "clip_emb256"
+
+ self.front_view_selection = front_view_selection
+ if self.front_view_selection == "random":
+ pass
+ elif self.front_view_selection == "fixed":
+ pass
+ elif self.front_view_selection.startswith("clip_score"):
+ self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt")
+ self.ids = list(self.clip_scores.keys())
+ else:
+ raise ValueError(
+ f"Unknown front view selection method {self.front_view_selection}"
+ )
+
+ if max_item is not None:
+ self.ids = self.ids[:max_item]
+ ## debug
+ self.ids = self.ids * 10000
+
+ if debug_base_idx is not None:
+ print(f"debug mode with base idx: {debug_base_idx}")
+ self.debug_base_idx = debug_base_idx
+
+ def __getitem__(self, idx: int):
+ if hasattr(self, "debug_base_idx"):
+ idx = (idx + self.debug_base_idx) % len(self.ids)
+ data = {}
+ idx_list = np.arange(self.n_views)
+ # if self.random_front:
+ # roll_idx = np.random.randint(self.n_views)
+ # idx_list = np.roll(idx_list, roll_idx)
+ if self.front_view_selection == "random":
+ roll_idx = np.random.randint(self.n_views)
+ idx_list = np.roll(idx_list, roll_idx)
+ elif self.front_view_selection == "fixed":
+ pass
+ elif self.front_view_selection == "clip_score_softmax":
+ this_clip_score = (
+ F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
+ )
+ roll_idx = np.random.choice(idx_list, p=this_clip_score)
+ idx_list = np.roll(idx_list, roll_idx)
+ elif self.front_view_selection == "clip_score_max":
+ this_clip_score = (
+ F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy()
+ )
+ roll_idx = np.argmax(this_clip_score)
+ idx_list = np.roll(idx_list, roll_idx)
+ frames = []
+ if not self.use_latents:
+ try:
+ for view_idx in idx_list:
+ frame = Image.open(
+ self.root_dir
+ / "gobjaverse"
+ / self.ids[idx]
+ / f"{view_idx:05d}/{view_idx:05d}.png"
+ )
+ frames.append(self.transform(frame))
+ except:
+ idx = 0
+ frames = []
+ for view_idx in idx_list:
+ frame = Image.open(
+ self.root_dir
+ / "gobjaverse"
+ / self.ids[idx]
+ / f"{view_idx:05d}/{view_idx:05d}.png"
+ )
+ frames.append(self.transform(frame))
+ # a workaround for some bugs in gobjaverse
+ # use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results
+ frames = torch.stack(frames, dim=0)
+ cond = frames[0]
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+
+ data.update(
+ {
+ "frames": frames,
+ "cond_frames_without_noise": cond,
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
+ "fps_id": torch.as_tensor([self.fps_id] * self.n_views),
+ "motion_bucket_id": torch.as_tensor(
+ [self.motion_bucket_id] * self.n_views
+ ),
+ "num_video_frames": 24,
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
+ }
+ )
+ else:
+ latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list]
+ clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0]
+
+ cond = latents[0]
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+
+ data.update(
+ {
+ "latents": latents,
+ "cond_frames_without_noise": clip_emb,
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
+ "fps_id": torch.as_tensor([self.fps_id] * self.n_views),
+ "motion_bucket_id": torch.as_tensor(
+ [self.motion_bucket_id] * self.n_views
+ ),
+ "num_video_frames": 24,
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
+ }
+ )
+
+ if self.condition_on_elevation:
+ sample_c2w = read_camera_matrix_single(
+ self.root_dir / self.ids[idx] / f"00000/00000.json"
+ )
+ elevation = calc_elevation(sample_c2w)
+ data["elevation"] = torch.as_tensor([elevation] * self.n_views)
+
+ if self.load_pixelnerf:
+ assert "frames" in data, f"pixelnerf cannot work with latents only mode"
+ data["pixelnerf_input"] = {}
+ RTs = []
+ intrinsics = []
+ for view_idx in idx_list:
+ meta = (
+ self.root_dir
+ / "gobjaverse"
+ / self.ids[idx]
+ / f"{view_idx:05d}/{view_idx:05d}.json"
+ )
+ RTs.append(read_camera_matrix_single(meta)[:3])
+ intrinsics.append(read_camera_instrinsics_single(meta, 256, 256))
+ RTs = torch.stack(RTs, dim=0)
+ intrinsics = torch.stack(intrinsics, dim=0)
+ cameras = build_camera_standard(RTs, intrinsics)
+ data["pixelnerf_input"]["cameras"] = cameras
+
+ downsampled = []
+ for view_idx in idx_list:
+ frame = Image.open(
+ self.root_dir
+ / "gobjaverse"
+ / self.ids[idx]
+ / f"{view_idx:05d}/{view_idx:05d}.png"
+ ).resize((32, 32))
+ downsampled.append(to_tensor(blend_white_bg(frame)))
+ data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0)
+ data["pixelnerf_input"]["frames"] = data["frames"]
+ if self.scale_pose:
+ c2ws = cameras[..., :16].reshape(-1, 4, 4)
+ center = c2ws[:, :3, 3].mean(0)
+ radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max()
+ scale = 1.5 / radius
+ c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale
+ cameras[..., :16] = c2ws.reshape(-1, 16)
+
+ if self.load_caps:
+ data["caption"] = self.caps[self.ids[idx]]
+ data["ids"] = self.ids[idx]
+
+ return data
+
+ def __len__(self):
+ return len(self.ids)
+
+ def collate_fn(self, batch):
+ if self.max_n_cond > 1:
+ n_cond = np.random.randint(1, self.max_n_cond + 1)
+ if n_cond > 1:
+ for b in batch:
+ source_index = [0] + np.random.choice(
+ np.arange(1, self.n_views),
+ self.max_n_cond - 1,
+ replace=False,
+ ).tolist()
+ b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index)
+ b["pixelnerf_input"]["n_cond"] = n_cond
+ b["pixelnerf_input"]["source_images"] = b["frames"][source_index]
+ b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][
+ "cameras"
+ ][source_index]
+
+ return video_collate_fn(batch)
+
+
+class ObjaverseSpiral(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ transform=None,
+ random_front=False,
+ max_item=None,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ **unused_kwargs,
+ ):
+ self.root_dir = Path(root_dir)
+ self.split = split
+ self.random_front = random_front
+ self.transform = transform
+
+ self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r"))
+ self.n_views = 24
+ valid_ids = []
+ for idx in self.ids:
+ if (self.root_dir / idx).exists():
+ valid_ids.append(idx)
+ self.ids = valid_ids
+
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+
+ if max_item is not None:
+ self.ids = self.ids[:max_item]
+
+ ## debug
+ self.ids = self.ids * 10000
+
+ def __getitem__(self, idx: int):
+ frames = []
+ idx_list = np.arange(self.n_views)
+ if self.random_front:
+ roll_idx = np.random.randint(self.n_views)
+ idx_list = np.roll(idx_list, roll_idx)
+ for view_idx in idx_list:
+ frame = Image.open(
+ self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png"
+ )
+ frames.append(self.transform(frame))
+
+ # data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W]
+ frames = torch.stack(frames, dim=0)
+ cond = frames[0]
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+
+ data = {
+ "frames": frames,
+ "cond_frames_without_noise": cond,
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
+ "fps_id": torch.as_tensor([1.0] * self.n_views),
+ "motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
+ "num_video_frames": 24,
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
+ }
+
+ if self.condition_on_elevation:
+ sample_c2w = read_camera_matrix_single(
+ self.root_dir / self.ids[idx] / f"00000/00000.json"
+ )
+ elevation = calc_elevation(sample_c2w)
+ data["elevation"] = torch.as_tensor([elevation] * self.n_views)
+
+ return data
+
+ def __len__(self):
+ return len(self.ids)
+
+
+class ObjaverseLVISSpiral(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ transform=None,
+ random_front=False,
+ max_item=None,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ use_precomputed_latents=False,
+ **unused_kwargs,
+ ):
+ print("Using LVIS subset")
+ self.root_dir = Path(root_dir)
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
+ self.split = split
+ self.random_front = random_front
+ self.transform = transform
+ self.use_precomputed_latents = use_precomputed_latents
+
+ self.ids = json.load(open("./assets/lvis_uids.json", "r"))
+ self.n_views = 18
+ valid_ids = []
+ for idx in self.ids:
+ if (self.root_dir / idx).exists():
+ valid_ids.append(idx)
+ self.ids = valid_ids
+ print("=" * 30)
+ print("Number of valid ids: ", len(self.ids))
+ print("=" * 30)
+
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+
+ if max_item is not None:
+ self.ids = self.ids[:max_item]
+
+ ## debug
+ self.ids = self.ids * 10000
+
+ def __getitem__(self, idx: int):
+ frames = []
+ idx_list = np.arange(self.n_views)
+ if self.random_front:
+ roll_idx = np.random.randint(self.n_views)
+ idx_list = np.roll(idx_list, roll_idx)
+ for view_idx in idx_list:
+ frame = Image.open(
+ self.root_dir
+ / self.ids[idx]
+ / "elevations_0"
+ / f"colors_{view_idx * 2}.png"
+ )
+ frames.append(self.transform(frame))
+
+ frames = torch.stack(frames, dim=0)
+ cond = frames[0]
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+
+ data = {
+ "frames": frames,
+ "cond_frames_without_noise": cond,
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
+ "fps_id": torch.as_tensor([0.0] * self.n_views),
+ "motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
+ "num_video_frames": self.n_views,
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
+ }
+
+ if self.use_precomputed_latents:
+ data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
+
+ if self.condition_on_elevation:
+ # sample_c2w = read_camera_matrix_single(
+ # self.root_dir / self.ids[idx] / f"00000/00000.json"
+ # )
+ # elevation = calc_elevation(sample_c2w)
+ # data["elevation"] = torch.as_tensor([elevation] * self.n_views)
+ assert False, "currently assumes elevation 0"
+
+ return data
+
+ def __len__(self):
+ return len(self.ids)
+
+
+class ObjaverseALLSpiral(ObjaverseLVISSpiral):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ transform=None,
+ random_front=False,
+ max_item=None,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ use_precomputed_latents=False,
+ **unused_kwargs,
+ ):
+ print("Using ALL objects in Objaverse")
+ self.root_dir = Path(root_dir)
+ self.split = split
+ self.random_front = random_front
+ self.transform = transform
+ self.use_precomputed_latents = use_precomputed_latents
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
+
+ self.ids = json.load(open("./assets/all_ids.json", "r"))
+ self.n_views = 18
+ valid_ids = []
+ for idx in self.ids:
+ if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
+ valid_ids.append(idx)
+ self.ids = valid_ids
+ print("=" * 30)
+ print("Number of valid ids: ", len(self.ids))
+ print("=" * 30)
+
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+
+ if max_item is not None:
+ self.ids = self.ids[:max_item]
+
+ ## debug
+ self.ids = self.ids * 10000
+
+
+class ObjaverseWithPose(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ transform=None,
+ random_front=False,
+ max_item=None,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ condition_on_elevation=False,
+ use_precomputed_latents=False,
+ **unused_kwargs,
+ ):
+ print("Using Objaverse with poses")
+ self.root_dir = Path(root_dir)
+ self.split = split
+ self.random_front = random_front
+ self.transform = transform
+ self.use_precomputed_latents = use_precomputed_latents
+ self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512")
+
+ self.ids = json.load(open("./assets/all_ids.json", "r"))
+ self.n_views = 18
+ valid_ids = []
+ for idx in self.ids:
+ if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir():
+ valid_ids.append(idx)
+ self.ids = valid_ids
+ print("=" * 30)
+ print("Number of valid ids: ", len(self.ids))
+ print("=" * 30)
+
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ self.condition_on_elevation = condition_on_elevation
+
+ def __getitem__(self, idx: int):
+ frames = []
+ idx_list = np.arange(self.n_views)
+ if self.random_front:
+ roll_idx = np.random.randint(self.n_views)
+ idx_list = np.roll(idx_list, roll_idx)
+ for view_idx in idx_list:
+ frame = Image.open(
+ self.root_dir
+ / self.ids[idx]
+ / "elevations_0"
+ / f"colors_{view_idx * 2}.png"
+ )
+ frames.append(self.transform(frame))
+
+ frames = torch.stack(frames, dim=0)
+ cond = frames[0]
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+
+ data = {
+ "frames": frames,
+ "cond_frames_without_noise": cond,
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
+ "fps_id": torch.as_tensor([0.0] * self.n_views),
+ "motion_bucket_id": torch.as_tensor([300.0] * self.n_views),
+ "num_video_frames": self.n_views,
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
+ }
+
+ if self.use_precomputed_latents:
+ data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt")
+
+ if self.condition_on_elevation:
+ assert False, "currently assumes elevation 0"
+
+ return data
+
+
+class LatentObjaverse(Dataset):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ random_front=False,
+ subset="lvis",
+ fps_id=1.0,
+ motion_bucket_id=300.0,
+ cond_aug_mean=-3.0,
+ cond_aug_std=0.5,
+ **unused_kwargs,
+ ):
+ self.root_dir = Path(root_dir)
+ self.split = split
+ self.random_front = random_front
+ self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r"))
+ self.clip_emb_dir = self.root_dir / ".." / "clip_emb512"
+ self.n_views = 18
+ self.fps_id = fps_id
+ self.motion_bucket_id = motion_bucket_id
+ self.cond_aug_mean = cond_aug_mean
+ self.cond_aug_std = cond_aug_std
+ if self.random_front:
+ print("Using a random view as front view")
+
+ valid_ids = []
+ for idx in self.ids:
+ if (self.root_dir / f"{idx}.pt").exists() and (
+ self.clip_emb_dir / f"{idx}.pt"
+ ).exists():
+ valid_ids.append(idx)
+ self.ids = valid_ids
+ print("=" * 30)
+ print("Number of valid ids: ", len(self.ids))
+ print("=" * 30)
+
+ def __getitem__(self, idx: int):
+ uid = self.ids[idx]
+ idx_list = torch.arange(self.n_views)
+ latents = torch.load(self.root_dir / f"{uid}.pt")
+ clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt")
+ if self.random_front:
+ idx_list = torch.roll(idx_list, np.random.randint(self.n_views))
+ latents = latents[idx_list]
+ clip_emb = clip_emb[idx_list][0]
+
+ cond_aug = np.exp(
+ np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean
+ )
+ cond = latents[0]
+
+ data = {
+ "latents": latents,
+ "cond_frames_without_noise": clip_emb,
+ "cond_frames": cond + cond_aug * torch.randn_like(cond),
+ "fps_id": torch.as_tensor([self.fps_id] * self.n_views),
+ "motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views),
+ "cond_aug": torch.as_tensor([cond_aug] * self.n_views),
+ "num_video_frames": self.n_views,
+ "image_only_indicator": torch.as_tensor([0.0] * self.n_views),
+ }
+
+ return data
+
+ def __len__(self):
+ return len(self.ids)
+
+
+class ObjaverseSpiralDataset(LightningDataModule):
+ def __init__(
+ self,
+ root_dir,
+ random_front=False,
+ batch_size=2,
+ num_workers=10,
+ prefetch_factor=2,
+ shuffle=True,
+ max_item=None,
+ dataset_cls="richdreamer",
+ reso: int = 256,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.prefetch_factor = prefetch_factor
+ self.shuffle = shuffle
+ self.max_item = max_item
+
+ self.transform = Compose(
+ [
+ blend_white_bg,
+ Resize((reso, reso)),
+ ToTensor(),
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ ]
+ )
+
+ data_cls = {
+ "richdreamer": ObjaverseSpiral,
+ "lvis": ObjaverseLVISSpiral,
+ "shengshu_all": ObjaverseALLSpiral,
+ "latent": LatentObjaverse,
+ "gobjaverse": GObjaverse,
+ }[dataset_cls]
+
+ self.train_dataset = data_cls(
+ root_dir=root_dir,
+ split="train",
+ random_front=random_front,
+ transform=self.transform,
+ max_item=self.max_item,
+ **kwargs,
+ )
+ self.test_dataset = data_cls(
+ root_dir=root_dir,
+ split="val",
+ random_front=random_front,
+ transform=self.transform,
+ max_item=self.max_item,
+ **kwargs,
+ )
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=video_collate_fn
+ if not hasattr(self.train_dataset, "collate_fn")
+ else self.train_dataset.collate_fn,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=video_collate_fn
+ if not hasattr(self.test_dataset, "collate_fn")
+ else self.train_dataset.collate_fn,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ prefetch_factor=self.prefetch_factor,
+ collate_fn=video_collate_fn
+ if not hasattr(self.test_dataset, "collate_fn")
+ else self.train_dataset.collate_fn,
+ )
diff --git a/sgm/inference/api.py b/sgm/inference/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..a359a67bcd9740acc9e320d2f26dc6a3befb36e0
--- /dev/null
+++ b/sgm/inference/api.py
@@ -0,0 +1,385 @@
+import pathlib
+from dataclasses import asdict, dataclass
+from enum import Enum
+from typing import Optional
+
+from omegaconf import OmegaConf
+
+from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
+ do_sample)
+from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
+ DPMPP2SAncestralSampler,
+ EulerAncestralSampler,
+ EulerEDMSampler,
+ HeunEDMSampler,
+ LinearMultistepSampler)
+from sgm.util import load_model_from_config
+
+
+class ModelArchitecture(str, Enum):
+ SD_2_1 = "stable-diffusion-v2-1"
+ SD_2_1_768 = "stable-diffusion-v2-1-768"
+ SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
+ SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
+ SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
+ SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
+
+
+class Sampler(str, Enum):
+ EULER_EDM = "EulerEDMSampler"
+ HEUN_EDM = "HeunEDMSampler"
+ EULER_ANCESTRAL = "EulerAncestralSampler"
+ DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
+ DPMPP2M = "DPMPP2MSampler"
+ LINEAR_MULTISTEP = "LinearMultistepSampler"
+
+
+class Discretization(str, Enum):
+ LEGACY_DDPM = "LegacyDDPMDiscretization"
+ EDM = "EDMDiscretization"
+
+
+class Guider(str, Enum):
+ VANILLA = "VanillaCFG"
+ IDENTITY = "IdentityGuider"
+
+
+class Thresholder(str, Enum):
+ NONE = "None"
+
+
+@dataclass
+class SamplingParams:
+ width: int = 1024
+ height: int = 1024
+ steps: int = 50
+ sampler: Sampler = Sampler.DPMPP2M
+ discretization: Discretization = Discretization.LEGACY_DDPM
+ guider: Guider = Guider.VANILLA
+ thresholder: Thresholder = Thresholder.NONE
+ scale: float = 6.0
+ aesthetic_score: float = 5.0
+ negative_aesthetic_score: float = 5.0
+ img2img_strength: float = 1.0
+ orig_width: int = 1024
+ orig_height: int = 1024
+ crop_coords_top: int = 0
+ crop_coords_left: int = 0
+ sigma_min: float = 0.0292
+ sigma_max: float = 14.6146
+ rho: float = 3.0
+ s_churn: float = 0.0
+ s_tmin: float = 0.0
+ s_tmax: float = 999.0
+ s_noise: float = 1.0
+ eta: float = 1.0
+ order: int = 4
+
+
+@dataclass
+class SamplingSpec:
+ width: int
+ height: int
+ channels: int
+ factor: int
+ is_legacy: bool
+ config: str
+ ckpt: str
+ is_guided: bool
+
+
+model_specs = {
+ ModelArchitecture.SD_2_1: SamplingSpec(
+ height=512,
+ width=512,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_2_1.yaml",
+ ckpt="v2-1_512-ema-pruned.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SD_2_1_768: SamplingSpec(
+ height=768,
+ width=768,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_2_1_768.yaml",
+ ckpt="v2-1_768-ema-pruned.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=False,
+ config="sd_xl_base.yaml",
+ ckpt="sd_xl_base_0.9.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_xl_refiner.yaml",
+ ckpt="sd_xl_refiner_0.9.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=False,
+ config="sd_xl_base.yaml",
+ ckpt="sd_xl_base_1.0.safetensors",
+ is_guided=True,
+ ),
+ ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
+ height=1024,
+ width=1024,
+ channels=4,
+ factor=8,
+ is_legacy=True,
+ config="sd_xl_refiner.yaml",
+ ckpt="sd_xl_refiner_1.0.safetensors",
+ is_guided=True,
+ ),
+}
+
+
+class SamplingPipeline:
+ def __init__(
+ self,
+ model_id: ModelArchitecture,
+ model_path="checkpoints",
+ config_path="configs/inference",
+ device="cuda",
+ use_fp16=True,
+ ) -> None:
+ if model_id not in model_specs:
+ raise ValueError(f"Model {model_id} not supported")
+ self.model_id = model_id
+ self.specs = model_specs[self.model_id]
+ self.config = str(pathlib.Path(config_path, self.specs.config))
+ self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
+ self.device = device
+ self.model = self._load_model(device=device, use_fp16=use_fp16)
+
+ def _load_model(self, device="cuda", use_fp16=True):
+ config = OmegaConf.load(self.config)
+ model = load_model_from_config(config, self.ckpt)
+ if model is None:
+ raise ValueError(f"Model {self.model_id} could not be loaded")
+ model.to(device)
+ if use_fp16:
+ model.conditioner.half()
+ model.model.half()
+ return model
+
+ def text_to_image(
+ self,
+ params: SamplingParams,
+ prompt: str,
+ negative_prompt: str = "",
+ samples: int = 1,
+ return_latents: bool = False,
+ ):
+ sampler = get_sampler_config(params)
+ value_dict = asdict(params)
+ value_dict["prompt"] = prompt
+ value_dict["negative_prompt"] = negative_prompt
+ value_dict["target_width"] = params.width
+ value_dict["target_height"] = params.height
+ return do_sample(
+ self.model,
+ sampler,
+ value_dict,
+ samples,
+ params.height,
+ params.width,
+ self.specs.channels,
+ self.specs.factor,
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
+ return_latents=return_latents,
+ filter=None,
+ )
+
+ def image_to_image(
+ self,
+ params: SamplingParams,
+ image,
+ prompt: str,
+ negative_prompt: str = "",
+ samples: int = 1,
+ return_latents: bool = False,
+ ):
+ sampler = get_sampler_config(params)
+
+ if params.img2img_strength < 1.0:
+ sampler.discretization = Img2ImgDiscretizationWrapper(
+ sampler.discretization,
+ strength=params.img2img_strength,
+ )
+ height, width = image.shape[2], image.shape[3]
+ value_dict = asdict(params)
+ value_dict["prompt"] = prompt
+ value_dict["negative_prompt"] = negative_prompt
+ value_dict["target_width"] = width
+ value_dict["target_height"] = height
+ return do_img2img(
+ image,
+ self.model,
+ sampler,
+ value_dict,
+ samples,
+ force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
+ return_latents=return_latents,
+ filter=None,
+ )
+
+ def refiner(
+ self,
+ params: SamplingParams,
+ image,
+ prompt: str,
+ negative_prompt: Optional[str] = None,
+ samples: int = 1,
+ return_latents: bool = False,
+ ):
+ sampler = get_sampler_config(params)
+ value_dict = {
+ "orig_width": image.shape[3] * 8,
+ "orig_height": image.shape[2] * 8,
+ "target_width": image.shape[3] * 8,
+ "target_height": image.shape[2] * 8,
+ "prompt": prompt,
+ "negative_prompt": negative_prompt,
+ "crop_coords_top": 0,
+ "crop_coords_left": 0,
+ "aesthetic_score": 6.0,
+ "negative_aesthetic_score": 2.5,
+ }
+
+ return do_img2img(
+ image,
+ self.model,
+ sampler,
+ value_dict,
+ samples,
+ skip_encode=True,
+ return_latents=return_latents,
+ filter=None,
+ )
+
+
+def get_guider_config(params: SamplingParams):
+ if params.guider == Guider.IDENTITY:
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
+ }
+ elif params.guider == Guider.VANILLA:
+ scale = params.scale
+
+ thresholder = params.thresholder
+
+ if thresholder == Thresholder.NONE:
+ dyn_thresh_config = {
+ "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
+ }
+ else:
+ raise NotImplementedError
+
+ guider_config = {
+ "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
+ "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
+ }
+ else:
+ raise NotImplementedError
+ return guider_config
+
+
+def get_discretization_config(params: SamplingParams):
+ if params.discretization == Discretization.LEGACY_DDPM:
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
+ }
+ elif params.discretization == Discretization.EDM:
+ discretization_config = {
+ "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
+ "params": {
+ "sigma_min": params.sigma_min,
+ "sigma_max": params.sigma_max,
+ "rho": params.rho,
+ },
+ }
+ else:
+ raise ValueError(f"unknown discretization {params.discretization}")
+ return discretization_config
+
+
+def get_sampler_config(params: SamplingParams):
+ discretization_config = get_discretization_config(params)
+ guider_config = get_guider_config(params)
+ sampler = None
+ if params.sampler == Sampler.EULER_EDM:
+ return EulerEDMSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=params.s_churn,
+ s_tmin=params.s_tmin,
+ s_tmax=params.s_tmax,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.HEUN_EDM:
+ return HeunEDMSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ s_churn=params.s_churn,
+ s_tmin=params.s_tmin,
+ s_tmax=params.s_tmax,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.EULER_ANCESTRAL:
+ return EulerAncestralSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ eta=params.eta,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
+ return DPMPP2SAncestralSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ eta=params.eta,
+ s_noise=params.s_noise,
+ verbose=True,
+ )
+ if params.sampler == Sampler.DPMPP2M:
+ return DPMPP2MSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ verbose=True,
+ )
+ if params.sampler == Sampler.LINEAR_MULTISTEP:
+ return LinearMultistepSampler(
+ num_steps=params.steps,
+ discretization_config=discretization_config,
+ guider_config=guider_config,
+ order=params.order,
+ verbose=True,
+ )
+
+ raise ValueError(f"unknown sampler {params.sampler}!")
diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..31b0ec3dc414bf522261e35f73805810cd35582d
--- /dev/null
+++ b/sgm/inference/helpers.py
@@ -0,0 +1,305 @@
+import math
+import os
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from einops import rearrange
+from imwatermark import WatermarkEncoder
+from omegaconf import ListConfig
+from PIL import Image
+from torch import autocast
+
+from sgm.util import append_dims
+
+
+class WatermarkEmbedder:
+ def __init__(self, watermark):
+ self.watermark = watermark
+ self.num_bits = len(WATERMARK_BITS)
+ self.encoder = WatermarkEncoder()
+ self.encoder.set_watermark("bits", self.watermark)
+
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
+ """
+ Adds a predefined watermark to the input image
+
+ Args:
+ image: ([N,] B, RGB, H, W) in range [0, 1]
+
+ Returns:
+ same as input but watermarked
+ """
+ squeeze = len(image.shape) == 4
+ if squeeze:
+ image = image[None, ...]
+ n = image.shape[0]
+ image_np = rearrange(
+ (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
+ ).numpy()[:, :, :, ::-1]
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
+ # watermarking libary expects input as cv2 BGR format
+ for k in range(image_np.shape[0]):
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
+ image = torch.from_numpy(
+ rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
+ ).to(image.device)
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
+ if squeeze:
+ image = image[0]
+ return image
+
+
+# A fixed 48-bit message that was choosen at random
+# WATERMARK_MESSAGE = 0xB3EC907BB19E
+WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
+# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
+WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
+embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
+
+
+def get_unique_embedder_keys_from_conditioner(conditioner):
+ return list({x.input_key for x in conditioner.embedders})
+
+
+def perform_save_locally(save_path, samples):
+ os.makedirs(os.path.join(save_path), exist_ok=True)
+ base_count = len(os.listdir(os.path.join(save_path)))
+ samples = embed_watermark(samples)
+ for sample in samples:
+ sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
+ Image.fromarray(sample.astype(np.uint8)).save(
+ os.path.join(save_path, f"{base_count:09}.png")
+ )
+ base_count += 1
+
+
+class Img2ImgDiscretizationWrapper:
+ """
+ wraps a discretizer, and prunes the sigmas
+ params:
+ strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
+ """
+
+ def __init__(self, discretization, strength: float = 1.0):
+ self.discretization = discretization
+ self.strength = strength
+ assert 0.0 <= self.strength <= 1.0
+
+ def __call__(self, *args, **kwargs):
+ # sigmas start large first, and decrease then
+ sigmas = self.discretization(*args, **kwargs)
+ print(f"sigmas after discretization, before pruning img2img: ", sigmas)
+ sigmas = torch.flip(sigmas, (0,))
+ sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
+ print("prune index:", max(int(self.strength * len(sigmas)), 1))
+ sigmas = torch.flip(sigmas, (0,))
+ print(f"sigmas after pruning: ", sigmas)
+ return sigmas
+
+
+def do_sample(
+ model,
+ sampler,
+ value_dict,
+ num_samples,
+ H,
+ W,
+ C,
+ F,
+ force_uc_zero_embeddings: Optional[List] = None,
+ batch2model_input: Optional[List] = None,
+ return_latents=False,
+ filter=None,
+ device="cuda",
+):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ if batch2model_input is None:
+ batch2model_input = []
+
+ with torch.no_grad():
+ with autocast(device) as precision_scope:
+ with model.ema_scope():
+ num_samples = [num_samples]
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ num_samples,
+ )
+ for key in batch:
+ if isinstance(batch[key], torch.Tensor):
+ print(key, batch[key].shape)
+ elif isinstance(batch[key], list):
+ print(key, [len(l) for l in batch[key]])
+ else:
+ print(key, batch[key])
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ if not k == "crossattn":
+ c[k], uc[k] = map(
+ lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
+ )
+
+ additional_model_inputs = {}
+ for k in batch2model_input:
+ additional_model_inputs[k] = batch[k]
+
+ shape = (math.prod(num_samples), C, H // F, W // F)
+ randn = torch.randn(shape).to(device)
+
+ def denoiser(input, sigma, c):
+ return model.denoiser(
+ model.model, input, sigma, c, **additional_model_inputs
+ )
+
+ samples_z = sampler(denoiser, randn, cond=c, uc=uc)
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if filter is not None:
+ samples = filter(samples)
+
+ if return_latents:
+ return samples, samples_z
+ return samples
+
+
+def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
+ # Hardcoded demo setups; might undergo some changes in the future
+
+ batch = {}
+ batch_uc = {}
+
+ for key in keys:
+ if key == "txt":
+ batch["txt"] = (
+ np.repeat([value_dict["prompt"]], repeats=math.prod(N))
+ .reshape(N)
+ .tolist()
+ )
+ batch_uc["txt"] = (
+ np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
+ .reshape(N)
+ .tolist()
+ )
+ elif key == "original_size_as_tuple":
+ batch["original_size_as_tuple"] = (
+ torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+ elif key == "crop_coords_top_left":
+ batch["crop_coords_top_left"] = (
+ torch.tensor(
+ [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
+ )
+ .to(device)
+ .repeat(*N, 1)
+ )
+ elif key == "aesthetic_score":
+ batch["aesthetic_score"] = (
+ torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
+ )
+ batch_uc["aesthetic_score"] = (
+ torch.tensor([value_dict["negative_aesthetic_score"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+
+ elif key == "target_size_as_tuple":
+ batch["target_size_as_tuple"] = (
+ torch.tensor([value_dict["target_height"], value_dict["target_width"]])
+ .to(device)
+ .repeat(*N, 1)
+ )
+ else:
+ batch[key] = value_dict[key]
+
+ for key in batch.keys():
+ if key not in batch_uc and isinstance(batch[key], torch.Tensor):
+ batch_uc[key] = torch.clone(batch[key])
+ return batch, batch_uc
+
+
+def get_input_image_tensor(image: Image.Image, device="cuda"):
+ w, h = image.size
+ print(f"loaded input image of size ({w}, {h})")
+ width, height = map(
+ lambda x: x - x % 64, (w, h)
+ ) # resize to integer multiple of 64
+ image = image.resize((width, height))
+ image_array = np.array(image.convert("RGB"))
+ image_array = image_array[None].transpose(0, 3, 1, 2)
+ image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
+ return image_tensor.to(device)
+
+
+def do_img2img(
+ img,
+ model,
+ sampler,
+ value_dict,
+ num_samples,
+ force_uc_zero_embeddings=[],
+ additional_kwargs={},
+ offset_noise_level: float = 0.0,
+ return_latents=False,
+ skip_encode=False,
+ filter=None,
+ device="cuda",
+):
+ with torch.no_grad():
+ with autocast(device) as precision_scope:
+ with model.ema_scope():
+ batch, batch_uc = get_batch(
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
+ value_dict,
+ [num_samples],
+ )
+ c, uc = model.conditioner.get_unconditional_conditioning(
+ batch,
+ batch_uc=batch_uc,
+ force_uc_zero_embeddings=force_uc_zero_embeddings,
+ )
+
+ for k in c:
+ c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
+
+ for k in additional_kwargs:
+ c[k] = uc[k] = additional_kwargs[k]
+ if skip_encode:
+ z = img
+ else:
+ z = model.encode_first_stage(img)
+ noise = torch.randn_like(z)
+ sigmas = sampler.discretization(sampler.num_steps)
+ sigma = sigmas[0].to(z.device)
+
+ if offset_noise_level > 0.0:
+ noise = noise + offset_noise_level * append_dims(
+ torch.randn(z.shape[0], device=z.device), z.ndim
+ )
+ noised_z = z + noise * append_dims(sigma, z.ndim)
+ noised_z = noised_z / torch.sqrt(
+ 1.0 + sigmas[0] ** 2.0
+ ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
+
+ def denoiser(x, sigma, c):
+ return model.denoiser(model.model, x, sigma, c)
+
+ samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
+ samples_x = model.decode_first_stage(samples_z)
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
+
+ if filter is not None:
+ samples = filter(samples)
+
+ if return_latents:
+ return samples, samples_z
+ return samples
diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2f4d384c1fcaff0df13e0564450d3fa972ace42
--- /dev/null
+++ b/sgm/lr_scheduler.py
@@ -0,0 +1,135 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+
+ def __init__(
+ self,
+ warm_up_steps,
+ lr_min,
+ lr_max,
+ lr_start,
+ max_decay_steps,
+ verbosity_interval=0,
+ ):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (
+ self.lr_max - self.lr_start
+ ) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (
+ self.lr_max_decay_steps - self.lr_warm_up_steps
+ )
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+
+ def __init__(
+ self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0
+ ):
+ assert (
+ len(warm_up_steps)
+ == len(f_min)
+ == len(f_max)
+ == len(f_start)
+ == len(cycle_lengths)
+ )
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.0
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (
+ self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]
+ )
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi)
+ )
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0:
+ print(
+ f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}"
+ )
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
+ cycle
+ ] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (
+ self.cycle_lengths[cycle] - n
+ ) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c410b3747afc208e4204c8f140170e0a7808eace
--- /dev/null
+++ b/sgm/models/__init__.py
@@ -0,0 +1,2 @@
+from .autoencoder import AutoencodingEngine
+from .diffusion import DiffusionEngine
diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2949b91011a2be7a6b8ca17ce260812f20ce8b75
--- /dev/null
+++ b/sgm/models/autoencoder.py
@@ -0,0 +1,615 @@
+import logging
+import math
+import re
+from abc import abstractmethod
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+from ..modules.autoencoding.regularizers import AbstractRegularizer
+from ..modules.ema import LitEma
+from ..util import (default, get_nested_attribute, get_obj_from_str,
+ instantiate_from_config)
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractAutoencoder(pl.LightningModule):
+ """
+ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
+ unCLIP models, etc. Hence, it is fairly general, and specific features
+ (e.g. discriminator training, encoding, decoding) must be implemented in subclasses.
+ """
+
+ def __init__(
+ self,
+ ema_decay: Union[None, float] = None,
+ monitor: Union[None, str] = None,
+ input_key: str = "jpg",
+ ):
+ super().__init__()
+
+ self.input_key = input_key
+ self.use_ema = ema_decay is not None
+ if monitor is not None:
+ self.monitor = monitor
+
+ if self.use_ema:
+ self.model_ema = LitEma(self, decay=ema_decay)
+ logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ self.automatic_optimization = False
+
+ def apply_ckpt(self, ckpt: Union[None, str, dict]):
+ if ckpt is None:
+ return
+ if isinstance(ckpt, str):
+ ckpt = {
+ "target": "sgm.modules.checkpoint.CheckpointEngine",
+ "params": {"ckpt_path": ckpt},
+ }
+ engine = instantiate_from_config(ckpt)
+ engine(self)
+
+ @abstractmethod
+ def get_input(self, batch) -> Any:
+ raise NotImplementedError()
+
+ def on_train_batch_end(self, *args, **kwargs):
+ # for EMA computation
+ if self.use_ema:
+ self.model_ema(self)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ logpy.info(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ logpy.info(f"{context}: Restored training weights")
+
+ @abstractmethod
+ def encode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("encode()-method of abstract base class called")
+
+ @abstractmethod
+ def decode(self, *args, **kwargs) -> torch.Tensor:
+ raise NotImplementedError("decode()-method of abstract base class called")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self) -> Any:
+ raise NotImplementedError()
+
+
+class AutoencodingEngine(AbstractAutoencoder):
+ """
+ Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL
+ (we also restore them explicitly as special cases for legacy reasons).
+ Regularizations such as KL or VQ are moved to the regularizer class.
+ """
+
+ def __init__(
+ self,
+ *args,
+ encoder_config: Dict,
+ decoder_config: Dict,
+ loss_config: Dict,
+ regularizer_config: Dict,
+ optimizer_config: Union[Dict, None] = None,
+ lr_g_factor: float = 1.0,
+ trainable_ae_params: Optional[List[List[str]]] = None,
+ ae_optimizer_args: Optional[List[dict]] = None,
+ trainable_disc_params: Optional[List[List[str]]] = None,
+ disc_optimizer_args: Optional[List[dict]] = None,
+ disc_start_iter: int = 0,
+ diff_boost_factor: float = 3.0,
+ ckpt_engine: Union[None, str, dict] = None,
+ ckpt_path: Optional[str] = None,
+ additional_decode_keys: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+ self.automatic_optimization = False # pytorch lightning
+
+ self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
+ self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
+ self.loss: torch.nn.Module = instantiate_from_config(loss_config)
+ self.regularization: AbstractRegularizer = instantiate_from_config(
+ regularizer_config
+ )
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.Adam"}
+ )
+ self.diff_boost_factor = diff_boost_factor
+ self.disc_start_iter = disc_start_iter
+ self.lr_g_factor = lr_g_factor
+ self.trainable_ae_params = trainable_ae_params
+ if self.trainable_ae_params is not None:
+ self.ae_optimizer_args = default(
+ ae_optimizer_args,
+ [{} for _ in range(len(self.trainable_ae_params))],
+ )
+ assert len(self.ae_optimizer_args) == len(self.trainable_ae_params)
+ else:
+ self.ae_optimizer_args = [{}] # makes type consitent
+
+ self.trainable_disc_params = trainable_disc_params
+ if self.trainable_disc_params is not None:
+ self.disc_optimizer_args = default(
+ disc_optimizer_args,
+ [{} for _ in range(len(self.trainable_disc_params))],
+ )
+ assert len(self.disc_optimizer_args) == len(self.trainable_disc_params)
+ else:
+ self.disc_optimizer_args = [{}] # makes type consitent
+
+ if ckpt_path is not None:
+ assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path"
+ logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead")
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+ self.additional_decode_keys = set(default(additional_decode_keys, []))
+
+ def get_input(self, batch: Dict) -> torch.Tensor:
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in channels-first
+ # format (e.g., bchw instead if bhwc)
+ return batch[self.input_key]
+
+ def get_autoencoder_params(self) -> list:
+ params = []
+ if hasattr(self.loss, "get_trainable_autoencoder_parameters"):
+ params += list(self.loss.get_trainable_autoencoder_parameters())
+ if hasattr(self.regularization, "get_trainable_parameters"):
+ params += list(self.regularization.get_trainable_parameters())
+ params = params + list(self.encoder.parameters())
+ params = params + list(self.decoder.parameters())
+ return params
+
+ def get_discriminator_params(self) -> list:
+ if hasattr(self.loss, "get_trainable_parameters"):
+ params = list(self.loss.get_trainable_parameters()) # e.g., discriminator
+ else:
+ params = []
+ return params
+
+ def get_last_layer(self):
+ return self.decoder.get_last_layer()
+
+ def encode(
+ self,
+ x: torch.Tensor,
+ return_reg_log: bool = False,
+ unregularized: bool = False,
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ z = self.encoder(x)
+ if unregularized:
+ return z, dict()
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor:
+ x = self.decoder(z, **kwargs)
+ return x
+
+ def forward(
+ self, x: torch.Tensor, **additional_decode_kwargs
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ z, reg_log = self.encode(x, return_reg_log=True)
+ dec = self.decode(z, **additional_decode_kwargs)
+ return z, dec, reg_log
+
+ def inner_training_step(
+ self, batch: dict, batch_idx: int, optimizer_idx: int = 0
+ ) -> torch.Tensor:
+ x = self.get_input(batch)
+ additional_decode_kwargs = {
+ key: batch[key] for key in self.additional_decode_keys.intersection(batch)
+ }
+ z, xrec, regularization_log = self(x, **additional_decode_kwargs)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": optimizer_idx,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "train",
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+
+ if optimizer_idx == 0:
+ # autoencode
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {"train/loss/rec": aeloss.detach()}
+
+ self.log_dict(
+ log_dict_ae,
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=False,
+ )
+ self.log(
+ "loss",
+ aeloss.mean().detach(),
+ prog_bar=True,
+ logger=False,
+ on_epoch=False,
+ on_step=True,
+ )
+ return aeloss
+ elif optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ # -> discriminator always needs to return a tuple
+ self.log_dict(
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True
+ )
+ return discloss
+ else:
+ raise NotImplementedError(f"Unknown optimizer {optimizer_idx}")
+
+ def training_step(self, batch: dict, batch_idx: int):
+ opts = self.optimizers()
+ if not isinstance(opts, list):
+ # Non-adversarial case
+ opts = [opts]
+ optimizer_idx = batch_idx % len(opts)
+ if self.global_step < self.disc_start_iter:
+ optimizer_idx = 0
+ opt = opts[optimizer_idx]
+ opt.zero_grad()
+ with opt.toggle_model():
+ loss = self.inner_training_step(
+ batch, batch_idx, optimizer_idx=optimizer_idx
+ )
+ self.manual_backward(loss)
+ opt.step()
+
+ def validation_step(self, batch: dict, batch_idx: int) -> Dict:
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ log_dict.update(log_dict_ema)
+ return log_dict
+
+ def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict:
+ x = self.get_input(batch)
+
+ z, xrec, regularization_log = self(x)
+ if hasattr(self.loss, "forward_keys"):
+ extra_info = {
+ "z": z,
+ "optimizer_idx": 0,
+ "global_step": self.global_step,
+ "last_layer": self.get_last_layer(),
+ "split": "val" + postfix,
+ "regularization_log": regularization_log,
+ "autoencoder": self,
+ }
+ extra_info = {k: extra_info[k] for k in self.loss.forward_keys}
+ else:
+ extra_info = dict()
+ out_loss = self.loss(x, xrec, **extra_info)
+ if isinstance(out_loss, tuple):
+ aeloss, log_dict_ae = out_loss
+ else:
+ # simple loss function
+ aeloss = out_loss
+ log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()}
+ full_log_dict = log_dict_ae
+
+ if "optimizer_idx" in extra_info:
+ extra_info["optimizer_idx"] = 1
+ discloss, log_dict_disc = self.loss(x, xrec, **extra_info)
+ full_log_dict.update(log_dict_disc)
+ self.log(
+ f"val{postfix}/loss/rec",
+ log_dict_ae[f"val{postfix}/loss/rec"],
+ sync_dist=True,
+ )
+ self.log_dict(full_log_dict, sync_dist=True)
+ return full_log_dict
+
+ def get_param_groups(
+ self, parameter_names: List[List[str]], optimizer_args: List[dict]
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ groups = []
+ num_params = 0
+ for names, args in zip(parameter_names, optimizer_args):
+ params = []
+ for pattern_ in names:
+ pattern_params = []
+ pattern = re.compile(pattern_)
+ for p_name, param in self.named_parameters():
+ if re.match(pattern, p_name):
+ pattern_params.append(param)
+ num_params += param.numel()
+ if len(pattern_params) == 0:
+ logpy.warn(f"Did not find parameters for pattern {pattern_}")
+ params.extend(pattern_params)
+ groups.append({"params": params, **args})
+ return groups, num_params
+
+ def configure_optimizers(self) -> List[torch.optim.Optimizer]:
+ if self.trainable_ae_params is None:
+ ae_params = self.get_autoencoder_params()
+ else:
+ ae_params, num_ae_params = self.get_param_groups(
+ self.trainable_ae_params, self.ae_optimizer_args
+ )
+ logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}")
+ if self.trainable_disc_params is None:
+ disc_params = self.get_discriminator_params()
+ else:
+ disc_params, num_disc_params = self.get_param_groups(
+ self.trainable_disc_params, self.disc_optimizer_args
+ )
+ logpy.info(
+ f"Number of trainable discriminator parameters: {num_disc_params:,}"
+ )
+ opt_ae = self.instantiate_optimizer_from_config(
+ ae_params,
+ default(self.lr_g_factor, 1.0) * self.learning_rate,
+ self.optimizer_config,
+ )
+ opts = [opt_ae]
+ if len(disc_params) > 0:
+ opt_disc = self.instantiate_optimizer_from_config(
+ disc_params, self.learning_rate, self.optimizer_config
+ )
+ opts.append(opt_disc)
+
+ return opts
+
+ @torch.no_grad()
+ def log_images(
+ self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs
+ ) -> dict:
+ log = dict()
+ additional_decode_kwargs = {}
+ x = self.get_input(batch)
+ additional_decode_kwargs.update(
+ {key: batch[key] for key in self.additional_decode_keys.intersection(batch)}
+ )
+
+ _, xrec, _ = self(x, **additional_decode_kwargs)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x)
+ diff.clamp_(0, 1.0)
+ log["diff"] = 2.0 * diff - 1.0
+ # diff_boost shows location of small errors, by boosting their
+ # brightness.
+ log["diff_boost"] = (
+ 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1
+ )
+ if hasattr(self.loss, "log_images"):
+ log.update(self.loss.log_images(x, xrec))
+ with self.ema_scope():
+ _, xrec_ema, _ = self(x, **additional_decode_kwargs)
+ log["reconstructions_ema"] = xrec_ema
+ diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x)
+ diff_ema.clamp_(0, 1.0)
+ log["diff_ema"] = 2.0 * diff_ema - 1.0
+ log["diff_boost_ema"] = (
+ 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1
+ )
+ if additional_log_kwargs:
+ additional_decode_kwargs.update(additional_log_kwargs)
+ _, xrec_add, _ = self(x, **additional_decode_kwargs)
+ log_str = "reconstructions-" + "-".join(
+ [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs]
+ )
+ log[log_str] = xrec_add
+ return log
+
+
+class AutoencodingEngineLegacy(AutoencodingEngine):
+ def __init__(self, embed_dim: int, **kwargs):
+ self.max_batch_size = kwargs.pop("max_batch_size", None)
+ ddconfig = kwargs.pop("ddconfig")
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ckpt_engine = kwargs.pop("ckpt_engine", None)
+ super().__init__(
+ encoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Encoder",
+ "params": ddconfig,
+ },
+ decoder_config={
+ "target": "sgm.modules.diffusionmodules.model.Decoder",
+ "params": ddconfig,
+ },
+ **kwargs,
+ )
+ self.quant_conv = torch.nn.Conv2d(
+ (1 + ddconfig["double_z"]) * ddconfig["z_channels"],
+ (1 + ddconfig["double_z"]) * embed_dim,
+ 1,
+ )
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+
+ self.apply_ckpt(default(ckpt_path, ckpt_engine))
+
+ def get_autoencoder_params(self) -> list:
+ params = super().get_autoencoder_params()
+ return params
+
+ def encode(
+ self, x: torch.Tensor, return_reg_log: bool = False
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
+ if self.max_batch_size is None:
+ z = self.encoder(x)
+ z = self.quant_conv(z)
+ else:
+ N = x.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ z = list()
+ for i_batch in range(n_batches):
+ z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
+ z_batch = self.quant_conv(z_batch)
+ z.append(z_batch)
+ z = torch.cat(z, 0)
+
+ z, reg_log = self.regularization(z)
+ if return_reg_log:
+ return z, reg_log
+ return z
+
+ def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
+ if self.max_batch_size is None:
+ dec = self.post_quant_conv(z)
+ dec = self.decoder(dec, **decoder_kwargs)
+ else:
+ N = z.shape[0]
+ bs = self.max_batch_size
+ n_batches = int(math.ceil(N / bs))
+ dec = list()
+ for i_batch in range(n_batches):
+ dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
+ dec_batch = self.decoder(dec_batch, **decoder_kwargs)
+ dec.append(dec_batch)
+ dec = torch.cat(dec, 0)
+
+ return dec
+
+
+class AutoencoderKL(AutoencodingEngineLegacy):
+ def __init__(self, **kwargs):
+ if "lossconfig" in kwargs:
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={
+ "target": (
+ "sgm.modules.autoencoding.regularizers"
+ ".DiagonalGaussianRegularizer"
+ )
+ },
+ **kwargs,
+ )
+
+
+class AutoencoderLegacyVQ(AutoencodingEngineLegacy):
+ def __init__(
+ self,
+ embed_dim: int,
+ n_embed: int,
+ sane_index_shape: bool = False,
+ **kwargs,
+ ):
+ if "lossconfig" in kwargs:
+ logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.")
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={
+ "target": (
+ "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer"
+ ),
+ "params": {
+ "n_e": n_embed,
+ "e_dim": embed_dim,
+ "sane_index_shape": sane_index_shape,
+ },
+ },
+ **kwargs,
+ )
+
+
+class IdentityFirstStage(AbstractAutoencoder):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def get_input(self, x: Any) -> Any:
+ return x
+
+ def encode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+ def decode(self, x: Any, *args, **kwargs) -> Any:
+ return x
+
+
+class AEIntegerWrapper(nn.Module):
+ def __init__(
+ self,
+ model: nn.Module,
+ shape: Union[None, Tuple[int, int], List[int]] = (16, 16),
+ regularization_key: str = "regularization",
+ encoder_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ super().__init__()
+ self.model = model
+ assert hasattr(model, "encode") and hasattr(
+ model, "decode"
+ ), "Need AE interface"
+ self.regularization = get_nested_attribute(model, regularization_key)
+ self.shape = shape
+ self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True})
+
+ def encode(self, x) -> torch.Tensor:
+ assert (
+ not self.training
+ ), f"{self.__class__.__name__} only supports inference currently"
+ _, log = self.model.encode(x, **self.encoder_kwargs)
+ assert isinstance(log, dict)
+ inds = log["min_encoding_indices"]
+ return rearrange(inds, "b ... -> b (...)")
+
+ def decode(
+ self, inds: torch.Tensor, shape: Union[None, tuple, list] = None
+ ) -> torch.Tensor:
+ # expect inds shape (b, s) with s = h*w
+ shape = default(shape, self.shape) # Optional[(h, w)]
+ if shape is not None:
+ assert len(shape) == 2, f"Unhandeled shape {shape}"
+ inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1])
+ h = self.regularization.get_codebook_entry(inds) # (b, h, w, c)
+ h = rearrange(h, "b h w c -> b c h w")
+ return self.model.decode(h)
+
+
+class AutoencoderKLModeOnly(AutoencodingEngineLegacy):
+ def __init__(self, **kwargs):
+ if "lossconfig" in kwargs:
+ kwargs["loss_config"] = kwargs.pop("lossconfig")
+ super().__init__(
+ regularizer_config={
+ "target": (
+ "sgm.modules.autoencoding.regularizers"
+ ".DiagonalGaussianRegularizer"
+ ),
+ "params": {"sample": False},
+ },
+ **kwargs,
+ )
diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..41a0f4a7c6a7ed49e2d2538879d47d18ede16cba
--- /dev/null
+++ b/sgm/models/diffusion.py
@@ -0,0 +1,358 @@
+import math
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import ListConfig, OmegaConf
+from safetensors.torch import load_file as load_safetensors
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange
+
+from ..modules import UNCONDITIONAL_CONFIG
+from ..modules.autoencoding.temporal_ae import VideoDecoder
+from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from ..modules.ema import LitEma
+from ..util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+)
+
+
+class DiffusionEngine(pl.LightningModule):
+ def __init__(
+ self,
+ network_config,
+ denoiser_config,
+ first_stage_config,
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ network_wrapper: Union[None, str] = None,
+ ckpt_path: Union[None, str] = None,
+ use_ema: bool = False,
+ ema_decay_rate: float = 0.9999,
+ scale_factor: float = 1.0,
+ disable_first_stage_autocast=False,
+ input_key: str = "jpg",
+ log_keys: Union[List, None] = None,
+ no_cond_log: bool = False,
+ compile_model: bool = False,
+ en_and_decode_n_samples_a_time: Optional[int] = None,
+ ):
+ super().__init__()
+ self.log_keys = log_keys
+ self.input_key = input_key
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.AdamW"}
+ )
+ model = instantiate_from_config(network_config)
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+ self.sampler = (
+ instantiate_from_config(sampler_config)
+ if sampler_config is not None
+ else None
+ )
+ self.conditioner = instantiate_from_config(
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
+ )
+ self.scheduler_config = scheduler_config
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = (
+ instantiate_from_config(loss_fn_config)
+ if loss_fn_config is not None
+ else None
+ )
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+
+ def init_from_ckpt(
+ self,
+ path: str,
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+
+ def get_input(self, batch):
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in bchw format
+ return batch[self.input_key]
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
+
+ n_rounds = math.ceil(z.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
+ else:
+ kwargs = {}
+ out = self.first_stage_model.decode(
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
+ )
+ all_out.append(out)
+ out = torch.cat(all_out, dim=0)
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ bs = x.shape[0]
+ is_video_input = False
+ if x.dim() == 5:
+ is_video_input = True
+ # for video diffusion
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
+ n_rounds = math.ceil(x.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ out = self.first_stage_model.encode(
+ x[n * n_samples : (n + 1) * n_samples]
+ )
+ all_out.append(out)
+ z = torch.cat(all_out, dim=0)
+ z = self.scale_factor * z
+
+ if is_video_input:
+ z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
+
+ return z
+
+ def forward(self, x, batch):
+ loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch)
+ loss_mean = loss.mean()
+ loss_dict = {"loss": loss_mean}
+ return loss_mean, loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ x = self.get_input(batch)
+ breakpoint()
+ x = self.encode_first_stage(x)
+ batch["global_step"] = self.global_step
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ self.log(
+ "global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ if self.scheduler_config is not None:
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ return loss
+
+ def on_train_start(self, *args, **kwargs):
+ if self.sampler is None or self.loss_fn is None:
+ raise ValueError("Sampler and loss function need to be set for training.")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ for embedder in self.conditioner.embedders:
+ if embedder.is_trainable:
+ params = params + list(embedder.parameters())
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ "interval": "step",
+ "frequency": 1,
+ }
+ ]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[2:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if (
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
+ ) and not self.no_cond_log:
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = [
+ "x".join([str(xx) for xx in x[i].tolist()])
+ for i in range(x.shape[0])
+ ]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ elif isinstance(x, (List, ListConfig)):
+ if isinstance(x[0], str):
+ # strings
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+ return log
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Dict,
+ N: int = 8,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys
+ if len(self.conditioner.embedders) > 0
+ else [],
+ )
+
+ sampling_kwargs = {}
+
+ N = min(x.shape[0], N)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ log["reconstructions"] = self.decode_first_stage(z)
+ log.update(self.log_conditionings(batch, N))
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
+
+ if sample:
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ log["samples"] = samples
+ return log
diff --git a/sgm/models/video3d_diffusion.py b/sgm/models/video3d_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c4f97ec0c975937f4686471b1fa5698af013197
--- /dev/null
+++ b/sgm/models/video3d_diffusion.py
@@ -0,0 +1,524 @@
+import re
+import math
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+import torch
+from omegaconf import ListConfig, OmegaConf
+from safetensors.torch import load_file as load_safetensors
+from torch.optim.lr_scheduler import LambdaLR
+from torchvision.utils import make_grid
+from einops import rearrange, repeat
+
+from ..modules import UNCONDITIONAL_CONFIG
+from ..modules.autoencoding.temporal_ae import VideoDecoder
+from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from ..modules.ema import LitEma
+from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
+from ..util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+ video_frames_as_grid,
+)
+
+
+def flatten_for_video(input):
+ return input.flatten()
+
+
+class Video3DDiffusionEngine(pl.LightningModule):
+ def __init__(
+ self,
+ network_config,
+ denoiser_config,
+ first_stage_config,
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ network_wrapper: Union[None, str] = None,
+ ckpt_path: Union[None, str] = None,
+ use_ema: bool = False,
+ ema_decay_rate: float = 0.9999,
+ scale_factor: float = 1.0,
+ disable_first_stage_autocast=False,
+ input_key: str = "frames", # for video inputs
+ log_keys: Union[List, None] = None,
+ no_cond_log: bool = False,
+ compile_model: bool = False,
+ en_and_decode_n_samples_a_time: Optional[int] = None,
+ ):
+ super().__init__()
+ self.log_keys = log_keys
+ self.input_key = input_key
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.AdamW"}
+ )
+ model = instantiate_from_config(network_config)
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+ self.sampler = (
+ instantiate_from_config(sampler_config)
+ if sampler_config is not None
+ else None
+ )
+ self.conditioner = instantiate_from_config(
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
+ )
+ self.scheduler_config = scheduler_config
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = (
+ instantiate_from_config(loss_fn_config)
+ if loss_fn_config is not None
+ else None
+ )
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path)
+
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+
+ def _load_last_embedder(self, original_state_dict):
+ original_module_name = "conditioner.embedders.3"
+ state_dict = dict()
+ for k, v in original_state_dict.items():
+ m = re.match(rf"^{original_module_name}\.(.*)$", k)
+ if m is None:
+ continue
+ state_dict[m.group(1)] = v
+
+ idx = -1
+ for i in range(len(self.conditioner.embedders)):
+ if isinstance(
+ self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
+ ):
+ idx = i
+
+ print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
+
+ self.conditioner.embedders[idx].load_state_dict(state_dict)
+
+ def init_from_ckpt(
+ self,
+ path: str,
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ self_sd = self.state_dict()
+ input_keys = [
+ "model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight",
+ ]
+ for input_key in input_keys:
+ if input_key not in sd or input_key not in self_sd:
+ continue
+
+ input_weight = self_sd[input_key]
+
+ if input_weight.shape != sd[input_key].shape:
+ print("Manual init: {}".format(input_key))
+ input_weight.zero_()
+ input_weight[:, :8, :, :].copy_(sd[input_key])
+
+ deleted_keys = []
+ for k, v in self.state_dict().items():
+ # resolve shape dismatch
+ if k in sd:
+ if v.shape != sd[k].shape:
+ del sd[k]
+ deleted_keys.append(k)
+
+ if len(deleted_keys) > 0:
+ print(f"Deleted Keys: {deleted_keys}")
+
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+ if len(deleted_keys) > 0:
+ print(f"Deleted Keys: {deleted_keys}")
+
+ if len(missing) > 0 or len(unexpected) > 0:
+ # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
+ print("Modified embedder to support 3d spiral video inputs")
+ try:
+ self._load_last_embedder(sd)
+ except:
+ print("Failed to load last embedder, make sure this is expected")
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+
+ def get_input(self, batch):
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in bchw format
+ return batch[self.input_key]
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ is_video_input = False
+ bs = z.shape[0]
+ if z.dim() == 5:
+ is_video_input = True
+ # for video diffusion
+ z = rearrange(z, "b t c h w -> (b t) c h w")
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
+
+ n_rounds = math.ceil(z.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
+ else:
+ kwargs = {}
+ out = self.first_stage_model.decode(
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
+ )
+ all_out.append(out)
+ out = torch.cat(all_out, dim=0)
+
+ if is_video_input:
+ out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
+
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if self.input_key == "latents":
+ return x
+
+ bs = x.shape[0]
+ is_video_input = False
+ if x.dim() == 5:
+ is_video_input = True
+ # for video diffusion
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
+ n_rounds = math.ceil(x.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ out = self.first_stage_model.encode(
+ x[n * n_samples : (n + 1) * n_samples]
+ )
+ all_out.append(out)
+ z = torch.cat(all_out, dim=0)
+ z = self.scale_factor * z
+
+ # if is_video_input:
+ # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
+
+ return z
+
+ def forward(self, x, batch):
+ loss, model_output = self.loss_fn(
+ self.model,
+ self.denoiser,
+ self.conditioner,
+ x,
+ batch,
+ return_model_output=True,
+ )
+ loss_mean = loss.mean()
+ loss_dict = {"loss": loss_mean, "model_output": model_output}
+ return loss_mean, loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ # TODO: move this shit to collate_fn in dataloader
+ # if "fps_id" in batch:
+ # batch["fps_id"] = flatten_for_video(batch["fps_id"])
+ # if "motion_bucket_id" in batch:
+ # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
+ # if "cond_aug" in batch:
+ # batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
+ x = self.get_input(batch)
+ x = self.encode_first_stage(x)
+ # ## debug
+ # x_recon = self.decode_first_stage(x)
+ # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
+ # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
+ # ## debug
+ batch["global_step"] = self.global_step
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ with torch.no_grad():
+ if "model_output" in loss_dict:
+ if batch_idx % 100 == 0:
+ if isinstance(self.logger, WandbLogger):
+ model_output = loss_dict["model_output"].detach()[
+ : batch["num_video_frames"]
+ ]
+ recons = (
+ (self.decode_first_stage(model_output) + 1.0) / 2.0
+ ).clamp(0.0, 1.0)
+ recon_grid = make_grid(recons, nrow=4)
+ self.logger.log_image(
+ key=f"train/model_output_recon",
+ images=[recon_grid],
+ step=self.global_step,
+ )
+ del loss_dict["model_output"]
+
+ if torch.isnan(loss).any():
+ print("Nan detected")
+ loss = None
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ self.log(
+ "global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ if self.scheduler_config is not None:
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ return loss
+
+ def on_train_start(self, *args, **kwargs):
+ if self.sampler is None or self.loss_fn is None:
+ raise ValueError("Sampler and loss function need to be set for training.")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ for embedder in self.conditioner.embedders:
+ if embedder.is_trainable:
+ params = params + list(embedder.parameters())
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ "interval": "step",
+ "frequency": 1,
+ }
+ ]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[-2:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if (
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
+ ) and not self.no_cond_log:
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = [
+ "x".join([str(xx) for xx in x[i].tolist()])
+ for i in range(x.shape[0])
+ ]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ elif x.dim() == 4:
+ # image
+ xc = x
+ else:
+ raise NotImplementedError()
+ elif isinstance(x, (List, ListConfig)):
+ if isinstance(x[0], str):
+ # strings
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+
+ return log
+
+ # for video diffusions will be logging frames of a video
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Dict,
+ N: int = 1,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ # # debug
+ # return {}
+ # # debug
+ assert "num_video_frames" in batch, "num_video_frames must be in batch"
+ num_video_frames = batch["num_video_frames"]
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ conditioner_input_keys = []
+ for e in self.conditioner.embedders:
+ if e.input_key is not None:
+ conditioner_input_keys.append(e.input_key)
+ else:
+ conditioner_input_keys.extend(e.input_keys)
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys
+ if len(self.conditioner.embedders) > 0
+ else [],
+ )
+
+ sampling_kwargs = {"num_video_frames": num_video_frames}
+ n = min(x.shape[0] // num_video_frames, N)
+ sampling_kwargs["image_only_indicator"] = torch.cat(
+ [batch["image_only_indicator"][:n]] * 2
+ )
+
+ N = min(x.shape[0] // num_video_frames, N) * num_video_frames
+ x = x.to(self.device)[:N]
+ # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ recon = self.decode_first_stage(z)
+ # log["reconstructions"] = rearrange(
+ # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
+ # )
+ log["reconstructions"] = recon
+ log.update(self.log_conditionings(batch, N))
+ log["pixelnerf_rgb"] = c["rgb"]
+
+ for k in ["crossattn", "concat", "vector"]:
+ if k in c:
+ c[k] = c[k][:N]
+ uc[k] = uc[k][:N]
+
+ # for k in c:
+ # if isinstance(c[k], torch.Tensor):
+ # if k == "vector":
+ # end = N
+ # else:
+ # end = n
+ # c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
+
+ # # for k in c:
+ # # print(c[k].shape)
+
+ # breakpoint()
+ # for k in ["crossattn", "concat"]:
+ # c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
+ # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
+ # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
+ # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
+
+ # for k in c:
+ # print(c[k].shape)
+ if sample:
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ log["samples"] = samples
+ return log
diff --git a/sgm/models/video_diffusion.py b/sgm/models/video_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dbaa4a6d99e44fb2662f13e7cb5ca3ff9b0939e
--- /dev/null
+++ b/sgm/models/video_diffusion.py
@@ -0,0 +1,503 @@
+import re
+import math
+from contextlib import contextmanager
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import WandbLogger
+import torch
+from omegaconf import ListConfig, OmegaConf
+from safetensors.torch import load_file as load_safetensors
+from torch.optim.lr_scheduler import LambdaLR
+from torchvision.utils import make_grid
+from einops import rearrange, repeat
+
+from ..modules import UNCONDITIONAL_CONFIG
+from ..modules.autoencoding.temporal_ae import VideoDecoder
+from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER
+from ..modules.ema import LitEma
+from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder
+from ..util import (
+ default,
+ disabled_train,
+ get_obj_from_str,
+ instantiate_from_config,
+ log_txt_as_img,
+ video_frames_as_grid,
+)
+
+
+def flatten_for_video(input):
+ return input.flatten()
+
+
+class DiffusionEngine(pl.LightningModule):
+ def __init__(
+ self,
+ network_config,
+ denoiser_config,
+ first_stage_config,
+ conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None,
+ network_wrapper: Union[None, str] = None,
+ ckpt_path: Union[None, str] = None,
+ use_ema: bool = False,
+ ema_decay_rate: float = 0.9999,
+ scale_factor: float = 1.0,
+ disable_first_stage_autocast=False,
+ input_key: str = "frames", # for video inputs
+ log_keys: Union[List, None] = None,
+ no_cond_log: bool = False,
+ compile_model: bool = False,
+ en_and_decode_n_samples_a_time: Optional[int] = None,
+ load_last_embedder: bool = False,
+ from_scratch: bool = False,
+ ):
+ super().__init__()
+ self.log_keys = log_keys
+ self.input_key = input_key
+ self.optimizer_config = default(
+ optimizer_config, {"target": "torch.optim.AdamW"}
+ )
+ model = instantiate_from_config(network_config)
+ self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
+ model, compile_model=compile_model
+ )
+
+ self.denoiser = instantiate_from_config(denoiser_config)
+ self.sampler = (
+ instantiate_from_config(sampler_config)
+ if sampler_config is not None
+ else None
+ )
+ self.conditioner = instantiate_from_config(
+ default(conditioner_config, UNCONDITIONAL_CONFIG)
+ )
+ self.scheduler_config = scheduler_config
+ self._init_first_stage(first_stage_config)
+
+ self.loss_fn = (
+ instantiate_from_config(loss_fn_config)
+ if loss_fn_config is not None
+ else None
+ )
+
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model, decay=ema_decay_rate)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.scale_factor = scale_factor
+ self.disable_first_stage_autocast = disable_first_stage_autocast
+ self.no_cond_log = no_cond_log
+
+ self.load_last_embedder = load_last_embedder
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, from_scratch)
+
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+
+ def _load_last_embedder(self, original_state_dict):
+ original_module_name = "conditioner.embedders.3"
+ state_dict = dict()
+ for k, v in original_state_dict.items():
+ m = re.match(rf"^{original_module_name}\.(.*)$", k)
+ if m is None:
+ continue
+ state_dict[m.group(1)] = v
+
+ idx = -1
+ for i in range(len(self.conditioner.embedders)):
+ if isinstance(
+ self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder
+ ):
+ idx = i
+
+ print(f"Embedder [{idx}] is the frame encoder, make sure this is expected")
+
+ self.conditioner.embedders[idx].load_state_dict(state_dict)
+
+ def init_from_ckpt(
+ self,
+ path: str,
+ from_scratch: bool = False,
+ ) -> None:
+ if path.endswith("ckpt"):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ elif path.endswith("safetensors"):
+ sd = load_safetensors(path)
+ else:
+ raise NotImplementedError
+
+ deleted_keys = []
+ for k, v in self.state_dict().items():
+ # resolve shape dismatch
+ if k in sd:
+ if v.shape != sd[k].shape:
+ del sd[k]
+ deleted_keys.append(k)
+
+ if from_scratch:
+ new_sd = {}
+ for k in sd:
+ if "first_stage_model" in k:
+ new_sd[k] = sd[k]
+ sd = new_sd
+ print(sd.keys())
+
+ if len(deleted_keys) > 0:
+ print(f"Deleted Keys: {deleted_keys}")
+
+ missing, unexpected = self.load_state_dict(sd, strict=False)
+ print(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+ if len(deleted_keys) > 0:
+ print(f"Deleted Keys: {deleted_keys}")
+
+ if (len(missing) > 0 or len(unexpected) > 0) and self.load_last_embedder:
+ # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id)
+ print("Modified embedder to support 3d spiral video inputs")
+ self._load_last_embedder(sd)
+
+ def _init_first_stage(self, config):
+ model = instantiate_from_config(config).eval()
+ model.train = disabled_train
+ for param in model.parameters():
+ param.requires_grad = False
+ self.first_stage_model = model
+
+ def get_input(self, batch):
+ # assuming unified data format, dataloader returns a dict.
+ # image tensors should be scaled to -1 ... 1 and in bchw format
+ return batch[self.input_key]
+
+ @torch.no_grad()
+ def decode_first_stage(self, z):
+ z = 1.0 / self.scale_factor * z
+ is_video_input = False
+ bs = z.shape[0]
+ if z.dim() == 5:
+ is_video_input = True
+ # for video diffusion
+ z = rearrange(z, "b t c h w -> (b t) c h w")
+ n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0])
+
+ n_rounds = math.ceil(z.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ if isinstance(self.first_stage_model.decoder, VideoDecoder):
+ kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])}
+ else:
+ kwargs = {}
+ out = self.first_stage_model.decode(
+ z[n * n_samples : (n + 1) * n_samples], **kwargs
+ )
+ all_out.append(out)
+ out = torch.cat(all_out, dim=0)
+
+ if is_video_input:
+ out = rearrange(out, "(b t) c h w -> b t c h w", b=bs)
+
+ return out
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ if self.input_key == "latents":
+ return x * self.scale_factor
+
+ bs = x.shape[0]
+ is_video_input = False
+ if x.dim() == 5:
+ is_video_input = True
+ # for video diffusion
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+ n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0])
+ n_rounds = math.ceil(x.shape[0] / n_samples)
+ all_out = []
+ with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
+ for n in range(n_rounds):
+ out = self.first_stage_model.encode(
+ x[n * n_samples : (n + 1) * n_samples]
+ )
+ all_out.append(out)
+ z = torch.cat(all_out, dim=0)
+ z = self.scale_factor * z
+
+ # if is_video_input:
+ # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs)
+
+ return z
+
+ def forward(self, x, batch):
+ loss, model_output = self.loss_fn(
+ self.model,
+ self.denoiser,
+ self.conditioner,
+ x,
+ batch,
+ return_model_output=True,
+ )
+ loss_mean = loss.mean()
+ loss_dict = {"loss": loss_mean, "model_output": model_output}
+ return loss_mean, loss_dict
+
+ def shared_step(self, batch: Dict) -> Any:
+ # TODO: move this shit to collate_fn in dataloader
+ # if "fps_id" in batch:
+ # batch["fps_id"] = flatten_for_video(batch["fps_id"])
+ # if "motion_bucket_id" in batch:
+ # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"])
+ # if "cond_aug" in batch:
+ # batch["cond_aug"] = flatten_for_video(batch["cond_aug"])
+ x = self.get_input(batch)
+ x = self.encode_first_stage(x)
+ # ## debug
+ # x_recon = self.decode_first_stage(x)
+ # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg")
+ # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg")
+ # ## debug
+ batch["global_step"] = self.global_step
+ # breakpoint()
+ loss, loss_dict = self(x, batch)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ with torch.no_grad():
+ if "model_output" in loss_dict:
+ if batch_idx % 100 == 0:
+ if isinstance(self.logger, WandbLogger):
+ model_output = loss_dict["model_output"].detach()[
+ : batch["num_video_frames"]
+ ]
+ recons = (
+ (self.decode_first_stage(model_output) + 1.0) / 2.0
+ ).clamp(0.0, 1.0)
+ recon_grid = make_grid(recons, nrow=4)
+ self.logger.log_image(
+ key=f"train/model_output_recon",
+ images=[recon_grid],
+ step=self.global_step,
+ )
+ del loss_dict["model_output"]
+
+ self.log_dict(
+ loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ self.log(
+ "global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False,
+ )
+
+ if self.scheduler_config is not None:
+ lr = self.optimizers().param_groups[0]["lr"]
+ self.log(
+ "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False
+ )
+
+ return loss
+
+ def on_train_start(self, *args, **kwargs):
+ if self.sampler is None or self.loss_fn is None:
+ raise ValueError("Sampler and loss function need to be set for training.")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def instantiate_optimizer_from_config(self, params, lr, cfg):
+ return get_obj_from_str(cfg["target"])(
+ params, lr=lr, **cfg.get("params", dict())
+ )
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ for embedder in self.conditioner.embedders:
+ if embedder.is_trainable:
+ params = params + list(embedder.parameters())
+ opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
+ if self.scheduler_config is not None:
+ scheduler = instantiate_from_config(self.scheduler_config)
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
+ "interval": "step",
+ "frequency": 1,
+ }
+ ]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def sample(
+ self,
+ cond: Dict,
+ uc: Union[Dict, None] = None,
+ batch_size: int = 16,
+ shape: Union[None, Tuple, List] = None,
+ **kwargs,
+ ):
+ randn = torch.randn(batch_size, *shape).to(self.device)
+
+ denoiser = lambda input, sigma, c: self.denoiser(
+ self.model, input, sigma, c, **kwargs
+ )
+ samples = self.sampler(denoiser, randn, cond, uc=uc)
+ return samples
+
+ @torch.no_grad()
+ def log_conditionings(self, batch: Dict, n: int) -> Dict:
+ """
+ Defines heuristics to log different conditionings.
+ These can be lists of strings (text-to-image), tensors, ints, ...
+ """
+ image_h, image_w = batch[self.input_key].shape[-2:]
+ log = dict()
+
+ for embedder in self.conditioner.embedders:
+ if (
+ (self.log_keys is None) or (embedder.input_key in self.log_keys)
+ ) and not self.no_cond_log:
+ x = batch[embedder.input_key][:n]
+ if isinstance(x, torch.Tensor):
+ if x.dim() == 1:
+ # class-conditional, convert integer to string
+ x = [str(x[i].item()) for i in range(x.shape[0])]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4)
+ elif x.dim() == 2:
+ # size and crop cond and the like
+ x = [
+ "x".join([str(xx) for xx in x[i].tolist()])
+ for i in range(x.shape[0])
+ ]
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ elif x.dim() == 4:
+ # image
+ xc = x
+ else:
+ pass
+ # breakpoint()
+ # raise NotImplementedError()
+ elif isinstance(x, (List, ListConfig)):
+ if isinstance(x[0], str):
+ # strings
+ xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20)
+ else:
+ raise NotImplementedError()
+ else:
+ raise NotImplementedError()
+ log[embedder.input_key] = xc
+ return log
+
+ # for video diffusions will be logging frames of a video
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Dict,
+ N: int = 1,
+ sample: bool = True,
+ ucg_keys: List[str] = None,
+ **kwargs,
+ ) -> Dict:
+ # # debug
+ # return {}
+ # # debug
+ assert "num_video_frames" in batch, "num_video_frames must be in batch"
+ num_video_frames = batch["num_video_frames"]
+ conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
+ if ucg_keys:
+ assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
+ "Each defined ucg key for sampling must be in the provided conditioner input keys,"
+ f"but we have {ucg_keys} vs. {conditioner_input_keys}"
+ )
+ else:
+ ucg_keys = conditioner_input_keys
+ log = dict()
+
+ x = self.get_input(batch)
+
+ c, uc = self.conditioner.get_unconditional_conditioning(
+ batch,
+ force_uc_zero_embeddings=ucg_keys
+ if len(self.conditioner.embedders) > 0
+ else [],
+ )
+
+ sampling_kwargs = {"num_video_frames": num_video_frames}
+ n = min(x.shape[0] // num_video_frames, N)
+ sampling_kwargs["image_only_indicator"] = torch.cat(
+ [batch["image_only_indicator"][:n]] * 2
+ )
+
+ N = min(x.shape[0] // num_video_frames, N) * num_video_frames
+ x = x.to(self.device)[:N]
+ # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames)
+ if self.input_key != "latents":
+ log["inputs"] = x
+ z = self.encode_first_stage(x)
+ recon = self.decode_first_stage(z)
+ # log["reconstructions"] = rearrange(
+ # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames
+ # )
+ log["reconstructions"] = recon
+ log.update(self.log_conditionings(batch, N))
+
+ for k in c:
+ if isinstance(c[k], torch.Tensor):
+ if k == "vector":
+ end = N
+ else:
+ end = n
+ c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc))
+
+ # for k in c:
+ # print(c[k].shape)
+
+ for k in ["crossattn", "concat"]:
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames)
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames)
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames)
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames)
+
+ # for k in c:
+ # print(c[k].shape)
+ if sample:
+ with self.ema_scope("Plotting"):
+ samples = self.sample(
+ c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs
+ )
+ samples = self.decode_first_stage(samples)
+ log["samples"] = samples
+ return log
diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aa9ad360acf32dab22989d81630b3eb7978abb1
--- /dev/null
+++ b/sgm/modules/__init__.py
@@ -0,0 +1,6 @@
+from .encoders.modules import GeneralConditioner, ExtraConditioner
+
+UNCONDITIONAL_CONFIG = {
+ "target": "sgm.modules.GeneralConditioner",
+ "params": {"emb_models": []},
+}
diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3b60cabce854b52527f6dee85ea4f0cb0951eb6
--- /dev/null
+++ b/sgm/modules/attention.py
@@ -0,0 +1,764 @@
+import logging
+import math
+from inspect import isfunction
+from typing import Any, Optional
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from packaging import version
+from torch import nn
+
+# from torch.utils.checkpoint import checkpoint
+
+checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
+
+
+logpy = logging.getLogger(__name__)
+
+if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ SDP_IS_AVAILABLE = True
+ from torch.backends.cuda import SDPBackend, sdp_kernel
+
+ BACKEND_MAP = {
+ SDPBackend.MATH: {
+ "enable_math": True,
+ "enable_flash": False,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.FLASH_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": True,
+ "enable_mem_efficient": False,
+ },
+ SDPBackend.EFFICIENT_ATTENTION: {
+ "enable_math": False,
+ "enable_flash": False,
+ "enable_mem_efficient": True,
+ },
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
+ }
+else:
+ from contextlib import nullcontext
+
+ SDP_IS_AVAILABLE = False
+ sdp_kernel = nullcontext
+ BACKEND_MAP = {}
+ logpy.warn(
+ f"No SDP backend available, likely because you are running in pytorch "
+ f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
+ f"You might want to consider upgrading."
+ )
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ logpy.warn("no module 'xformers'. Processing without...")
+
+# from .diffusionmodules.util import mixed_checkpoint as checkpoint
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = (
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
+ if not glu
+ else GEGLU(dim, inner_dim)
+ )
+
+ self.net = nn.Sequential(
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class LinearAttention(nn.Module):
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
+ )
+ k = k.softmax(dim=-1)
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
+ out = rearrange(
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
+ )
+ return self.to_out(out)
+
+
+class SelfAttention(nn.Module):
+ ATTENTION_MODES = ("xformers", "torch", "math")
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[float] = None,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ attn_mode: str = "xformers",
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ assert attn_mode in self.ATTENTION_MODES
+ self.attn_mode = attn_mode
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, L, C = x.shape
+
+ qkv = self.qkv(x)
+ if self.attn_mode == "torch":
+ qkv = rearrange(
+ qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
+ ).float()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = rearrange(x, "B H L D -> B L (H D)")
+ elif self.attn_mode == "xformers":
+ qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
+ x = xformers.ops.memory_efficient_attention(q, k, v)
+ x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
+ elif self.attn_mode == "math":
+ qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
+ q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, L, C)
+ else:
+ raise NotImplemented
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, "b c h w -> b (h w) c")
+ k = rearrange(k, "b c h w -> b c (h w)")
+ w_ = torch.einsum("bij,bjk->bik", q, k)
+
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, "b c h w -> b c (h w)")
+ w_ = rearrange(w_, "b i j -> b j i")
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.0,
+ backend=None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.backend = backend
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ h = self.heads
+
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
+ )
+
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
+
+ ## old
+ """
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ """
+ ## new
+ with sdp_kernel(**BACKEND_MAP[self.backend]):
+ # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
+ out = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask
+ ) # scale is dim_head ** -0.5 per default
+
+ del q, k, v
+ out = rearrange(out, "b h n d -> b n (h d)", h=h)
+
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(
+ self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
+ ):
+ super().__init__()
+ logpy.debug(
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
+ f"context_dim is {context_dim} and using {heads} heads with a "
+ f"dimension of {dim_head}."
+ )
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
+ )
+ self.attention_op: Optional[Any] = None
+
+ def forward(
+ self,
+ x,
+ context=None,
+ mask=None,
+ additional_tokens=None,
+ n_times_crossframe_attn_in_self=0,
+ ):
+ if additional_tokens is not None:
+ # get the number of masked tokens at the beginning of the output sequence
+ n_tokens_to_mask = additional_tokens.shape[1]
+ # add additional token
+ x = torch.cat([additional_tokens, x], dim=1)
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ if n_times_crossframe_attn_in_self:
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
+ # n_cp = x.shape[0]//n_times_crossframe_attn_in_self
+ k = repeat(
+ k[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+ v = repeat(
+ v[::n_times_crossframe_attn_in_self],
+ "b ... -> (b n) ...",
+ n=n_times_crossframe_attn_in_self,
+ )
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ if version.parse(xformers.__version__) >= version.parse("0.0.21"):
+ # NOTE: workaround for
+ # https://github.com/facebookresearch/xformers/issues/845
+ max_bs = 32768
+ N = q.shape[0]
+ n_batches = math.ceil(N / max_bs)
+ out = list()
+ for i_batch in range(n_batches):
+ batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
+ out.append(
+ xformers.ops.memory_efficient_attention(
+ q[batch],
+ k[batch],
+ v[batch],
+ attn_bias=None,
+ op=self.attention_op,
+ )
+ )
+ out = torch.cat(out, 0)
+ else:
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ # TODO: Use this directly in the attention operation, as a bias
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ if additional_tokens is not None:
+ # remove additional token
+ out = out[:, n_tokens_to_mask:]
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attn_mode="softmax",
+ sdp_backend=None,
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
+ logpy.warn(
+ f"Attention mode '{attn_mode}' is not available. Falling "
+ f"back to native attention. This is not a problem in "
+ f"Pytorch >= 2.0. FYI, you are running with PyTorch "
+ f"version {torch.__version__}."
+ )
+ attn_mode = "softmax"
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
+ logpy.warn(
+ "We do not support vanilla attention anymore, as it is too "
+ "expensive. Sorry."
+ )
+ if not XFORMERS_IS_AVAILABLE:
+ assert (
+ False
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ else:
+ logpy.info("Falling back to xformers efficient attention.")
+ attn_mode = "softmax-xformers"
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
+ else:
+ assert sdp_backend is None
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None,
+ backend=sdp_backend,
+ ) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ backend=sdp_backend,
+ ) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ logpy.debug(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ kwargs = {"x": x}
+
+ if context is not None:
+ kwargs.update({"context": context})
+
+ if additional_tokens is not None:
+ kwargs.update({"additional_tokens": additional_tokens})
+
+ if n_times_crossframe_attn_in_self:
+ kwargs.update(
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
+ )
+
+ # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
+ if self.checkpoint:
+ # inputs = {"x": x, "context": context}
+ return checkpoint(self._forward, x, context)
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
+ else:
+ return self._forward(**kwargs)
+
+ def _forward(
+ self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
+ ):
+ x = (
+ self.attn1(
+ self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ additional_tokens=additional_tokens,
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
+ if not self.disable_self_attn
+ else 0,
+ )
+ + x
+ )
+ x = (
+ self.attn2(
+ self.norm2(x), context=context, additional_tokens=additional_tokens
+ )
+ + x
+ )
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class BasicTransformerSingleLayerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ attn_mode="softmax",
+ ):
+ super().__init__()
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ )
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ # inputs = {"x": x, "context": context}
+ # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
+ return checkpoint(self._forward, x, context)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context) + x
+ x = self.ff(self.norm2(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ context_dim=None,
+ disable_self_attn=False,
+ use_linear=False,
+ attn_type="softmax",
+ use_checkpoint=True,
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
+ sdp_backend=None,
+ ):
+ super().__init__()
+ logpy.debug(
+ f"constructing {self.__class__.__name__} of depth {depth} w/ "
+ f"{in_channels} channels and {n_heads} heads."
+ )
+
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ if exists(context_dim) and isinstance(context_dim, list):
+ if depth != len(context_dim):
+ logpy.warn(
+ f"{self.__class__.__name__}: Found context dims "
+ f"{context_dim} of depth {len(context_dim)}, which does not "
+ f"match the specified 'depth' of {depth}. Setting context_dim "
+ f"to {depth * [context_dim[0]]} now."
+ )
+ # depth does not match context dims.
+ assert all(
+ map(lambda x: x == context_dim[0], context_dim)
+ ), "need homogenous context_dim to match depth automatically"
+ context_dim = depth * [context_dim[0]]
+ elif context_dim is None:
+ context_dim = [None] * depth
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
+ )
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn,
+ attn_mode=attn_type,
+ checkpoint=use_checkpoint,
+ sdp_backend=sdp_backend,
+ )
+ for d in range(depth)
+ ]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ )
+ else:
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ if i > 0 and len(context) == 1:
+ i = 0 # use same context for each block
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class SimpleTransformer(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ depth: int,
+ heads: int,
+ dim_head: int,
+ context_dim: Optional[int] = None,
+ dropout: float = 0.0,
+ checkpoint: bool = True,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ BasicTransformerBlock(
+ dim,
+ heads,
+ dim_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ attn_mode="softmax-xformers",
+ checkpoint=checkpoint,
+ )
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ for layer in self.layers:
+ x = layer(x, context)
+ return x
diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b316c7aa6ea1c5e31a58987aa3b37b2933eb7e2
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/__init__.py
@@ -0,0 +1,7 @@
+__all__ = [
+ "GeneralLPIPSWithDiscriminator",
+ "LatentLPIPS",
+]
+
+from .discriminator_loss import GeneralLPIPSWithDiscriminator
+from .lpips import LatentLPIPS
diff --git a/sgm/modules/autoencoding/losses/discriminator_loss.py b/sgm/modules/autoencoding/losses/discriminator_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..09b6829267bf8e4d98c3f29abdc19e58dcbcbe64
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/discriminator_loss.py
@@ -0,0 +1,306 @@
+from typing import Dict, Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision
+from einops import rearrange
+from matplotlib import colormaps
+from matplotlib import pyplot as plt
+
+from ....util import default, instantiate_from_config
+from ..lpips.loss.lpips import LPIPS
+from ..lpips.model.model import weights_init
+from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss
+
+
+class GeneralLPIPSWithDiscriminator(nn.Module):
+ def __init__(
+ self,
+ disc_start: int,
+ logvar_init: float = 0.0,
+ disc_num_layers: int = 3,
+ disc_in_channels: int = 3,
+ disc_factor: float = 1.0,
+ disc_weight: float = 1.0,
+ perceptual_weight: float = 1.0,
+ disc_loss: str = "hinge",
+ scale_input_to_tgt_size: bool = False,
+ dims: int = 2,
+ learn_logvar: bool = False,
+ regularization_weights: Union[None, Dict[str, float]] = None,
+ additional_log_keys: Optional[List[str]] = None,
+ discriminator_config: Optional[Dict] = None,
+ ):
+ super().__init__()
+ self.dims = dims
+ if self.dims > 2:
+ print(
+ f"running with dims={dims}. This means that for perceptual loss "
+ f"calculation, the LPIPS loss will be applied to each frame "
+ f"independently."
+ )
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ assert disc_loss in ["hinge", "vanilla"]
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ # output log variance
+ self.logvar = nn.Parameter(
+ torch.full((), logvar_init), requires_grad=learn_logvar
+ )
+ self.learn_logvar = learn_logvar
+
+ discriminator_config = default(
+ discriminator_config,
+ {
+ "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator",
+ "params": {
+ "input_nc": disc_in_channels,
+ "n_layers": disc_num_layers,
+ "use_actnorm": False,
+ },
+ },
+ )
+
+ self.discriminator = instantiate_from_config(discriminator_config).apply(
+ weights_init
+ )
+ self.discriminator_iter_start = disc_start
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.regularization_weights = default(regularization_weights, {})
+
+ self.forward_keys = [
+ "optimizer_idx",
+ "global_step",
+ "last_layer",
+ "split",
+ "regularization_log",
+ ]
+
+ self.additional_log_keys = set(default(additional_log_keys, []))
+ self.additional_log_keys.update(set(self.regularization_weights.keys()))
+
+ def get_trainable_parameters(self) -> Iterator[nn.Parameter]:
+ return self.discriminator.parameters()
+
+ def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]:
+ if self.learn_logvar:
+ yield self.logvar
+ yield from ()
+
+ @torch.no_grad()
+ def log_images(
+ self, inputs: torch.Tensor, reconstructions: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ # calc logits of real/fake
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ if len(logits_real.shape) < 4:
+ # Non patch-discriminator
+ return dict()
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ # -> (b, 1, h, w)
+
+ # parameters for colormapping
+ high = max(logits_fake.abs().max(), logits_real.abs().max()).item()
+ cmap = colormaps["PiYG"] # diverging colormap
+
+ def to_colormap(logits: torch.Tensor) -> torch.Tensor:
+ """(b, 1, ...) -> (b, 3, ...)"""
+ logits = (logits + high) / (2 * high)
+ logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel
+ # -> (b, 1, ..., 3)
+ logits = torch.from_numpy(logits_np).to(logits.device)
+ return rearrange(logits, "b 1 ... c -> b c ...")
+
+ logits_real = torch.nn.functional.interpolate(
+ logits_real,
+ size=inputs.shape[-2:],
+ mode="nearest",
+ antialias=False,
+ )
+ logits_fake = torch.nn.functional.interpolate(
+ logits_fake,
+ size=reconstructions.shape[-2:],
+ mode="nearest",
+ antialias=False,
+ )
+
+ # alpha value of logits for overlay
+ alpha_real = torch.abs(logits_real) / high
+ alpha_fake = torch.abs(logits_fake) / high
+ # -> (b, 1, h, w) in range [0, 0.5]
+ # alpha value of lines don't really matter, since the values are the same
+ # for both images and logits anyway
+ grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4)
+ grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4)
+ grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1)
+ # -> (1, h, w)
+ # blend logits and images together
+
+ # prepare logits for plotting
+ logits_real = to_colormap(logits_real)
+ logits_fake = to_colormap(logits_fake)
+ # resize logits
+ # -> (b, 3, h, w)
+
+ # make some grids
+ # add all logits to one plot
+ logits_real = torchvision.utils.make_grid(logits_real, nrow=4)
+ logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4)
+ # I just love how torchvision calls the number of columns `nrow`
+ grid_logits = torch.cat((logits_real, logits_fake), dim=1)
+ # -> (3, h, w)
+
+ grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4)
+ grid_images_fake = torchvision.utils.make_grid(
+ 0.5 * reconstructions + 0.5, nrow=4
+ )
+ grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1)
+ # -> (3, h, w) in range [0, 1]
+
+ grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images
+
+ # Create labeled colorbar
+ dpi = 100
+ height = 128 / dpi
+ width = grid_logits.shape[2] / dpi
+ fig, ax = plt.subplots(figsize=(width, height), dpi=dpi)
+ img = ax.imshow(np.array([[-high, high]]), cmap=cmap)
+ plt.colorbar(
+ img,
+ cax=ax,
+ orientation="horizontal",
+ fraction=0.9,
+ aspect=width / height,
+ pad=0.0,
+ )
+ img.set_visible(False)
+ fig.tight_layout()
+ fig.canvas.draw()
+ # manually convert figure to numpy
+ cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0
+ cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device)
+
+ # Add colorbar to plot
+ annotated_grid = torch.cat((grid_logits, cbar), dim=1)
+ blended_grid = torch.cat((grid_blend, cbar), dim=1)
+ return {
+ "vis_logits": 2 * annotated_grid[None, ...] - 1,
+ "vis_logits_blended": 2 * blended_grid[None, ...] - 1,
+ }
+
+ def calculate_adaptive_weight(
+ self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor
+ ) -> torch.Tensor:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ reconstructions: torch.Tensor,
+ *, # added because I changed the order here
+ regularization_log: Dict[str, torch.Tensor],
+ optimizer_idx: int,
+ global_step: int,
+ last_layer: torch.Tensor,
+ split: str = "train",
+ weights: Union[None, float, torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, dict]:
+ if self.scale_input_to_tgt_size:
+ inputs = torch.nn.functional.interpolate(
+ inputs, reconstructions.shape[2:], mode="bicubic", antialias=True
+ )
+
+ if self.dims > 2:
+ inputs, reconstructions = map(
+ lambda x: rearrange(x, "b c t h w -> (b t) c h w"),
+ (inputs, reconstructions),
+ )
+
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(
+ inputs.contiguous(), reconstructions.contiguous()
+ )
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+
+ nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if global_step >= self.discriminator_iter_start or not self.training:
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ g_loss = -torch.mean(logits_fake)
+ if self.training:
+ d_weight = self.calculate_adaptive_weight(
+ nll_loss, g_loss, last_layer=last_layer
+ )
+ else:
+ d_weight = torch.tensor(1.0)
+ else:
+ d_weight = torch.tensor(0.0)
+ g_loss = torch.tensor(0.0, requires_grad=True)
+
+ loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss
+ log = dict()
+ for k in regularization_log:
+ if k in self.regularization_weights:
+ loss = loss + self.regularization_weights[k] * regularization_log[k]
+ if k in self.additional_log_keys:
+ log[f"{split}/{k}"] = regularization_log[k].detach().float().mean()
+
+ log.update(
+ {
+ f"{split}/loss/total": loss.clone().detach().mean(),
+ f"{split}/loss/nll": nll_loss.detach().mean(),
+ f"{split}/loss/rec": rec_loss.detach().mean(),
+ f"{split}/loss/g": g_loss.detach().mean(),
+ f"{split}/scalars/logvar": self.logvar.detach(),
+ f"{split}/scalars/d_weight": d_weight.detach(),
+ }
+ )
+
+ return loss, log
+ elif optimizer_idx == 1:
+ # second pass for discriminator update
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+
+ if global_step >= self.discriminator_iter_start or not self.training:
+ d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake)
+ else:
+ d_loss = torch.tensor(0.0, requires_grad=True)
+
+ log = {
+ f"{split}/loss/disc": d_loss.clone().detach().mean(),
+ f"{split}/logits/real": logits_real.detach().mean(),
+ f"{split}/logits/fake": logits_fake.detach().mean(),
+ }
+ return d_loss, log
+ else:
+ raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}")
+
+ def get_nll_loss(
+ self,
+ rec_loss: torch.Tensor,
+ weights: Optional[Union[float, torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
+ weighted_nll_loss = nll_loss
+ if weights is not None:
+ weighted_nll_loss = weights * nll_loss
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+
+ return nll_loss, weighted_nll_loss
diff --git a/sgm/modules/autoencoding/losses/lpips.py b/sgm/modules/autoencoding/losses/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..b329fcc2ee9477f0122aa7d066866cdfe71ce521
--- /dev/null
+++ b/sgm/modules/autoencoding/losses/lpips.py
@@ -0,0 +1,73 @@
+import torch
+import torch.nn as nn
+
+from ....util import default, instantiate_from_config
+from ..lpips.loss.lpips import LPIPS
+
+
+class LatentLPIPS(nn.Module):
+ def __init__(
+ self,
+ decoder_config,
+ perceptual_weight=1.0,
+ latent_weight=1.0,
+ scale_input_to_tgt_size=False,
+ scale_tgt_to_input_size=False,
+ perceptual_weight_on_inputs=0.0,
+ ):
+ super().__init__()
+ self.scale_input_to_tgt_size = scale_input_to_tgt_size
+ self.scale_tgt_to_input_size = scale_tgt_to_input_size
+ self.init_decoder(decoder_config)
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+ self.latent_weight = latent_weight
+ self.perceptual_weight_on_inputs = perceptual_weight_on_inputs
+
+ def init_decoder(self, config):
+ self.decoder = instantiate_from_config(config)
+ if hasattr(self.decoder, "encoder"):
+ del self.decoder.encoder
+
+ def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"):
+ log = dict()
+ loss = (latent_inputs - latent_predictions) ** 2
+ log[f"{split}/latent_l2_loss"] = loss.mean().detach()
+ image_reconstructions = None
+ if self.perceptual_weight > 0.0:
+ image_reconstructions = self.decoder.decode(latent_predictions)
+ image_targets = self.decoder.decode(latent_inputs)
+ perceptual_loss = self.perceptual_loss(
+ image_targets.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = (
+ self.latent_weight * loss.mean()
+ + self.perceptual_weight * perceptual_loss.mean()
+ )
+ log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach()
+
+ if self.perceptual_weight_on_inputs > 0.0:
+ image_reconstructions = default(
+ image_reconstructions, self.decoder.decode(latent_predictions)
+ )
+ if self.scale_input_to_tgt_size:
+ image_inputs = torch.nn.functional.interpolate(
+ image_inputs,
+ image_reconstructions.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+ elif self.scale_tgt_to_input_size:
+ image_reconstructions = torch.nn.functional.interpolate(
+ image_reconstructions,
+ image_inputs.shape[2:],
+ mode="bicubic",
+ antialias=True,
+ )
+
+ perceptual_loss2 = self.perceptual_loss(
+ image_inputs.contiguous(), image_reconstructions.contiguous()
+ )
+ loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean()
+ log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach()
+ return loss, log
diff --git a/sgm/modules/autoencoding/lpips/__init__.py b/sgm/modules/autoencoding/lpips/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/lpips/loss/.gitignore b/sgm/modules/autoencoding/lpips/loss/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/loss/.gitignore
@@ -0,0 +1 @@
+vgg.pth
\ No newline at end of file
diff --git a/sgm/modules/autoencoding/lpips/loss/LICENSE b/sgm/modules/autoencoding/lpips/loss/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..924cfc85b8d63ef538f5676f830a2a8497932108
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/loss/LICENSE
@@ -0,0 +1,23 @@
+Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sgm/modules/autoencoding/lpips/loss/__init__.py b/sgm/modules/autoencoding/lpips/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e34f3d083674f675a5ca024e9bd27fb77e2b6b5
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/loss/lpips.py
@@ -0,0 +1,147 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from ..util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss")
+ self.load_state_dict(
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
+ )
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(
+ torch.load(ckpt, map_location=torch.device("cpu")), strict=False
+ )
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(
+ outs1[kk]
+ )
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [
+ spatial_average(lins[kk].model(diffs[kk]), keepdim=True)
+ for kk in range(len(self.chns))
+ ]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer(
+ "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]
+ )
+ self.register_buffer(
+ "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]
+ )
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """A single linear layer which does a 1x1 conv"""
+
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = (
+ [
+ nn.Dropout(),
+ ]
+ if (use_dropout)
+ else []
+ )
+ layers += [
+ nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
+ ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple(
+ "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]
+ )
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x, eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
+ return x / (norm_factor + eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2, 3], keepdim=keepdim)
diff --git a/sgm/modules/autoencoding/lpips/model/LICENSE b/sgm/modules/autoencoding/lpips/model/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..4b356e66b5aa689b339f1a80a9f1b5ba378003bb
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/model/LICENSE
@@ -0,0 +1,58 @@
+Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+--------------------------- LICENSE FOR pix2pix --------------------------------
+BSD License
+
+For pix2pix software
+Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+----------------------------- LICENSE FOR DCGAN --------------------------------
+BSD License
+
+For dcgan.torch software
+
+Copyright (c) 2015, Facebook, Inc. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+
+Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
+
+Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
+
+Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
\ No newline at end of file
diff --git a/sgm/modules/autoencoding/lpips/model/__init__.py b/sgm/modules/autoencoding/lpips/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/autoencoding/lpips/model/model.py b/sgm/modules/autoencoding/lpips/model/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..66357d4e627f9a69a5abbbad15546c96fcd758fe
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/model/model.py
@@ -0,0 +1,88 @@
+import functools
+
+import torch.nn as nn
+
+from ..util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find("BatchNorm") != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if (
+ type(norm_layer) == functools.partial
+ ): # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
+ nn.LeakyReLU(0.2, True),
+ ]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2**n_layers, 8)
+ sequence += [
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=1,
+ padding=padw,
+ bias=use_bias,
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ ] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/sgm/modules/autoencoding/lpips/util.py b/sgm/modules/autoencoding/lpips/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..49c76e370bf16888ab61f42844b3c9f14ad9014c
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/util.py
@@ -0,0 +1,128 @@
+import hashlib
+import os
+
+import requests
+import torch
+import torch.nn as nn
+from tqdm import tqdm
+
+URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
+
+CKPT_MAP = {"vgg_lpips": "vgg.pth"}
+
+MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class ActNorm(nn.Module):
+ def __init__(
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
+ ):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height * width * torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:, :, None, None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
diff --git a/sgm/modules/autoencoding/lpips/vqperceptual.py b/sgm/modules/autoencoding/lpips/vqperceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..6195f0a6ed7ee6fd32c1bccea071e6075e95ee43
--- /dev/null
+++ b/sgm/modules/autoencoding/lpips/vqperceptual.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn.functional as F
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real))
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
+ )
+ return d_loss
diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff2b1815a5ba88892375e8ec9bedacea49024113
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/__init__.py
@@ -0,0 +1,31 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ....modules.distributions.distributions import \
+ DiagonalGaussianDistribution
+from .base import AbstractRegularizer
+
+
+class DiagonalGaussianRegularizer(AbstractRegularizer):
+ def __init__(self, sample: bool = True):
+ super().__init__()
+ self.sample = sample
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ log = dict()
+ posterior = DiagonalGaussianDistribution(z)
+ if self.sample:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ kl_loss = posterior.kl()
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
+ log["kl_loss"] = kl_loss
+ return z, log
diff --git a/sgm/modules/autoencoding/regularizers/base.py b/sgm/modules/autoencoding/regularizers/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca681bb3c1f4818b57e956e31b98f76077ccb67
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/base.py
@@ -0,0 +1,40 @@
+from abc import abstractmethod
+from typing import Any, Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class AbstractRegularizer(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ raise NotImplementedError()
+
+ @abstractmethod
+ def get_trainable_parameters(self) -> Any:
+ raise NotImplementedError()
+
+
+class IdentityRegularizer(AbstractRegularizer):
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
+ return z, dict()
+
+ def get_trainable_parameters(self) -> Any:
+ yield from ()
+
+
+def measure_perplexity(
+ predicted_indices: torch.Tensor, num_centroids: int
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
+ encodings = (
+ F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids)
+ )
+ avg_probs = encodings.mean(0)
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
+ cluster_use = torch.sum(avg_probs > 0)
+ return perplexity, cluster_use
diff --git a/sgm/modules/autoencoding/regularizers/quantize.py b/sgm/modules/autoencoding/regularizers/quantize.py
new file mode 100644
index 0000000000000000000000000000000000000000..86a4dbdd10101b24f03bba134c4f8d2ab007f0db
--- /dev/null
+++ b/sgm/modules/autoencoding/regularizers/quantize.py
@@ -0,0 +1,487 @@
+import logging
+from abc import abstractmethod
+from typing import Dict, Iterator, Literal, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from torch import einsum
+
+from .base import AbstractRegularizer, measure_perplexity
+
+logpy = logging.getLogger(__name__)
+
+
+class AbstractQuantizer(AbstractRegularizer):
+ def __init__(self):
+ super().__init__()
+ # Define these in your init
+ # shape (N,)
+ self.used: Optional[torch.Tensor]
+ self.re_embed: int
+ self.unknown_index: Union[Literal["random"], int]
+
+ def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor:
+ assert self.used is not None, "You need to define used indices for remap"
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
+ device=new.device
+ )
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor:
+ assert self.used is not None, "You need to define used indices for remap"
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ @abstractmethod
+ def get_codebook_entry(
+ self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
+ ) -> torch.Tensor:
+ raise NotImplementedError()
+
+ def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]:
+ yield from self.parameters()
+
+
+class GumbelQuantizer(AbstractQuantizer):
+ """
+ credit to @karpathy:
+ https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def __init__(
+ self,
+ num_hiddens: int,
+ embedding_dim: int,
+ n_embed: int,
+ straight_through: bool = True,
+ kl_weight: float = 5e-4,
+ temp_init: float = 1.0,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ loss_key: str = "loss/vq",
+ ) -> None:
+ super().__init__()
+
+ self.loss_key = loss_key
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_embed
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ def forward(
+ self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False
+ ) -> Tuple[torch.Tensor, Dict]:
+ # force hard = True when we are in eval mode, as we must quantize.
+ # actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+ out_dict = {}
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:, self.used, ...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:, self.used, ...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = (
+ self.kl_weight
+ * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+ )
+ out_dict[self.loss_key] = diff
+
+ ind = soft_one_hot.argmax(dim=1)
+ out_dict["indices"] = ind
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+
+ if return_logits:
+ out_dict["logits"] = logits
+
+ return z_q, out_dict
+
+ def get_codebook_entry(self, indices, shape):
+ # TODO: shape not yet optional
+ b, h, w, c = shape
+ assert b * h * w == indices.shape[0]
+ indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = (
+ F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ )
+ z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer(AbstractQuantizer):
+ """
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term,
+ beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ def __init__(
+ self,
+ n_e: int,
+ e_dim: int,
+ beta: float = 0.25,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ sane_index_shape: bool = False,
+ log_perplexity: bool = False,
+ embedding_weight_norm: bool = False,
+ loss_key: str = "loss/vq",
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.loss_key = loss_key
+
+ if not embedding_weight_norm:
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+ else:
+ self.embedding = torch.nn.utils.weight_norm(
+ nn.Embedding(self.n_e, self.e_dim), dim=1
+ )
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_e
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ self.sane_index_shape = sane_index_shape
+ self.log_perplexity = log_perplexity
+
+ def forward(
+ self,
+ z: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict]:
+ do_reshape = z.ndim == 4
+ if do_reshape:
+ # # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, "b c h w -> b h w c").contiguous()
+
+ else:
+ assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined"
+ z = z.contiguous()
+
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = (
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
+ + torch.sum(self.embedding.weight**2, dim=1)
+ - 2
+ * torch.einsum(
+ "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n")
+ )
+ )
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ loss_dict = {}
+ if self.log_perplexity:
+ perplexity, cluster_usage = measure_perplexity(
+ min_encoding_indices.detach(), self.n_e
+ )
+ loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage})
+
+ # compute loss for embedding
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
+ (z_q - z.detach()) ** 2
+ )
+ loss_dict[self.loss_key] = loss
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ if do_reshape:
+ z_q = rearrange(z_q, "b h w c -> b c h w").contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z.shape[0], -1
+ ) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ if do_reshape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
+ )
+ else:
+ min_encoding_indices = rearrange(
+ min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0]
+ )
+
+ loss_dict["min_encoding_indices"] = min_encoding_indices
+
+ return z_q, loss_dict
+
+ def get_codebook_entry(
+ self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None
+ ) -> torch.Tensor:
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ assert shape is not None, "Need to give shape for remap"
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad=False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(
+ new_cluster_size, alpha=1 - self.decay
+ )
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ # normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(AbstractQuantizer):
+ def __init__(
+ self,
+ n_embed: int,
+ embedding_dim: int,
+ beta: float,
+ decay: float = 0.99,
+ eps: float = 1e-5,
+ remap: Optional[str] = None,
+ unknown_index: str = "random",
+ loss_key: str = "loss/vq",
+ ):
+ super().__init__()
+ self.codebook_dim = embedding_dim
+ self.num_tokens = n_embed
+ self.beta = beta
+ self.loss_key = loss_key
+
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ else:
+ self.used = None
+ self.re_embed = n_embed
+ if unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ else:
+ assert unknown_index == "random" or isinstance(
+ unknown_index, int
+ ), "unknown index needs to be 'random', 'extra' or any integer"
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.remap is not None:
+ logpy.info(
+ f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z, 'b c h w -> b h w c'
+ z = rearrange(z, "b c h w -> b h w c")
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (
+ z_flattened.pow(2).sum(dim=1, keepdim=True)
+ + self.embedding.weight.pow(2).sum(dim=1)
+ - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight)
+ ) # 'n d -> d n'
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ # EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ # EMA embedding average
+ embed_sum = encodings.transpose(0, 1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ # normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ # z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, "b h w c -> b c h w")
+
+ out_dict = {
+ self.loss_key: loss,
+ "encodings": encodings,
+ "encoding_indices": encoding_indices,
+ "perplexity": perplexity,
+ }
+
+ return z_q, out_dict
+
+
+class VectorQuantizerWithInputProjection(VectorQuantizer):
+ def __init__(
+ self,
+ input_dim: int,
+ n_codes: int,
+ codebook_dim: int,
+ beta: float = 1.0,
+ output_dim: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(n_codes, codebook_dim, beta, **kwargs)
+ self.proj_in = nn.Linear(input_dim, codebook_dim)
+ self.output_dim = output_dim
+ if output_dim is not None:
+ self.proj_out = nn.Linear(codebook_dim, output_dim)
+ else:
+ self.proj_out = nn.Identity()
+
+ def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
+ rearr = False
+ in_shape = z.shape
+
+ if z.ndim > 3:
+ rearr = self.output_dim is not None
+ z = rearrange(z, "b c ... -> b (...) c")
+ z = self.proj_in(z)
+ z_q, loss_dict = super().forward(z)
+
+ z_q = self.proj_out(z_q)
+ if rearr:
+ if len(in_shape) == 4:
+ z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1])
+ elif len(in_shape) == 5:
+ z_q = rearrange(
+ z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2]
+ )
+ else:
+ raise NotImplementedError(
+ f"rearranging not available for {len(in_shape)}-dimensional input."
+ )
+
+ return z_q, loss_dict
diff --git a/sgm/modules/autoencoding/temporal_ae.py b/sgm/modules/autoencoding/temporal_ae.py
new file mode 100644
index 0000000000000000000000000000000000000000..374373e2e4330846ffef28d9061dcc64f70d2722
--- /dev/null
+++ b/sgm/modules/autoencoding/temporal_ae.py
@@ -0,0 +1,349 @@
+from typing import Callable, Iterable, Union
+
+import torch
+from einops import rearrange, repeat
+
+from sgm.modules.diffusionmodules.model import (
+ XFORMERS_IS_AVAILABLE,
+ AttnBlock,
+ Decoder,
+ MemoryEfficientAttnBlock,
+ ResnetBlock,
+)
+from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
+from sgm.modules.video_attention import VideoTransformerBlock
+from sgm.util import partialclass
+
+
+class VideoResBlock(ResnetBlock):
+ def __init__(
+ self,
+ out_channels,
+ *args,
+ dropout=0.0,
+ video_kernel_size=3,
+ alpha=0.0,
+ merge_strategy="learned",
+ **kwargs,
+ ):
+ super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs)
+ if video_kernel_size is None:
+ video_kernel_size = [3, 1, 1]
+ self.time_stack = ResBlock(
+ channels=out_channels,
+ emb_channels=0,
+ dropout=dropout,
+ dims=3,
+ use_scale_shift_norm=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ kernel_size=video_kernel_size,
+ use_checkpoint=False,
+ skip_t_emb=True,
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, bs):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError()
+
+ def forward(self, x, temb, skip_video=False, timesteps=None):
+ if timesteps is None:
+ timesteps = self.timesteps
+
+ b, c, h, w = x.shape
+
+ x = super().forward(x, temb)
+
+ if not skip_video:
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+
+ x = self.time_stack(x, temb)
+
+ alpha = self.get_alpha(bs=b // timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix
+
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class AE3DConv(torch.nn.Conv2d):
+ def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ if isinstance(video_kernel_size, Iterable):
+ padding = [int(k // 2) for k in video_kernel_size]
+ else:
+ padding = int(video_kernel_size // 2)
+
+ self.time_mix_conv = torch.nn.Conv3d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=video_kernel_size,
+ padding=padding,
+ )
+
+ def forward(self, input, timesteps, skip_video=False):
+ x = super().forward(input)
+ if skip_video:
+ return x
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps)
+ x = self.time_mix_conv(x)
+ return rearrange(x, "b c t h w -> (b t) c h w")
+
+
+class VideoBlock(AttnBlock):
+ def __init__(
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
+ ):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=False,
+ ff_in=True,
+ attn_mode="softmax",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_video=False):
+ if skip_video:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock):
+ def __init__(
+ self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned"
+ ):
+ super().__init__(in_channels)
+ # no context, single headed, as in base class
+ self.time_mix_block = VideoTransformerBlock(
+ dim=in_channels,
+ n_heads=1,
+ d_head=in_channels,
+ checkpoint=False,
+ ff_in=True,
+ attn_mode="softmax-xformers",
+ )
+
+ time_embed_dim = self.in_channels * 4
+ self.video_time_embed = torch.nn.Sequential(
+ torch.nn.Linear(self.in_channels, time_embed_dim),
+ torch.nn.SiLU(),
+ torch.nn.Linear(time_embed_dim, self.in_channels),
+ )
+
+ self.merge_strategy = merge_strategy
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif self.merge_strategy == "learned":
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def forward(self, x, timesteps, skip_time_block=False):
+ if skip_time_block:
+ return super().forward(x)
+
+ x_in = x
+ x = self.attention(x)
+ h, w = x.shape[2:]
+ x = rearrange(x, "b c h w -> b (h w) c")
+
+ x_mix = x
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False)
+ emb = self.video_time_embed(t_emb) # b, n_channels
+ emb = emb[:, None, :]
+ x_mix = x_mix + emb
+
+ alpha = self.get_alpha()
+ x_mix = self.time_mix_block(x_mix, timesteps=timesteps)
+ x = alpha * x + (1.0 - alpha) * x_mix # alpha merge
+
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ x = self.proj_out(x)
+
+ return x_in + x
+
+ def get_alpha(
+ self,
+ ):
+ if self.merge_strategy == "fixed":
+ return self.mix_factor
+ elif self.merge_strategy == "learned":
+ return torch.sigmoid(self.mix_factor)
+ else:
+ raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}")
+
+
+def make_time_attn(
+ in_channels,
+ attn_type="vanilla",
+ attn_kwargs=None,
+ alpha: float = 0,
+ merge_strategy: str = "learned",
+):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ ], f"attn_type {attn_type} not supported for spatio-temporal attention"
+ print(
+ f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels"
+ )
+ if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers":
+ print(
+ f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. "
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
+ )
+ attn_type = "vanilla"
+
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return partialclass(
+ VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
+ )
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return partialclass(
+ MemoryEfficientVideoBlock,
+ in_channels,
+ alpha=alpha,
+ merge_strategy=merge_strategy,
+ )
+ else:
+ return NotImplementedError()
+
+
+class Conv2DWrapper(torch.nn.Conv2d):
+ def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor:
+ return super().forward(input)
+
+
+class VideoDecoder(Decoder):
+ available_time_modes = ["all", "conv-only", "attn-only"]
+
+ def __init__(
+ self,
+ *args,
+ video_kernel_size: Union[int, list] = 3,
+ alpha: float = 0.0,
+ merge_strategy: str = "learned",
+ time_mode: str = "conv-only",
+ **kwargs,
+ ):
+ self.video_kernel_size = video_kernel_size
+ self.alpha = alpha
+ self.merge_strategy = merge_strategy
+ self.time_mode = time_mode
+ assert (
+ self.time_mode in self.available_time_modes
+ ), f"time_mode parameter has to be in {self.available_time_modes}"
+ super().__init__(*args, **kwargs)
+
+ def get_last_layer(self, skip_time_mix=False, **kwargs):
+ if self.time_mode == "attn-only":
+ raise NotImplementedError("TODO")
+ else:
+ return (
+ self.conv_out.time_mix_conv.weight
+ if not skip_time_mix
+ else self.conv_out.weight
+ )
+
+ def _make_attn(self) -> Callable:
+ if self.time_mode not in ["conv-only", "only-last-conv"]:
+ return partialclass(
+ make_time_attn,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_attn()
+
+ def _make_conv(self) -> Callable:
+ if self.time_mode != "attn-only":
+ return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size)
+ else:
+ return Conv2DWrapper
+
+ def _make_resblock(self) -> Callable:
+ if self.time_mode not in ["attn-only", "only-last-conv"]:
+ return partialclass(
+ VideoResBlock,
+ video_kernel_size=self.video_kernel_size,
+ alpha=self.alpha,
+ merge_strategy=self.merge_strategy,
+ )
+ else:
+ return super()._make_resblock()
diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d86e7a262d1f036139e41f500d8579a2b95071ef
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser.py
@@ -0,0 +1,75 @@
+from typing import Dict, Union
+
+import torch
+import torch.nn as nn
+
+from ...util import append_dims, instantiate_from_config
+from .denoiser_scaling import DenoiserScaling
+from .discretizer import Discretization
+
+
+class Denoiser(nn.Module):
+ def __init__(self, scaling_config: Dict):
+ super().__init__()
+
+ self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
+
+ def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
+ return sigma
+
+ def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
+ return c_noise
+
+ def forward(
+ self,
+ network: nn.Module,
+ input: torch.Tensor,
+ sigma: torch.Tensor,
+ cond: Dict,
+ **additional_model_inputs,
+ ) -> torch.Tensor:
+ sigma = self.possibly_quantize_sigma(sigma)
+ sigma_shape = sigma.shape
+ sigma = append_dims(sigma, input.ndim)
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma)
+ c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
+ return (
+ network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ + input * c_skip
+ )
+
+
+class DiscreteDenoiser(Denoiser):
+ def __init__(
+ self,
+ scaling_config: Dict,
+ num_idx: int,
+ discretization_config: Dict,
+ do_append_zero: bool = False,
+ quantize_c_noise: bool = True,
+ flip: bool = True,
+ ):
+ super().__init__(scaling_config)
+ self.discretization: Discretization = instantiate_from_config(
+ discretization_config
+ )
+ sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip)
+ self.register_buffer("sigmas", sigmas)
+ self.quantize_c_noise = quantize_c_noise
+ self.num_idx = num_idx
+
+ def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
+ dists = sigma - self.sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape)
+
+ def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor:
+ return self.sigmas[idx]
+
+ def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
+ return self.idx_to_sigma(self.sigma_to_idx(sigma))
+
+ def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
+ if self.quantize_c_noise:
+ return self.sigma_to_idx(c_noise)
+ else:
+ return c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e287bfe8a82839a9a12fbd25c3446f43ab493b
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_scaling.py
@@ -0,0 +1,59 @@
+from abc import ABC, abstractmethod
+from typing import Tuple
+
+import torch
+
+
+class DenoiserScaling(ABC):
+ @abstractmethod
+ def __call__(
+ self, sigma: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ pass
+
+
+class EDMScaling:
+ def __init__(self, sigma_data: float = 0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(
+ self, sigma: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
+
+
+class EpsScaling:
+ def __call__(
+ self, sigma: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = torch.ones_like(sigma, device=sigma.device)
+ c_out = -sigma
+ c_in = 1 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScaling:
+ def __call__(
+ self, sigma: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = sigma.clone()
+ return c_skip, c_out, c_in, c_noise
+
+
+class VScalingWithEDMcNoise(DenoiserScaling):
+ def __call__(
+ self, sigma: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ c_skip = 1.0 / (sigma**2 + 1.0)
+ c_out = -sigma / (sigma**2 + 1.0) ** 0.5
+ c_in = 1.0 / (sigma**2 + 1.0) ** 0.5
+ c_noise = 0.25 * sigma.log()
+ return c_skip, c_out, c_in, c_noise
diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00
--- /dev/null
+++ b/sgm/modules/diffusionmodules/denoiser_weighting.py
@@ -0,0 +1,24 @@
+import torch
+
+
+class UnitWeighting:
+ def __call__(self, sigma):
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting:
+ def __init__(self, sigma_data=0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma):
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting:
+ def __call__(self, sigma):
+ return sigma**-2.0
diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..02add6081c5e3164d4402619b44d5be235d3ec58
--- /dev/null
+++ b/sgm/modules/diffusionmodules/discretizer.py
@@ -0,0 +1,69 @@
+from abc import abstractmethod
+from functools import partial
+
+import numpy as np
+import torch
+
+from ...modules.diffusionmodules.util import make_beta_schedule
+from ...util import append_zero
+
+
+def generate_roughly_equally_spaced_steps(
+ num_substeps: int, max_step: int
+) -> np.ndarray:
+ return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
+
+
+class Discretization:
+ def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
+ sigmas = self.get_sigmas(n, device=device)
+ sigmas = append_zero(sigmas) if do_append_zero else sigmas
+ return sigmas if not flip else torch.flip(sigmas, (0,))
+
+ @abstractmethod
+ def get_sigmas(self, n, device):
+ pass
+
+
+class EDMDiscretization(Discretization):
+ def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.rho = rho
+
+ def get_sigmas(self, n, device="cpu"):
+ ramp = torch.linspace(0, 1, n, device=device)
+ min_inv_rho = self.sigma_min ** (1 / self.rho)
+ max_inv_rho = self.sigma_max ** (1 / self.rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
+ return sigmas
+
+
+class LegacyDDPMDiscretization(Discretization):
+ def __init__(
+ self,
+ linear_start=0.00085,
+ linear_end=0.0120,
+ num_timesteps=1000,
+ ):
+ super().__init__()
+ self.num_timesteps = num_timesteps
+ betas = make_beta_schedule(
+ "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ def get_sigmas(self, n, device="cpu"):
+ if n < self.num_timesteps:
+ timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
+ alphas_cumprod = self.alphas_cumprod[timesteps]
+ elif n == self.num_timesteps:
+ alphas_cumprod = self.alphas_cumprod
+ else:
+ raise ValueError
+
+ to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
+ sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ return torch.flip(sigmas, (0,))
diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b5775b6ca857b4706f65f8cf3187cc8e4506d8
--- /dev/null
+++ b/sgm/modules/diffusionmodules/guiders.py
@@ -0,0 +1,146 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from einops import rearrange, repeat
+
+from ...util import append_dims, default
+
+logpy = logging.getLogger(__name__)
+
+
+class Guider(ABC):
+ @abstractmethod
+ def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
+ pass
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: float, c: Dict, uc: Dict
+ ) -> Tuple[torch.Tensor, float, Dict]:
+ pass
+
+
+class VanillaCFG(Guider):
+ def __init__(self, scale: float):
+ self.scale = scale
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+ x_pred = x_u + self.scale * (x_c - x_u)
+ return x_pred
+
+ def prepare_inputs(self, x, s, c, uc):
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "concat"]:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class IdentityGuider(Guider):
+ def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
+ return x
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: float, c: Dict, uc: Dict
+ ) -> Tuple[torch.Tensor, float, Dict]:
+ c_out = dict()
+
+ for k in c:
+ c_out[k] = c[k]
+
+ return x, s, c_out
+
+
+class LinearPredictionGuider(Guider):
+ def __init__(
+ self,
+ max_scale: float,
+ num_frames: int,
+ min_scale: float = 1.0,
+ additional_cond_keys: Optional[Union[List[str], str]] = None,
+ ):
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.num_frames = num_frames
+ self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
+
+ additional_cond_keys = default(additional_cond_keys, [])
+ if isinstance(additional_cond_keys, str):
+ additional_cond_keys = [additional_cond_keys]
+ self.additional_cond_keys = additional_cond_keys
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+
+ x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
+ x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
+ scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
+ scale = append_dims(scale, x_u.ndim).to(x_u.device)
+
+ return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ if k == "rgb":
+ continue
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
+
+
+class CentralPredictionGuider(Guider):
+ def __init__(
+ self,
+ max_scale: float,
+ num_frames: int,
+ min_scale: float = 1.0,
+ additional_cond_keys: Optional[Union[List[str], str]] = None,
+ ):
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.num_frames = num_frames
+ # self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
+ self.scale = torch.linspace(min_scale, 2 * max_scale, num_frames)
+ self.scale[num_frames // 2 :] = 2 * max_scale - self.scale[num_frames // 2 :]
+ self.scale = self.scale.unsqueeze(0)
+
+ additional_cond_keys = default(additional_cond_keys, [])
+ if isinstance(additional_cond_keys, str):
+ additional_cond_keys = [additional_cond_keys]
+ self.additional_cond_keys = additional_cond_keys
+
+ def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
+ x_u, x_c = x.chunk(2)
+
+ x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
+ x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
+ scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
+ scale = append_dims(scale, x_u.ndim).to(x_u.device)
+
+ return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
+
+ def prepare_inputs(
+ self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
+ ) -> Tuple[torch.Tensor, torch.Tensor, dict]:
+ c_out = dict()
+
+ for k in c:
+ if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
+ c_out[k] = torch.cat((uc[k], c[k]), 0)
+ else:
+ assert c[k] == uc[k]
+ c_out[k] = c[k]
+ return torch.cat([x] * 2), torch.cat([s] * 2), c_out
diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b2c437fab37bed10ea79c197560ade7bf511cad
--- /dev/null
+++ b/sgm/modules/diffusionmodules/loss.py
@@ -0,0 +1,187 @@
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+from ...modules.autoencoding.lpips.loss.lpips import LPIPS
+from ...modules.encoders.modules import GeneralConditioner
+from ...util import append_dims, instantiate_from_config
+from .denoiser import Denoiser
+
+
+class StandardDiffusionLoss(nn.Module):
+ def __init__(
+ self,
+ sigma_sampler_config: dict,
+ loss_weighting_config: dict,
+ loss_type: str = "l2",
+ offset_noise_level: float = 0.0,
+ batch2model_keys: Optional[Union[str, List[str]]] = None,
+ ):
+ super().__init__()
+
+ assert loss_type in ["l2", "l1", "lpips"]
+
+ self.sigma_sampler = instantiate_from_config(sigma_sampler_config)
+ self.loss_weighting = instantiate_from_config(loss_weighting_config)
+
+ self.loss_type = loss_type
+ self.offset_noise_level = offset_noise_level
+
+ if loss_type == "lpips":
+ self.lpips = LPIPS().eval()
+
+ if not batch2model_keys:
+ batch2model_keys = []
+
+ if isinstance(batch2model_keys, str):
+ batch2model_keys = [batch2model_keys]
+
+ self.batch2model_keys = set(batch2model_keys)
+
+ def get_noised_input(
+ self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor
+ ) -> torch.Tensor:
+ noised_input = input + noise * sigmas_bc
+ return noised_input
+
+ def forward(
+ self,
+ network: nn.Module,
+ denoiser: Denoiser,
+ conditioner: GeneralConditioner,
+ input: torch.Tensor,
+ batch: Dict,
+ return_model_output: bool = False,
+ ) -> torch.Tensor:
+ cond = conditioner(batch)
+ # for video diffusion
+ if "num_video_frames" in batch:
+ num_frames = batch["num_video_frames"]
+ for k in ["crossattn", "concat"]:
+ cond[k] = repeat(cond[k], "b ... -> b t ...", t=num_frames)
+ cond[k] = rearrange(cond[k], "b t ... -> (b t) ...", t=num_frames)
+ return self._forward(network, denoiser, cond, input, batch, return_model_output)
+
+ def _forward(
+ self,
+ network: nn.Module,
+ denoiser: Denoiser,
+ cond: Dict,
+ input: torch.Tensor,
+ batch: Dict,
+ return_model_output: bool = False,
+ ) -> Tuple[torch.Tensor, Dict]:
+ additional_model_inputs = {
+ key: batch[key] for key in self.batch2model_keys.intersection(batch)
+ }
+ sigmas = self.sigma_sampler(input.shape[0]).to(input)
+
+ noise = torch.randn_like(input)
+ if self.offset_noise_level > 0.0:
+ offset_shape = (
+ (input.shape[0], 1, input.shape[2])
+ if self.n_frames is not None
+ else (input.shape[0], input.shape[1])
+ )
+ noise = noise + self.offset_noise_level * append_dims(
+ torch.randn(offset_shape, device=input.device),
+ input.ndim,
+ )
+ sigmas_bc = append_dims(sigmas, input.ndim)
+ noised_input = self.get_noised_input(sigmas_bc, noise, input)
+
+ model_output = denoiser(
+ network, noised_input, sigmas, cond, **additional_model_inputs
+ )
+ w = append_dims(self.loss_weighting(sigmas), input.ndim)
+ if not return_model_output:
+ return self.get_loss(model_output, input, w)
+ else:
+ return self.get_loss(model_output, input, w), model_output
+
+ def get_loss(self, model_output, target, w):
+ if self.loss_type == "l2":
+ return torch.mean(
+ (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
+ )
+ elif self.loss_type == "l1":
+ return torch.mean(
+ (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
+ )
+ elif self.loss_type == "lpips":
+ loss = self.lpips(model_output, target).reshape(-1)
+ return loss
+ else:
+ raise NotImplementedError(f"Unknown loss type {self.loss_type}")
+
+
+class StandardDiffusionLossWithPixelNeRFLoss(StandardDiffusionLoss):
+ def __init__(
+ self,
+ sigma_sampler_config: Dict,
+ loss_weighting_config: Dict,
+ loss_type: str = "l2",
+ offset_noise_level: float = 0,
+ batch2model_keys: str | List[str] | None = None,
+ pixelnerf_loss_weight: float = 1.0,
+ pixelnerf_loss_type: str = "l2",
+ ):
+ super().__init__(
+ sigma_sampler_config,
+ loss_weighting_config,
+ loss_type,
+ offset_noise_level,
+ batch2model_keys,
+ )
+ self.pixelnerf_loss_weight = pixelnerf_loss_weight
+ self.pixelnerf_loss_type = pixelnerf_loss_type
+
+ def get_pixelnerf_loss(self, model_output, target):
+ if self.pixelnerf_loss_type == "l2":
+ return torch.mean(
+ ((model_output - target) ** 2).reshape(target.shape[0], -1), 1
+ )
+ elif self.pixelnerf_loss_type == "l1":
+ return torch.mean(
+ ((model_output - target).abs()).reshape(target.shape[0], -1), 1
+ )
+ elif self.pixelnerf_loss_type == "lpips":
+ loss = self.lpips(model_output, target).reshape(-1)
+ return loss
+ else:
+ raise NotImplementedError(f"Unknown loss type {self.loss_type}")
+
+ def forward(
+ self,
+ network: nn.Module,
+ denoiser: Denoiser,
+ conditioner: GeneralConditioner,
+ input: torch.Tensor,
+ batch: Dict,
+ return_model_output: bool = False,
+ ) -> torch.Tensor:
+ cond = conditioner(batch)
+ return self._forward(network, denoiser, cond, input, batch, return_model_output)
+
+ def _forward(
+ self,
+ network: nn.Module,
+ denoiser: Denoiser,
+ cond: Dict,
+ input: torch.Tensor,
+ batch: Dict,
+ return_model_output: bool = False,
+ ) -> Tuple[torch.Tensor | Dict]:
+ loss = super()._forward(
+ network, denoiser, cond, input, batch, return_model_output
+ )
+ pixelnerf_loss = self.get_pixelnerf_loss(
+ cond["rgb"], batch["pixelnerf_input"]["rgb"]
+ )
+
+ if not return_model_output:
+ return loss + self.pixelnerf_loss_weight * pixelnerf_loss
+ else:
+ return loss[0] + self.pixelnerf_loss_weight * pixelnerf_loss, loss[1]
diff --git a/sgm/modules/diffusionmodules/loss_weighting.py b/sgm/modules/diffusionmodules/loss_weighting.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12c0a76635435babd1af33969e82fa284525af8
--- /dev/null
+++ b/sgm/modules/diffusionmodules/loss_weighting.py
@@ -0,0 +1,32 @@
+from abc import ABC, abstractmethod
+
+import torch
+
+
+class DiffusionLossWeighting(ABC):
+ @abstractmethod
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ pass
+
+
+class UnitWeighting(DiffusionLossWeighting):
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ return torch.ones_like(sigma, device=sigma.device)
+
+
+class EDMWeighting(DiffusionLossWeighting):
+ def __init__(self, sigma_data: float = 0.5):
+ self.sigma_data = sigma_data
+
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+
+class VWeighting(EDMWeighting):
+ def __init__(self):
+ super().__init__(sigma_data=1.0)
+
+
+class EpsWeighting(DiffusionLossWeighting):
+ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
+ return sigma**-2.0
diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cf9d92140dee8443a0ea6b5cf218f2879ad88f4
--- /dev/null
+++ b/sgm/modules/diffusionmodules/model.py
@@ -0,0 +1,748 @@
+# pytorch_diffusion + derived encoder decoder
+import logging
+import math
+from typing import Any, Callable, Optional
+
+import numpy as np
+import torch
+import torch.nn as nn
+from einops import rearrange
+from packaging import version
+
+logpy = logging.getLogger(__name__)
+
+try:
+ import xformers
+ import xformers.ops
+
+ XFORMERS_IS_AVAILABLE = True
+except:
+ XFORMERS_IS_AVAILABLE = False
+ logpy.warning("no module 'xformers'. Processing without...")
+
+from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
+ )
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
+ )
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ b, c, h, w = q.shape
+ q, k, v = map(
+ lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
+ )
+ h_ = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v
+ ) # scale is dim ** -0.5 per default
+ # compute attention
+
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
+ )
+ self.attention_op: Optional[Any] = None
+
+ def attention(self, h_: torch.Tensor) -> torch.Tensor:
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(
+ q, k, v, attn_bias=None, op=self.attention_op
+ )
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C)
+
+ def forward(self, x, **kwargs):
+ h_ = x
+ h_ = self.attention(h_)
+ h_ = self.proj_out(h_)
+ return x + h_
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None, **unused_kwargs):
+ b, c, h, w = x.shape
+ x = rearrange(x, "b c h w -> b (h w) c")
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in [
+ "vanilla",
+ "vanilla-xformers",
+ "memory-efficient-cross-attn",
+ "linear",
+ "none",
+ ], f"attn_type {attn_type} unknown"
+ if (
+ version.parse(torch.__version__) < version.parse("2.0.0")
+ and attn_type != "none"
+ ):
+ assert XFORMERS_IS_AVAILABLE, (
+ f"We do not support vanilla attention in {torch.__version__} anymore, "
+ f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
+ )
+ attn_type = "vanilla-xformers"
+ logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ logpy.info(
+ f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
+ )
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Model(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList(
+ [
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ]
+ )
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(
+ in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x, t=None, context=None):
+ # assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb
+ )
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
+ )
+
+ curr_res = resolution
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(
+ block_in,
+ 2 * z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs,
+ ):
+ super().__init__()
+ if use_linear_attn:
+ attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ logpy.info(
+ "Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)
+ )
+ )
+
+ make_attn_cls = self._make_attn()
+ make_resblock_cls = self._make_resblock()
+ make_conv_cls = self._make_conv()
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
+ )
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
+ self.mid.block_2 = make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ make_resblock_cls(
+ in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout,
+ )
+ )
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn_cls(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = make_conv_cls(
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
+ )
+
+ def _make_attn(self) -> Callable:
+ return make_attn
+
+ def _make_resblock(self) -> Callable:
+ return ResnetBlock
+
+ def _make_conv(self) -> Callable:
+ return torch.nn.Conv2d
+
+ def get_last_layer(self, **kwargs):
+ return self.conv_out.weight
+
+ def forward(self, z, **kwargs):
+ # assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb, **kwargs)
+ h = self.mid.attn_1(h, **kwargs)
+ h = self.mid.block_2(h, temb, **kwargs)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, **kwargs)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h, **kwargs)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e762e6823540def71743e27131e284ea28cdb56e
--- /dev/null
+++ b/sgm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,863 @@
+import logging
+import math
+from abc import abstractmethod
+from typing import Iterable, List, Optional, Tuple, Union
+
+import torch
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from functools import partial
+
+# from torch.utils.checkpoint import checkpoint
+
+checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
+
+from ...modules.attention import SpatialTransformer
+from ...modules.diffusionmodules.util import (
+ avg_pool_nd,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from ...modules.video_attention import SpatialVideoTransformer
+from ...util import exists
+
+logpy = logging.getLogger(__name__)
+
+
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
+ )
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ b, c, _ = x.shape
+ x = x.reshape(b, c, -1)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x: th.Tensor, emb: th.Tensor):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(
+ self,
+ x: th.Tensor,
+ emb: th.Tensor,
+ context: Optional[th.Tensor] = None,
+ image_only_indicator: Optional[th.Tensor] = None,
+ time_context: Optional[int] = None,
+ num_video_frames: Optional[int] = None,
+ ):
+ from ...modules.diffusionmodules.video_model import VideoResBlock
+
+ for layer in self:
+ module = layer
+
+ if isinstance(module, TimestepBlock) and not isinstance(
+ module, VideoResBlock
+ ):
+ x = layer(x, emb)
+ elif isinstance(module, VideoResBlock):
+ x = layer(x, emb, num_video_frames, image_only_indicator)
+ elif isinstance(module, SpatialVideoTransformer):
+ x = layer(
+ x,
+ context,
+ time_context,
+ num_video_frames,
+ image_only_indicator,
+ )
+ elif isinstance(module, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool,
+ dims: int = 2,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ third_up: bool = False,
+ kernel_size: int = 3,
+ scale_factor: int = 2,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ self.third_up = third_up
+ self.scale_factor = scale_factor
+ if use_conv:
+ self.conv = conv_nd(
+ dims, self.channels, self.out_channels, kernel_size, padding=padding
+ )
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ assert x.shape[1] == self.channels
+
+ if self.dims == 3:
+ t_factor = 1 if not self.third_up else self.scale_factor
+ x = F.interpolate(
+ x,
+ (
+ t_factor * x.shape[2],
+ x.shape[3] * self.scale_factor,
+ x.shape[4] * self.scale_factor,
+ ),
+ mode="nearest",
+ )
+ else:
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ use_conv: bool,
+ dims: int = 2,
+ out_channels: Optional[int] = None,
+ padding: int = 1,
+ third_down: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
+ if use_conv:
+ logpy.info(f"Building a Downsample layer with {dims} dims.")
+ logpy.info(
+ f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
+ f"kernel-size: 3, stride: {stride}, padding: {padding}"
+ )
+ if dims == 3:
+ logpy.info(f" --> Downsampling third axis (time): {third_down}")
+ self.op = conv_nd(
+ dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding,
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ assert x.shape[1] == self.channels
+
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ emb_channels: int,
+ dropout: float,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ use_scale_shift_norm: bool = False,
+ dims: int = 2,
+ use_checkpoint: bool = False,
+ up: bool = False,
+ down: bool = False,
+ kernel_size: int = 3,
+ exchange_temb_dims: bool = False,
+ skip_t_emb: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.exchange_temb_dims = exchange_temb_dims
+
+ if isinstance(kernel_size, Iterable):
+ padding = [k // 2 for k in kernel_size]
+ else:
+ padding = kernel_size // 2
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.skip_t_emb = skip_t_emb
+ self.emb_out_channels = (
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels
+ )
+ if self.skip_t_emb:
+ logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}")
+ assert not self.use_scale_shift_norm
+ self.emb_layers = None
+ self.exchange_temb_dims = False
+ else:
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ self.emb_out_channels,
+ ),
+ )
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(
+ dims,
+ self.out_channels,
+ self.out_channels,
+ kernel_size,
+ padding=padding,
+ )
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, kernel_size, padding=padding
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.use_checkpoint:
+ return checkpoint(self._forward, x, emb)
+ else:
+ return self._forward(x, emb)
+
+ def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor:
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+
+ if self.skip_t_emb:
+ emb_out = th.zeros_like(h)
+ else:
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ if self.exchange_temb_dims:
+ emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ num_heads: int = 1,
+ num_head_channels: int = -1,
+ use_checkpoint: bool = False,
+ use_new_attention_order: bool = False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x: th.Tensor, **kwargs) -> th.Tensor:
+ return checkpoint(self._forward, x)
+
+ def _forward(self, x: th.Tensor) -> th.Tensor:
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads: int):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv: th.Tensor) -> th.Tensor:
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads: int):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv: th.Tensor) -> th.Tensor:
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim: int):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t: th.Tensor) -> th.Tensor:
+ return timestep_embedding(t, self.dim)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ model_channels: int,
+ out_channels: int,
+ num_res_blocks: int,
+ attention_resolutions: int,
+ dropout: float = 0.0,
+ channel_mult: Union[List, Tuple] = (1, 2, 4, 8),
+ conv_resample: bool = True,
+ dims: int = 2,
+ num_classes: Optional[Union[int, str]] = None,
+ use_checkpoint: bool = False,
+ num_heads: int = -1,
+ num_head_channels: int = -1,
+ num_heads_upsample: int = -1,
+ use_scale_shift_norm: bool = False,
+ resblock_updown: bool = False,
+ transformer_depth: int = 1,
+ context_dim: Optional[int] = None,
+ disable_self_attentions: Optional[List[bool]] = None,
+ num_attention_blocks: Optional[List[int]] = None,
+ disable_middle_self_attn: bool = False,
+ disable_middle_transformer: bool = False,
+ use_linear_in_transformer: bool = False,
+ spatial_transformer_attn_type: str = "softmax",
+ adm_in_channels: Optional[int] = None,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert (
+ num_head_channels != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ if num_head_channels == -1:
+ assert (
+ num_heads != -1
+ ), "Either num_heads or num_head_channels has to be set"
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ transformer_depth_middle = transformer_depth[-1]
+
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError(
+ "provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult"
+ )
+ self.num_res_blocks = num_res_blocks
+
+ if disable_self_attentions is not None:
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(
+ map(
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
+ range(len(num_attention_blocks)),
+ )
+ )
+ logpy.info(
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set."
+ )
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ logpy.info("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ if context_dim is not None and exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or nr < num_attention_blocks[level]
+ ):
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ if not disable_middle_transformer
+ else th.nn.Identity(),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if (
+ not exists(num_attention_blocks)
+ or i < num_attention_blocks[level]
+ ):
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ disable_self_attn=disabled_sa,
+ use_linear=use_linear_in_transformer,
+ attn_type=spatial_transformer_attn_type,
+ use_checkpoint=use_checkpoint,
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ def forward(
+ self,
+ x: th.Tensor,
+ timesteps: Optional[th.Tensor] = None,
+ context: Optional[th.Tensor] = None,
+ y: Optional[th.Tensor] = None,
+ **kwargs,
+ ) -> th.Tensor:
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+
+ return self.out(h)
diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..6346829c86a76ab549ed69431f1704e01379535a
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling.py
@@ -0,0 +1,365 @@
+"""
+ Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py
+"""
+
+
+from typing import Dict, Union
+
+import torch
+from omegaconf import ListConfig, OmegaConf
+from tqdm import tqdm
+
+from ...modules.diffusionmodules.sampling_utils import (
+ get_ancestral_step,
+ linear_multistep_coeff,
+ to_d,
+ to_neg_log_sigma,
+ to_sigma,
+)
+from ...util import append_dims, default, instantiate_from_config
+
+DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
+
+
+class BaseDiffusionSampler:
+ def __init__(
+ self,
+ discretization_config: Union[Dict, ListConfig, OmegaConf],
+ num_steps: Union[int, None] = None,
+ guider_config: Union[Dict, ListConfig, OmegaConf, None] = None,
+ verbose: bool = False,
+ device: str = "cuda",
+ ):
+ self.num_steps = num_steps
+ self.discretization = instantiate_from_config(discretization_config)
+ self.guider = instantiate_from_config(
+ default(
+ guider_config,
+ DEFAULT_GUIDER,
+ )
+ )
+ self.verbose = verbose
+ self.device = device
+
+ def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
+ sigmas = self.discretization(
+ self.num_steps if num_steps is None else num_steps, device=self.device
+ )
+ uc = default(uc, cond)
+
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
+ num_sigmas = len(sigmas)
+
+ s_in = x.new_ones([x.shape[0]])
+
+ return x, s_in, sigmas, num_sigmas, cond, uc
+
+ def denoise(self, x, denoiser, sigma, cond, uc):
+ denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc))
+ denoised = self.guider(denoised, sigma)
+ return denoised
+
+ def get_sigma_gen(self, num_sigmas):
+ sigma_generator = range(num_sigmas - 1)
+ if self.verbose:
+ print("#" * 30, " Sampling setting ", "#" * 30)
+ print(f"Sampler: {self.__class__.__name__}")
+ print(f"Discretization: {self.discretization.__class__.__name__}")
+ print(f"Guider: {self.guider.__class__.__name__}")
+ sigma_generator = tqdm(
+ sigma_generator,
+ total=num_sigmas,
+ desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps",
+ )
+ return sigma_generator
+
+
+class SingleStepDiffusionSampler(BaseDiffusionSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs):
+ raise NotImplementedError
+
+ def euler_step(self, x, d, dt):
+ return x + dt * d
+
+
+class EDMSampler(SingleStepDiffusionSampler):
+ def __init__(
+ self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.s_churn = s_churn
+ self.s_tmin = s_tmin
+ self.s_tmax = s_tmax
+ self.s_noise = s_noise
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0):
+ sigma_hat = sigma * (gamma + 1.0)
+ if gamma > 0:
+ eps = torch.randn_like(x) * self.s_noise
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
+
+ denoised = self.denoise(x, denoiser, sigma_hat, cond, uc)
+ d = to_d(x, sigma_hat, denoised)
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
+
+ euler_step = self.euler_step(x, d, dt)
+ x = self.possible_correction_step(
+ euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ gamma = (
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
+ else 0.0
+ )
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ gamma,
+ )
+
+ return x
+
+
+class AncestralSampler(SingleStepDiffusionSampler):
+ def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self.eta = eta
+ self.s_noise = s_noise
+ self.noise_sampler = lambda x: torch.randn_like(x)
+
+ def ancestral_euler_step(self, x, denoised, sigma, sigma_down):
+ d = to_d(x, sigma, denoised)
+ dt = append_dims(sigma_down - sigma, x.ndim)
+
+ return self.euler_step(x, d, dt)
+
+ def ancestral_step(self, x, sigma, next_sigma, sigma_up):
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0,
+ x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim),
+ x,
+ )
+ return x
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ for i in self.get_sigma_gen(num_sigmas):
+ x = self.sampler_step(
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc,
+ )
+
+ return x
+
+
+class LinearMultistepSampler(BaseDiffusionSampler):
+ def __init__(
+ self,
+ order=4,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ self.order = order
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ ds = []
+ sigmas_cpu = sigmas.detach().cpu().numpy()
+ for i in self.get_sigma_gen(num_sigmas):
+ sigma = s_in * sigmas[i]
+ denoised = denoiser(
+ *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs
+ )
+ denoised = self.guider(denoised, sigma)
+ d = to_d(x, sigma, denoised)
+ ds.append(d)
+ if len(ds) > self.order:
+ ds.pop(0)
+ cur_order = min(i + 1, self.order)
+ coeffs = [
+ linear_multistep_coeff(cur_order, sigmas_cpu, i, j)
+ for j in range(cur_order)
+ ]
+ x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
+
+ return x
+
+
+class EulerEDMSampler(EDMSampler):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
+ return euler_step
+
+
+class HeunEDMSampler(EDMSampler):
+ def possible_correction_step(
+ self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc
+ ):
+ if torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ return euler_step
+ else:
+ denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc)
+ d_new = to_d(euler_step, next_sigma, denoised)
+ d_prime = (d + d_new) / 2.0
+
+ # apply correction if noise level is not 0
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step
+ )
+ return x
+
+
+class EulerAncestralSampler(AncestralSampler):
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+
+ return x
+
+
+class DPMPP2SAncestralSampler(AncestralSampler):
+ def get_variables(self, sigma, sigma_down):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)]
+ h = t_next - t
+ s = t + 0.5 * h
+ return h, s, t, t_next
+
+ def get_mult(self, h, s, t, t_next):
+ mult1 = to_sigma(s) / to_sigma(t)
+ mult2 = (-0.5 * h).expm1()
+ mult3 = to_sigma(t_next) / to_sigma(t)
+ mult4 = (-h).expm1()
+
+ return mult1, mult2, mult3, mult4
+
+ def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs):
+ sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta)
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+ x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down)
+
+ if torch.sum(sigma_down) < 1e-14:
+ # Save a network evaluation if all noise levels are 0
+ x = x_euler
+ else:
+ h, s, t, t_next = self.get_variables(sigma, sigma_down)
+ mult = [
+ append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next)
+ ]
+
+ x2 = mult[0] * x - mult[1] * denoised
+ denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc)
+ x_dpmpp2s = mult[2] * x - mult[3] * denoised2
+
+ # apply correction if noise level is not 0
+ x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler)
+
+ x = self.ancestral_step(x, sigma, next_sigma, sigma_up)
+ return x
+
+
+class DPMPP2MSampler(BaseDiffusionSampler):
+ def get_variables(self, sigma, next_sigma, previous_sigma=None):
+ t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)]
+ h = t_next - t
+
+ if previous_sigma is not None:
+ h_last = t - to_neg_log_sigma(previous_sigma)
+ r = h_last / h
+ return h, r, t, t_next
+ else:
+ return h, None, t, t_next
+
+ def get_mult(self, h, r, t, t_next, previous_sigma):
+ mult1 = to_sigma(t_next) / to_sigma(t)
+ mult2 = (-h).expm1()
+
+ if previous_sigma is not None:
+ mult3 = 1 + 1 / (2 * r)
+ mult4 = 1 / (2 * r)
+ return mult1, mult2, mult3, mult4
+ else:
+ return mult1, mult2
+
+ def sampler_step(
+ self,
+ old_denoised,
+ previous_sigma,
+ sigma,
+ next_sigma,
+ denoiser,
+ x,
+ cond,
+ uc=None,
+ ):
+ denoised = self.denoise(x, denoiser, sigma, cond, uc)
+
+ h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma)
+ mult = [
+ append_dims(mult, x.ndim)
+ for mult in self.get_mult(h, r, t, t_next, previous_sigma)
+ ]
+
+ x_standard = mult[0] * x - mult[1] * denoised
+ if old_denoised is None or torch.sum(next_sigma) < 1e-14:
+ # Save a network evaluation if all noise levels are 0 or on the first step
+ return x_standard, denoised
+ else:
+ denoised_d = mult[2] * denoised - mult[3] * old_denoised
+ x_advanced = mult[0] * x - mult[1] * denoised_d
+
+ # apply correction if noise level is not 0 and not first step
+ x = torch.where(
+ append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard
+ )
+
+ return x, denoised
+
+ def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs):
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
+ x, cond, uc, num_steps
+ )
+
+ old_denoised = None
+ for i in self.get_sigma_gen(num_sigmas):
+ x, old_denoised = self.sampler_step(
+ old_denoised,
+ None if i == 0 else s_in * sigmas[i - 1],
+ s_in * sigmas[i],
+ s_in * sigmas[i + 1],
+ denoiser,
+ x,
+ cond,
+ uc=uc,
+ )
+
+ return x
diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce78527ea9052a8bfd0856ed2278901516fb9130
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sampling_utils.py
@@ -0,0 +1,43 @@
+import torch
+from scipy import integrate
+
+from ...util import append_dims
+
+
+def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
+ if order - 1 > i:
+ raise ValueError(f"Order {order} too high for step {i}")
+
+ def fn(tau):
+ prod = 1.0
+ for k in range(order):
+ if j == k:
+ continue
+ prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
+ return prod
+
+ return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0]
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
+ if not eta:
+ return sigma_to, 0.0
+ sigma_up = torch.minimum(
+ sigma_to,
+ eta
+ * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5,
+ )
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+def to_d(x, sigma, denoised):
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def to_neg_log_sigma(sigma):
+ return sigma.log().neg()
+
+
+def to_sigma(neg_log_sigma):
+ return neg_log_sigma.neg().exp()
diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d54724c6ef6a7b8067784a4192b0fe2f41123063
--- /dev/null
+++ b/sgm/modules/diffusionmodules/sigma_sampling.py
@@ -0,0 +1,31 @@
+import torch
+
+from ...util import default, instantiate_from_config
+
+
+class EDMSampling:
+ def __init__(self, p_mean=-1.2, p_std=1.2):
+ self.p_mean = p_mean
+ self.p_std = p_std
+
+ def __call__(self, n_samples, rand=None):
+ log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,)))
+ return log_sigma.exp()
+
+
+class DiscreteSampling:
+ def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True):
+ self.num_idx = num_idx
+ self.sigmas = instantiate_from_config(discretization_config)(
+ num_idx, do_append_zero=do_append_zero, flip=flip
+ )
+
+ def idx_to_sigma(self, idx):
+ return self.sigmas[idx]
+
+ def __call__(self, n_samples, rand=None):
+ idx = default(
+ rand,
+ torch.randint(0, self.num_idx, (n_samples,)),
+ )
+ return self.idx_to_sigma(idx)
diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..389f0e449367b1b628d61dca105343d066dbefff
--- /dev/null
+++ b/sgm/modules/diffusionmodules/util.py
@@ -0,0 +1,369 @@
+"""
+partially adopted from
+https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+and
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+and
+https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+
+thanks!
+"""
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+def make_beta_schedule(
+ schedule,
+ n_timestep,
+ linear_start=1e-4,
+ linear_end=2e-2,
+):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
+ )
+ ** 2
+ )
+ return betas.numpy()
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def mixed_checkpoint(func, inputs: dict, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function
+ borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that
+ it also works with non-tensor inputs
+ :param func: the function to evaluate.
+ :param inputs: the argument dictionary to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)]
+ tensor_inputs = [
+ inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor)
+ ]
+ non_tensor_keys = [
+ key for key in inputs if not isinstance(inputs[key], torch.Tensor)
+ ]
+ non_tensor_inputs = [
+ inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor)
+ ]
+ args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params)
+ return MixedCheckpointFunction.apply(
+ func,
+ len(tensor_inputs),
+ len(non_tensor_inputs),
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ )
+ else:
+ return func(**inputs)
+
+
+class MixedCheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ run_function,
+ length_tensors,
+ length_non_tensors,
+ tensor_keys,
+ non_tensor_keys,
+ *args,
+ ):
+ ctx.end_tensors = length_tensors
+ ctx.end_non_tensors = length_tensors + length_non_tensors
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ assert (
+ len(tensor_keys) == length_tensors
+ and len(non_tensor_keys) == length_non_tensors
+ )
+
+ ctx.input_tensors = {
+ key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors]))
+ }
+ ctx.input_non_tensors = {
+ key: val
+ for (key, val) in zip(
+ non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors])
+ )
+ }
+ ctx.run_function = run_function
+ ctx.input_params = list(args[ctx.end_non_tensors :])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(
+ **ctx.input_tensors, **ctx.input_non_tensors
+ )
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)}
+ ctx.input_tensors = {
+ key: ctx.input_tensors[key].detach().requires_grad_(True)
+ for key in ctx.input_tensors
+ }
+
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = {
+ key: ctx.input_tensors[key].view_as(ctx.input_tensors[key])
+ for key in ctx.input_tensors
+ }
+ # shallow_copies.update(additional_args)
+ output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ list(ctx.input_tensors.values()) + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (
+ (None, None, None, None, None)
+ + input_grads[: ctx.end_tensors]
+ + (None,) * (ctx.end_non_tensors - ctx.end_tensors)
+ + input_grads[ctx.end_tensors :]
+ )
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {
+ "enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled(),
+ }
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
+ )
+ else:
+ embedding = repeat(timesteps, "b -> b d", d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class AlphaBlender(nn.Module):
+ strategies = ["learned", "fixed", "learned_with_images"]
+
+ def __init__(
+ self,
+ alpha: float,
+ merge_strategy: str = "learned_with_images",
+ rearrange_pattern: str = "b t -> (b t) 1 1",
+ ):
+ super().__init__()
+ self.merge_strategy = merge_strategy
+ self.rearrange_pattern = rearrange_pattern
+
+ assert (
+ merge_strategy in self.strategies
+ ), f"merge_strategy needs to be in {self.strategies}"
+
+ if self.merge_strategy == "fixed":
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
+ elif (
+ self.merge_strategy == "learned"
+ or self.merge_strategy == "learned_with_images"
+ ):
+ self.register_parameter(
+ "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
+ )
+ else:
+ raise ValueError(f"unknown merge strategy {self.merge_strategy}")
+
+ def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
+ if self.merge_strategy == "fixed":
+ alpha = self.mix_factor
+ elif self.merge_strategy == "learned":
+ alpha = torch.sigmoid(self.mix_factor)
+ elif self.merge_strategy == "learned_with_images":
+ assert image_only_indicator is not None, "need image_only_indicator ..."
+ alpha = torch.where(
+ image_only_indicator.bool(),
+ torch.ones(1, 1, device=image_only_indicator.device),
+ rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
+ )
+ alpha = rearrange(alpha, self.rearrange_pattern)
+ else:
+ raise NotImplementedError
+ return alpha
+
+ def forward(
+ self,
+ x_spatial: torch.Tensor,
+ x_temporal: torch.Tensor,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ alpha = self.get_alpha(image_only_indicator)
+ x = (
+ alpha.to(x_spatial.dtype) * x_spatial
+ + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
+ )
+ return x
diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff2d077c7d0c7ed1c4a2c21f14105c266abc4926
--- /dev/null
+++ b/sgm/modules/diffusionmodules/video_model.py
@@ -0,0 +1,493 @@
+from functools import partial
+from typing import List, Optional, Union
+
+from einops import rearrange
+
+from ...modules.diffusionmodules.openaimodel import *
+from ...modules.video_attention import SpatialVideoTransformer
+from ...util import default
+from .util import AlphaBlender
+
+
+class VideoResBlock(ResBlock):
+ def __init__(
+ self,
+ channels: int,
+ emb_channels: int,
+ dropout: float,
+ video_kernel_size: Union[int, List[int]] = 3,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ out_channels: Optional[int] = None,
+ use_conv: bool = False,
+ use_scale_shift_norm: bool = False,
+ dims: int = 2,
+ use_checkpoint: bool = False,
+ up: bool = False,
+ down: bool = False,
+ ):
+ super().__init__(
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=out_channels,
+ use_conv=use_conv,
+ use_scale_shift_norm=use_scale_shift_norm,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ up=up,
+ down=down,
+ )
+
+ self.time_stack = ResBlock(
+ default(out_channels, channels),
+ emb_channels,
+ dropout=dropout,
+ dims=3,
+ out_channels=default(out_channels, channels),
+ use_scale_shift_norm=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ kernel_size=video_kernel_size,
+ use_checkpoint=use_checkpoint,
+ exchange_temb_dims=True,
+ )
+ self.time_mixer = AlphaBlender(
+ alpha=merge_factor,
+ merge_strategy=merge_strategy,
+ rearrange_pattern="b t -> b 1 t 1 1",
+ )
+
+ def forward(
+ self,
+ x: th.Tensor,
+ emb: th.Tensor,
+ num_video_frames: int,
+ image_only_indicator: Optional[th.Tensor] = None,
+ ) -> th.Tensor:
+ x = super().forward(x, emb)
+
+ x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames)
+
+ x = self.time_stack(
+ x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames)
+ )
+ x = self.time_mixer(
+ x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
+ )
+ x = rearrange(x, "b c t h w -> (b t) c h w")
+ return x
+
+
+class VideoUNet(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ model_channels: int,
+ out_channels: int,
+ num_res_blocks: int,
+ attention_resolutions: int,
+ dropout: float = 0.0,
+ channel_mult: List[int] = (1, 2, 4, 8),
+ conv_resample: bool = True,
+ dims: int = 2,
+ num_classes: Optional[int] = None,
+ use_checkpoint: bool = False,
+ num_heads: int = -1,
+ num_head_channels: int = -1,
+ num_heads_upsample: int = -1,
+ use_scale_shift_norm: bool = False,
+ resblock_updown: bool = False,
+ transformer_depth: Union[List[int], int] = 1,
+ transformer_depth_middle: Optional[int] = None,
+ context_dim: Optional[int] = None,
+ time_downup: bool = False,
+ time_context_dim: Optional[int] = None,
+ extra_ff_mix_layer: bool = False,
+ use_spatial_context: bool = False,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ spatial_transformer_attn_type: str = "softmax",
+ video_kernel_size: Union[int, List[int]] = 3,
+ use_linear_in_transformer: bool = False,
+ adm_in_channels: Optional[int] = None,
+ disable_temporal_crossattention: bool = False,
+ max_ddpm_temb_period: int = 10000,
+ ):
+ super().__init__()
+ assert context_dim is not None
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1
+
+ if num_head_channels == -1:
+ assert num_heads != -1
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(transformer_depth, int):
+ transformer_depth = len(channel_mult) * [transformer_depth]
+ transformer_depth_middle = default(
+ transformer_depth_middle, transformer_depth[-1]
+ )
+
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "timestep":
+ self.label_emb = nn.Sequential(
+ Timestep(model_channels),
+ nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ ),
+ )
+
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+
+ def get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=1,
+ context_dim=None,
+ use_checkpoint=False,
+ disabled_sa=False,
+ ):
+ return SpatialVideoTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=depth,
+ context_dim=context_dim,
+ time_context_dim=time_context_dim,
+ dropout=dropout,
+ ff_in=extra_ff_mix_layer,
+ use_spatial_context=use_spatial_context,
+ merge_strategy=merge_strategy,
+ merge_factor=merge_factor,
+ checkpoint=use_checkpoint,
+ use_linear=use_linear_in_transformer,
+ attn_mode=spatial_transformer_attn_type,
+ disable_self_attn=disabled_sa,
+ disable_temporal_crossattention=disable_temporal_crossattention,
+ max_time_embed_period=max_ddpm_temb_period,
+ )
+
+ def get_resblock(
+ merge_factor,
+ merge_strategy,
+ video_kernel_size,
+ ch,
+ time_embed_dim,
+ dropout,
+ out_ch,
+ dims,
+ use_checkpoint,
+ use_scale_shift_norm,
+ down=False,
+ up=False,
+ ):
+ return VideoResBlock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ channels=ch,
+ emb_channels=time_embed_dim,
+ dropout=dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=down,
+ up=up,
+ )
+
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ layers.append(
+ get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint,
+ disabled_sa=False,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ ds *= 2
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch,
+ conv_resample,
+ dims=dims,
+ out_channels=out_ch,
+ third_down=time_downup,
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ self.middle_block = TimestepEmbedSequential(
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ out_ch=None,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth_middle,
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint,
+ ),
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ out_ch=None,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch + ich,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ layers.append(
+ get_attention_layer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth[level],
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint,
+ disabled_sa=False,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ ds //= 2
+ layers.append(
+ get_resblock(
+ merge_factor=merge_factor,
+ merge_strategy=merge_strategy,
+ video_kernel_size=video_kernel_size,
+ ch=ch,
+ time_embed_dim=time_embed_dim,
+ dropout=dropout,
+ out_ch=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(
+ ch,
+ conv_resample,
+ dims=dims,
+ out_channels=out_ch,
+ third_up=time_downup,
+ )
+ )
+
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ def forward(
+ self,
+ x: th.Tensor,
+ timesteps: th.Tensor,
+ context: Optional[th.Tensor] = None,
+ y: Optional[th.Tensor] = None,
+ time_context: Optional[th.Tensor] = None,
+ num_video_frames: Optional[int] = None,
+ image_only_indicator: Optional[th.Tensor] = None,
+ ):
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x
+ for module in self.input_blocks:
+ h = module(
+ h,
+ emb,
+ context=context,
+ image_only_indicator=image_only_indicator,
+ time_context=time_context,
+ num_video_frames=num_video_frames,
+ )
+ hs.append(h)
+ h = self.middle_block(
+ h,
+ emb,
+ context=context,
+ image_only_indicator=image_only_indicator,
+ time_context=time_context,
+ num_video_frames=num_video_frames,
+ )
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(
+ h,
+ emb,
+ context=context,
+ image_only_indicator=image_only_indicator,
+ time_context=time_context,
+ num_video_frames=num_video_frames,
+ )
+ h = h.type(x.dtype)
+ return self.out(h)
diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..37449ea63e992b9f89856f1f47c18ba68be8e334
--- /dev/null
+++ b/sgm/modules/diffusionmodules/wrappers.py
@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+from packaging import version
+
+OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
+
+
+class IdentityWrapper(nn.Module):
+ def __init__(self, diffusion_model, compile_model: bool = False):
+ super().__init__()
+ compile = (
+ torch.compile
+ if (version.parse(torch.__version__) >= version.parse("2.0.0"))
+ and compile_model
+ else lambda x: x
+ )
+ self.diffusion_model = compile(diffusion_model)
+
+ def forward(self, *args, **kwargs):
+ return self.diffusion_model(*args, **kwargs)
+
+
+class OpenAIWrapper(IdentityWrapper):
+ def forward(
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
+ ) -> torch.Tensor:
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
+ return self.diffusion_model(
+ x,
+ timesteps=t,
+ context=c.get("crossattn", None),
+ y=c.get("vector", None),
+ **kwargs,
+ )
diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py
new file mode 100644
index 0000000000000000000000000000000000000000..016be35523187ea366db9ade391fe8ee276db60b
--- /dev/null
+++ b/sgm/modules/distributions/distributions.py
@@ -0,0 +1,102 @@
+import numpy as np
+import torch
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(
+ device=self.parameters.device
+ )
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
+ device=self.parameters.device
+ )
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68
--- /dev/null
+++ b/sgm/modules/ema.py
@@ -0,0 +1,86 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+
+ self.m_name2s_name = {}
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ "num_updates",
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates
+ else torch.tensor(-1, dtype=torch.int),
+ )
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace(".", "")
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay * (shadow_params[sname] - m_param[key])
+ )
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/sgm/modules/encoders/image_encoder.py b/sgm/modules/encoders/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..60d693245bc562987376b7d0fff80086fb936279
--- /dev/null
+++ b/sgm/modules/encoders/image_encoder.py
@@ -0,0 +1,349 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import importlib
+
+
+def class_for_name(module_name, class_name):
+ # load the module, will raise ImportError if module cannot be loaded
+ m = importlib.import_module(module_name)
+ return getattr(m, class_name)
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ padding_mode="reflect",
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ padding_mode="reflect",
+ )
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ groups=1,
+ base_width=64,
+ dilation=1,
+ norm_layer=None,
+ ):
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ # norm_layer = nn.InstanceNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes, track_running_stats=False, affine=True)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes, track_running_stats=False, affine=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ groups=1,
+ base_width=64,
+ dilation=1,
+ norm_layer=None,
+ ):
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ # norm_layer = nn.InstanceNorm2d
+ width = int(planes * (base_width / 64.0)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width, track_running_stats=False, affine=True)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width, track_running_stats=False, affine=True)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(
+ planes * self.expansion, track_running_stats=False, affine=True
+ )
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class conv(nn.Module):
+ def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
+ super(conv, self).__init__()
+ self.kernel_size = kernel_size
+ self.conv = nn.Conv2d(
+ num_in_layers,
+ num_out_layers,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=(self.kernel_size - 1) // 2,
+ padding_mode="reflect",
+ )
+ # self.bn = nn.InstanceNorm2d(
+ # num_out_layers, track_running_stats=False, affine=True
+ # )
+ self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True)
+ # self.bn = nn.LayerNorm(num_out_layers)
+
+ def forward(self, x):
+ return F.elu(self.bn(self.conv(x)), inplace=True)
+
+
+class upconv(nn.Module):
+ def __init__(self, num_in_layers, num_out_layers, kernel_size, scale):
+ super(upconv, self).__init__()
+ self.scale = scale
+ self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1)
+
+ def forward(self, x):
+ x = nn.functional.interpolate(
+ x, scale_factor=self.scale, align_corners=True, mode="bilinear"
+ )
+ return self.conv(x)
+
+
+class ResUNet(nn.Module):
+ def __init__(
+ self,
+ encoder="resnet34",
+ coarse_out_ch=32,
+ fine_out_ch=32,
+ norm_layer=None,
+ coarse_only=False,
+ ):
+ super(ResUNet, self).__init__()
+ assert encoder in [
+ "resnet18",
+ "resnet34",
+ "resnet50",
+ "resnet101",
+ "resnet152",
+ ], "Incorrect encoder type"
+ if encoder in ["resnet18", "resnet34"]:
+ filters = [64, 128, 256, 512]
+ else:
+ filters = [256, 512, 1024, 2048]
+ self.coarse_only = coarse_only
+ if self.coarse_only:
+ fine_out_ch = 0
+ self.coarse_out_ch = coarse_out_ch
+ self.fine_out_ch = fine_out_ch
+ out_ch = coarse_out_ch + fine_out_ch
+
+ # original
+ layers = [3, 4, 6, 3]
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ # norm_layer = nn.InstanceNorm2d
+ self._norm_layer = norm_layer
+ self.dilation = 1
+ block = BasicBlock
+ replace_stride_with_dilation = [False, False, False]
+ self.inplanes = 64
+ self.groups = 1
+ self.base_width = 64
+ self.conv1 = nn.Conv2d(
+ 3,
+ self.inplanes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False,
+ padding_mode="reflect",
+ )
+ self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True)
+ self.relu = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
+ )
+ self.layer3 = self._make_layer(
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
+ )
+
+ # decoder
+ self.upconv3 = upconv(filters[2], 128, 3, 2)
+ self.iconv3 = conv(filters[1] + 128, 128, 3, 1)
+ self.upconv2 = upconv(128, 64, 3, 2)
+ self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1)
+
+ # fine-level conv
+ self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(
+ planes * block.expansion, track_running_stats=False, affine=True
+ ),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer,
+ )
+ )
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
+
+ return nn.Sequential(*layers)
+
+ def skipconnect(self, x1, x2):
+ diffY = x2.size()[2] - x1.size()[2]
+ diffX = x2.size()[3] - x1.size()[3]
+
+ x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
+
+ # for padding issues, see
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
+
+ x = torch.cat([x2, x1], dim=1)
+ return x
+
+ def forward(self, x):
+ x = self.relu(self.bn1(self.conv1(x)))
+
+ x1 = self.layer1(x)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+
+ x = self.upconv3(x3)
+ x = self.skipconnect(x2, x)
+ x = self.iconv3(x)
+
+ x = self.upconv2(x)
+ x = self.skipconnect(x1, x)
+ x = self.iconv2(x)
+
+ x_out = self.out_conv(x)
+
+ return x_out
+
+ # if self.coarse_only:
+ # x_coarse = x_out
+ # x_fine = None
+ # else:
+ # x_coarse = x_out[:, : self.coarse_out_ch, :]
+ # x_fine = x_out[:, -self.fine_out_ch :, :]
+ # return x_coarse, x_fine
diff --git a/sgm/modules/encoders/image_encoder_v2.py b/sgm/modules/encoders/image_encoder_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c782b3edee155fa4367e697a94d6b8b6b86b85
--- /dev/null
+++ b/sgm/modules/encoders/image_encoder_v2.py
@@ -0,0 +1,160 @@
+"""
+UNet Network in PyTorch, modified from https://github.com/milesial/Pytorch-UNet
+with architecture referenced from https://keras.io/examples/vision/depth_estimation
+for monocular depth estimation from RGB images, i.e. one output channel.
+"""
+
+import torch
+from torch import nn
+
+
+class UNet(nn.Module):
+ """
+ The overall UNet architecture.
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ self.downscale_blocks = nn.ModuleList(
+ [
+ DownBlock(16, 32),
+ DownBlock(32, 64),
+ DownBlock(64, 128),
+ DownBlock(128, 256),
+ ]
+ )
+ self.upscale_blocks = nn.ModuleList(
+ [
+ UpBlock(256, 128),
+ UpBlock(128, 64),
+ UpBlock(64, 32),
+ UpBlock(32, 16),
+ ]
+ )
+
+ self.input_conv = nn.Conv2d(3, 16, kernel_size=3, padding="same")
+ self.output_conv = nn.Conv2d(16, 1, kernel_size=1)
+ self.bridge = BottleNeckBlock(256)
+ self.activation = nn.Sigmoid()
+
+ def forward(self, x):
+ x = self.input_conv(x)
+
+ skip_features = []
+ for block in self.downscale_blocks:
+ c, x = block(x)
+ skip_features.append(c)
+
+ x = self.bridge(x)
+
+ skip_features.reverse()
+ for block, skip in zip(self.upscale_blocks, skip_features):
+ x = block(x, skip)
+
+ x = self.output_conv(x)
+ x = self.activation(x)
+ return x
+
+
+class DownBlock(nn.Module):
+ """
+ Module that performs downscaling with residual connections.
+ """
+
+ def __init__(self, in_channels, out_channels, padding="same", stride=1):
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.conv2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn1 = nn.BatchNorm2d(out_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+ self.relu = nn.LeakyReLU(0.2)
+ self.maxpool = nn.MaxPool2d(2)
+
+ def forward(self, x):
+ d = self.conv1(x)
+ x = self.bn1(d)
+ x = self.relu(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+
+ x = x + d
+ p = self.maxpool(x)
+ return x, p
+
+
+class UpBlock(nn.Module):
+ """
+ Module that performs upscaling after concatenation with skip connections.
+ """
+
+ def __init__(self, in_channels, out_channels, padding="same", stride=1):
+ super().__init__()
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
+ self.conv1 = nn.Conv2d(
+ in_channels * 2,
+ in_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.conv2 = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn1 = nn.BatchNorm2d(in_channels)
+ self.bn2 = nn.BatchNorm2d(out_channels)
+ self.relu = nn.LeakyReLU(0.2)
+
+ def forward(self, x, skip):
+ x = self.up(x)
+ x = torch.cat([x, skip], dim=1)
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.relu(x)
+ return x
+
+
+class BottleNeckBlock(nn.Module):
+ """
+ BottleNeckBlock that serves as the UNet bridge.
+ """
+
+ def __init__(self, channels, padding="same", strides=1):
+ super().__init__()
+ self.conv1 = nn.Conv2d(channels, channels, 3, 1, "same")
+ self.conv2 = nn.Conv2d(channels, channels, 3, 1, "same")
+ self.relu = nn.LeakyReLU(0.2)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.relu(x)
+ return x
\ No newline at end of file
diff --git a/sgm/modules/encoders/math_utils.py b/sgm/modules/encoders/math_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..35e4f6f876ebb5878fa05e93bdf10488cb73e297
--- /dev/null
+++ b/sgm/modules/encoders/math_utils.py
@@ -0,0 +1,139 @@
+# MIT License
+
+# Copyright (c) 2022 Petr Kellnhofer
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import torch
+
+
+def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
+ """
+ Left-multiplies MxM @ NxM. Returns NxM.
+ """
+ res = torch.matmul(vectors4, matrix.T)
+ return res
+
+
+def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize vector lengths.
+ """
+ return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
+
+
+def torch_dot(x: torch.Tensor, y: torch.Tensor):
+ """
+ Dot product of two tensors.
+ """
+ return (x * y).sum(-1)
+
+
+def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
+ """
+ Author: Petr Kellnhofer
+ Intersects rays with the [-1, 1] NDC volume.
+ Returns min and max distance of entry.
+ Returns -1 for no intersection.
+ https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
+ """
+ o_shape = rays_o.shape
+ rays_o = rays_o.detach().reshape(-1, 3)
+ rays_d = rays_d.detach().reshape(-1, 3)
+
+ bb_min = [
+ -1 * (box_side_length / 2),
+ -1 * (box_side_length / 2),
+ -1 * (box_side_length / 2),
+ ]
+ bb_max = [
+ 1 * (box_side_length / 2),
+ 1 * (box_side_length / 2),
+ 1 * (box_side_length / 2),
+ ]
+ bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
+ is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
+
+ # Precompute inverse for stability.
+ invdir = 1 / rays_d
+ sign = (invdir < 0).long()
+
+ # Intersect with YZ plane.
+ tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[
+ ..., 0
+ ]
+ tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[
+ ..., 0
+ ]
+
+ # Intersect with XZ plane.
+ tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[
+ ..., 1
+ ]
+ tymax = (
+ bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]
+ ) * invdir[..., 1]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tymin)
+ tmax = torch.min(tmax, tymax)
+
+ # Intersect with XY plane.
+ tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[
+ ..., 2
+ ]
+ tzmax = (
+ bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]
+ ) * invdir[..., 2]
+
+ # Resolve parallel rays.
+ is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
+
+ # Use the shortest intersection.
+ tmin = torch.max(tmin, tzmin)
+ tmax = torch.min(tmax, tzmax)
+
+ # Mark invalid.
+ tmin[torch.logical_not(is_valid)] = -1
+ tmax[torch.logical_not(is_valid)] = -2
+
+ return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
+
+
+def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
+ """
+ Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
+ Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
+ """
+ # create a tensor of 'num' steps from 0 to 1
+ steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
+
+ # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
+ # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
+ # "cannot statically infer the expected size of a list in this contex", hence the code below
+ for i in range(start.ndim):
+ steps = steps.unsqueeze(-1)
+
+ # the output starts at 'start' and increments until 'stop' in each dimension
+ out = start[None] + steps * (stop - start)[None]
+
+ return out
diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..9860779362c766f4e9171d98c7411a2b178a842d
--- /dev/null
+++ b/sgm/modules/encoders/modules.py
@@ -0,0 +1,1189 @@
+import math
+from contextlib import nullcontext
+from functools import partial
+from typing import Dict, List, Optional, Tuple, Union
+
+import kornia
+import numpy as np
+import open_clip
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+from omegaconf import ListConfig
+
+# from torch.utils.checkpoint import checkpoint
+
+checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
+
+from transformers import (
+ ByT5Tokenizer,
+ CLIPTextModel,
+ CLIPTokenizer,
+ T5EncoderModel,
+ T5Tokenizer,
+)
+
+from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer
+from ...modules.diffusionmodules.model import Encoder
+from ...modules.diffusionmodules.openaimodel import Timestep
+from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ...modules.distributions.distributions import DiagonalGaussianDistribution
+from ...util import (
+ append_dims,
+ autocast,
+ count_params,
+ default,
+ disabled_train,
+ expand_dims_like,
+ instantiate_from_config,
+)
+
+
+class AbstractEmbModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self._is_trainable = None
+ self._ucg_rate = None
+ self._input_key = None
+
+ @property
+ def is_trainable(self) -> bool:
+ return self._is_trainable
+
+ @property
+ def ucg_rate(self) -> Union[float, torch.Tensor]:
+ return self._ucg_rate
+
+ @property
+ def input_key(self) -> str:
+ return self._input_key
+
+ @is_trainable.setter
+ def is_trainable(self, value: bool):
+ self._is_trainable = value
+
+ @ucg_rate.setter
+ def ucg_rate(self, value: Union[float, torch.Tensor]):
+ self._ucg_rate = value
+
+ @input_key.setter
+ def input_key(self, value: str):
+ self._input_key = value
+
+ @is_trainable.deleter
+ def is_trainable(self):
+ del self._is_trainable
+
+ @ucg_rate.deleter
+ def ucg_rate(self):
+ del self._ucg_rate
+
+ @input_key.deleter
+ def input_key(self):
+ del self._input_key
+
+
+class GeneralConditioner(nn.Module):
+ OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
+ KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
+
+ def __init__(self, emb_models: Union[List, ListConfig]):
+ super().__init__()
+ embedders = []
+ for n, embconfig in enumerate(emb_models):
+ embedder = instantiate_from_config(embconfig)
+ assert isinstance(
+ embedder, AbstractEmbModel
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
+ embedder.is_trainable = embconfig.get("is_trainable", False)
+ embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
+ if not embedder.is_trainable:
+ embedder.train = disabled_train
+ for param in embedder.parameters():
+ param.requires_grad = False
+ embedder.eval()
+ print(
+ f"Initialized embedder #{n}: {embedder.__class__.__name__} "
+ f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
+ )
+
+ if "input_key" in embconfig:
+ embedder.input_key = embconfig["input_key"]
+ elif "input_keys" in embconfig:
+ embedder.input_keys = embconfig["input_keys"]
+ else:
+ raise KeyError(
+ f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}"
+ )
+
+ embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
+ if embedder.legacy_ucg_val is not None:
+ embedder.ucg_prng = np.random.RandomState()
+
+ embedders.append(embedder)
+ self.embedders = nn.ModuleList(embedders)
+
+ def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict:
+ assert embedder.legacy_ucg_val is not None
+ p = embedder.ucg_rate
+ val = embedder.legacy_ucg_val
+ for i in range(len(batch[embedder.input_key])):
+ if embedder.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[embedder.input_key][i] = val
+ return batch
+
+ def forward(
+ self, batch: Dict, force_zero_embeddings: Optional[List] = None
+ ) -> Dict:
+ output = dict()
+ if force_zero_embeddings is None:
+ force_zero_embeddings = []
+ for embedder in self.embedders:
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
+ with embedding_context():
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
+ if embedder.legacy_ucg_val is not None:
+ batch = self.possibly_get_ucg_val(embedder, batch)
+ emb_out = embedder(batch[embedder.input_key])
+ elif hasattr(embedder, "input_keys"):
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
+ assert isinstance(
+ emb_out, (torch.Tensor, list, tuple)
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
+ if not isinstance(emb_out, (list, tuple)):
+ emb_out = [emb_out]
+ for emb in emb_out:
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ emb = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - embedder.ucg_rate)
+ * torch.ones(emb.shape[0], device=emb.device)
+ ),
+ emb,
+ )
+ * emb
+ )
+ if (
+ hasattr(embedder, "input_key")
+ and embedder.input_key in force_zero_embeddings
+ ):
+ emb = torch.zeros_like(emb)
+ if out_key in output:
+ output[out_key] = torch.cat(
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
+ )
+ else:
+ output[out_key] = emb
+
+ # if "num_video_frames" in batch:
+ # num_frames = batch["num_video_frames"]
+ # for k in ["crossattn", "concat"]:
+ # output[k] = repeat(output[k], "b ... -> b t ...", t=num_frames)
+ # output[k] = rearrange(output[k], "b t ... -> (b t) ...", t=num_frames)
+
+ return output
+
+ def get_unconditional_conditioning(
+ self,
+ batch_c: Dict,
+ batch_uc: Optional[Dict] = None,
+ force_uc_zero_embeddings: Optional[List[str]] = None,
+ force_cond_zero_embeddings: Optional[List[str]] = None,
+ ):
+ if force_uc_zero_embeddings is None:
+ force_uc_zero_embeddings = []
+ ucg_rates = list()
+ for embedder in self.embedders:
+ ucg_rates.append(embedder.ucg_rate)
+ embedder.ucg_rate = 0.0
+ c = self(batch_c, force_cond_zero_embeddings)
+ uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
+
+ for embedder, rate in zip(self.embedders, ucg_rates):
+ embedder.ucg_rate = rate
+ return c, uc
+
+
+class InceptionV3(nn.Module):
+ """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
+ port with an additional squeeze at the end"""
+
+ def __init__(self, normalize_input=False, **kwargs):
+ super().__init__()
+ from pytorch_fid import inception
+
+ kwargs["resize_input"] = True
+ self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs)
+
+ def forward(self, inp):
+ outp = self.model(inp)
+
+ if len(outp) == 1:
+ return outp[0].squeeze()
+
+ return outp
+
+
+class IdentityEncoder(AbstractEmbModel):
+ def encode(self, x):
+ return x
+
+ def forward(self, x):
+ return x
+
+
+class ClassEmbedder(AbstractEmbModel):
+ def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False):
+ super().__init__()
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.add_sequence_dim = add_sequence_dim
+
+ def forward(self, c):
+ c = self.embedding(c)
+ if self.add_sequence_dim:
+ c = c[:, None, :]
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = (
+ self.n_classes - 1
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc.long()}
+ return uc
+
+
+class ClassEmbedderForMultiCond(ClassEmbedder):
+ def forward(self, batch, key=None, disable_dropout=False):
+ out = batch
+ key = default(key, self.key)
+ islist = isinstance(batch[key], list)
+ if islist:
+ batch[key] = batch[key][0]
+ c_out = super().forward(batch, key, disable_dropout)
+ out[key] = [c_out] if islist else c_out
+ return out
+
+
+class FrozenT5Embedder(AbstractEmbModel):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(
+ self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenByT5Embedder(AbstractEmbModel):
+ """
+ Uses the ByT5 transformer encoder for text. Is character-aware.
+ """
+
+ def __init__(
+ self, version="google/byt5-base", device="cuda", max_length=77, freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = ByT5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ with torch.autocast("cuda", enabled=False):
+ outputs = self.transformer(input_ids=tokens)
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEmbModel):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+
+ LAYERS = ["last", "pooled", "hidden"]
+
+ def __init__(
+ self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ layer_idx=None,
+ always_return_pooled=False,
+ ): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ self.return_pooled = always_return_pooled
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
+ )
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ if self.return_pooled:
+ return z, outputs.pooler_output
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder2(AbstractEmbModel):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+
+ LAYERS = ["pooled", "last", "penultimate"]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ always_return_pooled=False,
+ legacy=True,
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ self.return_pooled = always_return_pooled
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+ self.legacy = legacy
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ if not self.return_pooled and self.legacy:
+ return z
+ if self.return_pooled:
+ assert not self.legacy
+ return z[self.layer], z["pooled"]
+ return z[self.layer]
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ if self.legacy:
+ x = x[self.layer]
+ x = self.model.ln_final(x)
+ return x
+ else:
+ # x is a dict and will stay a dict
+ o = x["last"]
+ o = self.model.ln_final(o)
+ pooled = self.pool(o, text)
+ x["pooled"] = pooled
+ return x
+
+ def pool(self, x, text):
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = (
+ x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
+ @ self.model.text_projection
+ )
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ outputs = {}
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - 1:
+ outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ outputs["last"] = x.permute(1, 0, 2) # LND -> NLD
+ return outputs
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEmbModel):
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate",
+ ]
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ ):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device("cpu"),
+ pretrained=version,
+ )
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if (
+ self.model.transformer.grad_checkpointing
+ and not torch.jit.is_scripting()
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(
+ self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ antialias=True,
+ ucg_rate=0.0,
+ unsqueeze_dim=False,
+ repeat_to_max_len=False,
+ num_image_crops=0,
+ output_tokens=False,
+ init_device=None,
+ ):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device(default(init_device, "cpu")),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.max_crops = num_image_crops
+ self.pad_to_max_len = self.max_crops > 0
+ self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+
+ self.antialias = antialias
+
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
+ self.ucg_rate = ucg_rate
+ self.unsqueeze_dim = unsqueeze_dim
+ self.stored_batch = None
+ self.model.visual.output_tokens = output_tokens
+ self.output_tokens = output_tokens
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ tokens = None
+ if self.output_tokens:
+ z, tokens = z[0], z[1]
+ z = z.to(image.dtype)
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
+ z = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
+ )[:, None]
+ * z
+ )
+ if tokens is not None:
+ tokens = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(tokens.shape[0], device=tokens.device)
+ ),
+ tokens,
+ )
+ * tokens
+ )
+ if self.unsqueeze_dim:
+ z = z[:, None, :]
+ if self.output_tokens:
+ assert not self.repeat_to_max_len
+ assert not self.pad_to_max_len
+ return tokens, z
+ if self.repeat_to_max_len:
+ if z.dim() == 2:
+ z_ = z[:, None, :]
+ else:
+ z_ = z
+ return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
+ elif self.pad_to_max_len:
+ assert z.dim() == 3
+ z_pad = torch.cat(
+ (
+ z,
+ torch.zeros(
+ z.shape[0],
+ self.max_length - z.shape[1],
+ z.shape[2],
+ device=z.device,
+ ),
+ ),
+ 1,
+ )
+ return z_pad, z_pad[:, 0, ...]
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ # if self.max_crops > 0:
+ # img = self.preprocess_by_cropping(img)
+ if img.dim() == 5:
+ assert self.max_crops == img.shape[1]
+ img = rearrange(img, "b n c h w -> (b n) c h w")
+ img = self.preprocess(img)
+ if not self.output_tokens:
+ assert not self.model.visual.output_tokens
+ x = self.model.visual(img)
+ tokens = None
+ else:
+ assert self.model.visual.output_tokens
+ x, tokens = self.model.visual(img)
+ if self.max_crops > 0:
+ x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
+ # drop out between 0 and all along the sequence axis
+ x = (
+ torch.bernoulli(
+ (1.0 - self.ucg_rate)
+ * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
+ )
+ * x
+ )
+ if tokens is not None:
+ tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
+ print(
+ f"You are running very experimental token-concat in {self.__class__.__name__}. "
+ f"Check what you are doing, and then remove this message."
+ )
+ if self.output_tokens:
+ return x, tokens
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEmbModel):
+ def __init__(
+ self,
+ clip_version="openai/clip-vit-large-patch14",
+ t5_version="google/t5-v1_1-xl",
+ device="cuda",
+ clip_max_length=77,
+ t5_max_length=77,
+ ):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(
+ clip_version, device, max_length=clip_max_length
+ )
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
+ )
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+class SpatialRescaler(nn.Module):
+ def __init__(
+ self,
+ n_stages=1,
+ method="bilinear",
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False,
+ wrap_video=False,
+ kernel_size=1,
+ remap_output=False,
+ ):
+ super().__init__()
+ self.n_stages = n_stages
+ assert self.n_stages >= 0
+ assert method in [
+ "nearest",
+ "linear",
+ "bilinear",
+ "trilinear",
+ "bicubic",
+ "area",
+ ]
+ self.multiplier = multiplier
+ self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
+ self.remap_output = out_channels is not None or remap_output
+ if self.remap_output:
+ print(
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
+ )
+ self.channel_mapper = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ bias=bias,
+ padding=kernel_size // 2,
+ )
+ self.wrap_video = wrap_video
+
+ def forward(self, x):
+ if self.wrap_video and x.ndim == 5:
+ B, C, T, H, W = x.shape
+ x = rearrange(x, "b c t h w -> b t c h w")
+ x = rearrange(x, "b t c h w -> (b t) c h w")
+
+ for stage in range(self.n_stages):
+ x = self.interpolator(x, scale_factor=self.multiplier)
+
+ if self.wrap_video:
+ x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C)
+ x = rearrange(x, "b t c h w -> b c t h w")
+ if self.remap_output:
+ x = self.channel_mapper(x)
+ return x
+
+ def encode(self, x):
+ return self(x)
+
+
+class LowScaleEncoder(nn.Module):
+ def __init__(
+ self,
+ model_config,
+ linear_start,
+ linear_end,
+ timesteps=1000,
+ max_noise_level=250,
+ output_size=64,
+ scale_factor=1.0,
+ ):
+ super().__init__()
+ self.max_noise_level = max_noise_level
+ self.model = instantiate_from_config(model_config)
+ self.augmentation_schedule = self.register_schedule(
+ timesteps=timesteps, linear_start=linear_start, linear_end=linear_end
+ )
+ self.out_size = output_size
+ self.scale_factor = scale_factor
+
+ def register_schedule(
+ self,
+ beta_schedule="linear",
+ timesteps=1000,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ ):
+ betas = make_beta_schedule(
+ beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s,
+ )
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert (
+ alphas_cumprod.shape[0] == self.num_timesteps
+ ), "alphas have to be defined for each timestep"
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer("betas", to_torch(betas))
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
+ )
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def forward(self, x):
+ z = self.model.encode(x)
+ if isinstance(z, DiagonalGaussianDistribution):
+ z = z.sample()
+ z = z * self.scale_factor
+ noise_level = torch.randint(
+ 0, self.max_noise_level, (x.shape[0],), device=x.device
+ ).long()
+ z = self.q_sample(z, noise_level)
+ if self.out_size is not None:
+ z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest")
+ return z, noise_level
+
+ def decode(self, z):
+ z = z / self.scale_factor
+ return self.model.decode(z)
+
+
+class ConcatTimestepEmbedderND(AbstractEmbModel):
+ """embeds each dimension independently and concatenates them"""
+
+ def __init__(self, outdim):
+ super().__init__()
+ self.timestep = Timestep(outdim)
+ self.outdim = outdim
+
+ def forward(self, x):
+ if x.ndim == 1:
+ x = x[:, None]
+ assert len(x.shape) == 2
+ b, dims = x.shape[0], x.shape[1]
+ x = rearrange(x, "b d -> (b d)")
+ emb = self.timestep(x)
+ emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
+ return emb
+
+
+class GaussianEncoder(Encoder, AbstractEmbModel):
+ def __init__(
+ self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ self.posterior = DiagonalGaussianRegularizer()
+ self.weight = weight
+ self.flatten_output = flatten_output
+
+ def forward(self, x) -> Tuple[Dict, torch.Tensor]:
+ z = super().forward(x)
+ z, log = self.posterior(z)
+ log["loss"] = log["kl_loss"]
+ log["weight"] = self.weight
+ if self.flatten_output:
+ z = rearrange(z, "b c h w -> b (h w ) c")
+ return log, z
+
+
+class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
+ def __init__(
+ self,
+ n_cond_frames: int,
+ n_copies: int,
+ encoder_config: dict,
+ sigma_sampler_config: Optional[dict] = None,
+ sigma_cond_config: Optional[dict] = None,
+ is_ae: bool = False,
+ scale_factor: float = 1.0,
+ disable_encoder_autocast: bool = False,
+ en_and_decode_n_samples_a_time: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+ self.encoder = instantiate_from_config(encoder_config)
+ self.sigma_sampler = (
+ instantiate_from_config(sigma_sampler_config)
+ if sigma_sampler_config is not None
+ else None
+ )
+ self.sigma_cond = (
+ instantiate_from_config(sigma_cond_config)
+ if sigma_cond_config is not None
+ else None
+ )
+ self.is_ae = is_ae
+ self.scale_factor = scale_factor
+ self.disable_encoder_autocast = disable_encoder_autocast
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+
+ def forward(
+ self, vid: torch.Tensor
+ ) -> Union[
+ torch.Tensor,
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, dict],
+ Tuple[Tuple[torch.Tensor, torch.Tensor], dict],
+ ]:
+ if self.sigma_sampler is not None:
+ b = vid.shape[0] // self.n_cond_frames
+ sigmas = self.sigma_sampler(b).to(vid.device)
+ if self.sigma_cond is not None:
+ sigma_cond = self.sigma_cond(sigmas)
+ sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
+ sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
+ noise = torch.randn_like(vid)
+ vid = vid + noise * append_dims(sigmas, vid.ndim)
+
+ with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
+ n_samples = (
+ self.en_and_decode_n_samples_a_time
+ if self.en_and_decode_n_samples_a_time is not None
+ else vid.shape[0]
+ )
+ n_rounds = math.ceil(vid.shape[0] / n_samples)
+ all_out = []
+ for n in range(n_rounds):
+ if self.is_ae:
+ out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples])
+ else:
+ out = self.encoder(vid[n * n_samples : (n + 1) * n_samples])
+ all_out.append(out)
+
+ vid = torch.cat(all_out, dim=0)
+ vid *= self.scale_factor
+
+ vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
+ vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
+ # modified for svd
+ # vid = repeat(vid, "b 1 c h w -> b t c h w", t=self.n_copies)
+
+ return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid
+
+ return return_val
+
+
+class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
+ def __init__(
+ self,
+ open_clip_embedding_config: Dict,
+ n_cond_frames: int,
+ n_copies: int,
+ ):
+ super().__init__()
+
+ self.n_cond_frames = n_cond_frames
+ self.n_copies = n_copies
+ self.open_clip = instantiate_from_config(open_clip_embedding_config)
+
+ def forward(self, vid):
+ vid = self.open_clip(vid)
+ vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
+ vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
+
+ return vid
+
+
+class PixelNeRFEmbedder(AbstractEmbModel):
+ def __init__(
+ self,
+ image_encoder_config: dict,
+ pixelnerf_encoder_config: dict,
+ render_size: int,
+ num_video_frames: int,
+ ):
+ super().__init__()
+ self.render_size = render_size
+ self.num_video_frames = num_video_frames
+ self.image_encoder = instantiate_from_config(image_encoder_config)
+ self.pixelnerf_encoder = instantiate_from_config(pixelnerf_encoder_config)
+
+ def forward(self, pixelnerf_input):
+ if "source_index" not in pixelnerf_input:
+ source_images = pixelnerf_input["frames"][:, 0]
+ image_feats = self.image_encoder(source_images)
+ image_feats = image_feats[:, None]
+ source_cameras = pixelnerf_input["cameras"][:, :1]
+ else:
+ # source_images = pixelnerf_input["frames"][
+ # :, pixelnerf_input["source_index"]
+ # ]
+ source_images = pixelnerf_input["source_images"]
+ n_source_images = source_images.shape[1]
+ source_images = rearrange(source_images, "b t c h w -> (b t) c h w")
+ image_feats = self.image_encoder(source_images)
+ image_feats = rearrange(
+ image_feats, "(b t) c h w -> b t c h w", t=n_source_images
+ )
+ source_cameras = pixelnerf_input["source_cameras"]
+ cameras = pixelnerf_input["cameras"]
+ target_cameras = cameras[:, :]
+ # source_images = source_images[:, None, ...]
+ source_c2ws = source_cameras[..., :16].reshape(*source_cameras.shape[:-1], 4, 4)
+ source_intrinsics = source_cameras[..., 16:].reshape(
+ *source_cameras.shape[:-1], 3, 3
+ )
+ target_c2ws = target_cameras[..., :16].reshape(*target_cameras.shape[:-1], 4, 4)
+ target_intrinsics = target_cameras[..., 16:].reshape(
+ *target_cameras.shape[:-1], 3, 3
+ )
+
+ rgb, feats = self.pixelnerf_encoder(
+ image_feats,
+ source_c2ws,
+ source_intrinsics,
+ target_c2ws,
+ target_intrinsics,
+ self.render_size,
+ )
+
+ rgb = rearrange(rgb, "b t c h w -> (b t) c h w")
+ feats = rearrange(feats, "b t c h w -> (b t) c h w")
+
+ return rgb, feats
+
+
+class ExtraConditioner(GeneralConditioner):
+ def forward(self, batch: Dict, force_zero_embeddings: List | None = None) -> Dict:
+ bs = batch["frames"].shape[0]
+ num_frames = batch["num_video_frames"]
+ output = dict()
+ if force_zero_embeddings is None:
+ force_zero_embeddings = []
+ for embedder in self.embedders:
+ embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
+ with embedding_context():
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
+ if embedder.legacy_ucg_val is not None:
+ batch = self.possibly_get_ucg_val(embedder, batch)
+ emb_out = embedder(batch[embedder.input_key])
+ elif hasattr(embedder, "input_keys"):
+ emb_out = embedder(*[batch[k] for k in embedder.input_keys])
+ assert isinstance(
+ emb_out, (torch.Tensor, list, tuple)
+ ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
+ if not isinstance(emb_out, (list, tuple)):
+ emb_out = [emb_out]
+ if isinstance(embedder, PixelNeRFEmbedder):
+ # a hack for pixelnerf input
+ output["rgb"] = emb_out[0]
+ emb_out = emb_out[1:]
+ for emb in emb_out:
+ out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
+ if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
+ emb = (
+ expand_dims_like(
+ torch.bernoulli(
+ (1.0 - embedder.ucg_rate)
+ * torch.ones(emb.shape[0], device=emb.device)
+ ),
+ emb,
+ )
+ * emb
+ )
+ if (
+ hasattr(embedder, "input_key")
+ and embedder.input_key in force_zero_embeddings
+ ):
+ emb = torch.zeros_like(emb)
+ if out_key in output:
+ output[out_key] = torch.cat(
+ (output[out_key], emb), self.KEY2CATDIM[out_key]
+ )
+ else:
+ output[out_key] = emb
+
+ if out_key in ["crossattn", "concat"]:
+ if output[out_key].shape[0] != bs:
+ output[out_key] = repeat(
+ output[out_key], "b ... -> (b t) ...", t=num_frames
+ )
+ return output
diff --git a/sgm/modules/encoders/pixelnerf.py b/sgm/modules/encoders/pixelnerf.py
new file mode 100644
index 0000000000000000000000000000000000000000..515699c3aa52097e27ddde98c3491547c2e3a0b7
--- /dev/null
+++ b/sgm/modules/encoders/pixelnerf.py
@@ -0,0 +1,368 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.autograd.profiler as profiler
+import numpy as np
+from einops import rearrange, repeat, einsum
+
+from .math_utils import get_ray_limits_box, linspace
+
+from ...modules.diffusionmodules.openaimodel import Timestep
+
+
+class ImageEncoder(nn.Module):
+ def __init__(self, output_dim: int = 64) -> None:
+ super().__init__()
+ self.output_dim = output_dim
+
+ def forward(self, image):
+ return image
+
+
+class PositionalEncoding(torch.nn.Module):
+ """
+ Implement NeRF's positional encoding
+ """
+
+ def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True):
+ super().__init__()
+ self.num_freqs = num_freqs
+ self.d_in = d_in
+ self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs)
+ self.d_out = self.num_freqs * 2 * d_in
+ self.include_input = include_input
+ if include_input:
+ self.d_out += d_in
+ # f1 f1 f2 f2 ... to multiply x by
+ self.register_buffer(
+ "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1)
+ )
+ # 0 pi/2 0 pi/2 ... so that
+ # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...)
+ _phases = torch.zeros(2 * self.num_freqs)
+ _phases[1::2] = np.pi * 0.5
+ self.register_buffer("_phases", _phases.view(1, -1, 1))
+
+ def forward(self, x):
+ """
+ Apply positional encoding (new implementation)
+ :param x (batch, self.d_in)
+ :return (batch, self.d_out)
+ """
+ with profiler.record_function("positional_enc"):
+ # embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1)
+ embed = repeat(x, "... C -> ... N C", N=self.num_freqs * 2)
+ embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs))
+ embed = rearrange(embed, "... N C -> ... (N C)")
+ if self.include_input:
+ embed = torch.cat((x, embed), dim=-1)
+ return embed
+
+
+class RayGenerator(torch.nn.Module):
+ """
+ from camera pose and intrinsics to ray origins and directions
+ """
+
+ def __init__(self):
+ super().__init__()
+ (
+ self.ray_origins_h,
+ self.ray_directions,
+ self.depths,
+ self.image_coords,
+ self.rendering_options,
+ ) = (None, None, None, None, None)
+
+ def forward(self, cam2world_matrix, intrinsics, render_size):
+ """
+ Create batches of rays and return origins and directions.
+
+ cam2world_matrix: (N, 4, 4)
+ intrinsics: (N, 3, 3)
+ render_size: int
+
+ ray_origins: (N, M, 3)
+ ray_dirs: (N, M, 2)
+ """
+
+ N, M = cam2world_matrix.shape[0], render_size**2
+ cam_locs_world = cam2world_matrix[:, :3, 3]
+ fx = intrinsics[:, 0, 0]
+ fy = intrinsics[:, 1, 1]
+ cx = intrinsics[:, 0, 2]
+ cy = intrinsics[:, 1, 2]
+ sk = intrinsics[:, 0, 1]
+
+ uv = torch.stack(
+ torch.meshgrid(
+ torch.arange(
+ render_size, dtype=torch.float32, device=cam2world_matrix.device
+ ),
+ torch.arange(
+ render_size, dtype=torch.float32, device=cam2world_matrix.device
+ ),
+ indexing="ij",
+ )
+ )
+ uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
+ uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
+
+ x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size)
+ y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size)
+ z_cam = torch.ones((N, M), device=cam2world_matrix.device)
+
+ x_lift = (
+ (
+ x_cam
+ - cx.unsqueeze(-1)
+ + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
+ - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1)
+ )
+ / fx.unsqueeze(-1)
+ * z_cam
+ )
+ y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
+
+ cam_rel_points = torch.stack(
+ (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1
+ )
+
+ # NOTE: this should be named _blender2opencv
+ _opencv2blender = (
+ torch.tensor(
+ [
+ [1, 0, 0, 0],
+ [0, -1, 0, 0],
+ [0, 0, -1, 0],
+ [0, 0, 0, 1],
+ ],
+ dtype=torch.float32,
+ device=cam2world_matrix.device,
+ )
+ .unsqueeze(0)
+ .repeat(N, 1, 1)
+ )
+
+ cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
+
+ world_rel_points = torch.bmm(
+ cam2world_matrix, cam_rel_points.permute(0, 2, 1)
+ ).permute(0, 2, 1)[:, :, :3]
+
+ ray_dirs = world_rel_points - cam_locs_world[:, None, :]
+ ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
+
+ ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
+
+ return ray_origins, ray_dirs
+
+
+class RaySampler(torch.nn.Module):
+ def __init__(
+ self,
+ num_samples_per_ray,
+ bbox_length=1.0,
+ near=0.5,
+ far=10000.0,
+ disparity=False,
+ ):
+ super().__init__()
+ self.num_samples_per_ray = num_samples_per_ray
+ self.bbox_length = bbox_length
+ self.near = near
+ self.far = far
+ self.disparity = disparity
+
+ def forward(self, ray_origins, ray_directions):
+ if not self.disparity:
+ t_start, t_end = get_ray_limits_box(
+ ray_origins, ray_directions, 2 * self.bbox_length
+ )
+ else:
+ t_start = torch.full_like(ray_origins, self.near)
+ t_end = torch.full_like(ray_origins, self.far)
+ is_ray_valid = t_end > t_start
+ if torch.any(is_ray_valid).item():
+ t_start[~is_ray_valid] = t_start[is_ray_valid].min()
+ t_end[~is_ray_valid] = t_start[is_ray_valid].max()
+
+ if not self.disparity:
+ depths = linspace(t_start, t_end, self.num_samples_per_ray)
+ depths += (
+ torch.rand_like(depths)
+ * (t_end - t_start)
+ / (self.num_samples_per_ray - 1)
+ )
+ else:
+ step = 1.0 / self.num_samples_per_ray
+ z_steps = torch.linspace(
+ 0, 1 - step, self.num_samples_per_ray, device=ray_origins.device
+ )
+ z_steps += torch.rand_like(z_steps) * step
+ depths = 1 / (1 / self.near * (1 - z_steps) + 1 / self.far * z_steps)
+ depths = depths[..., None, None, None]
+
+ return ray_origins[None] + ray_directions[None] * depths
+
+
+class PixelNeRF(torch.nn.Module):
+ def __init__(
+ self,
+ num_samples_per_ray: int = 128,
+ feature_dim: int = 64,
+ interp: str = "bilinear",
+ padding: str = "border",
+ disparity: bool = False,
+ near: float = 0.5,
+ far: float = 10000.0,
+ use_feats_std: bool = False,
+ use_pos_emb: bool = False,
+ ) -> None:
+ super().__init__()
+ # self.positional_encoder = Timestep(3) # TODO
+ self.num_samples_per_ray = num_samples_per_ray
+ self.ray_generator = RayGenerator()
+ self.ray_sampler = RaySampler(
+ num_samples_per_ray, near=near, far=far, disparity=disparity
+ ) # TODO
+ self.interp = interp
+ self.padding = padding
+
+ self.positional_encoder = PositionalEncoding()
+
+ # self.feature_aggregator = nn.Linear(128, 129) # TODO
+ self.use_feats_std = use_feats_std
+ self.use_pos_emb = use_pos_emb
+ d_in = feature_dim
+ if use_feats_std:
+ d_in += feature_dim
+ if use_pos_emb:
+ d_in += self.positional_encoder.d_out
+ self.feature_aggregator = nn.Sequential(
+ nn.Linear(d_in, 128),
+ nn.ReLU(),
+ nn.Linear(128, 128),
+ nn.ReLU(),
+ nn.Linear(128, 129),
+ )
+
+ # self.decoder = nn.Linear(128, 131) # TODO
+ self.decoder = nn.Sequential(
+ nn.Linear(128, 128),
+ nn.ReLU(),
+ nn.Linear(128, 128),
+ nn.ReLU(),
+ nn.Linear(128, 131),
+ )
+
+ def project(self, ray_samples, source_c2ws, source_instrincs):
+ # TODO: implement
+ # S for number of source cameras
+ # ray_samples: [B, N, H * W, N_sample, 3]
+ # source_c2ws: [B, S, 4, 4]
+ # source_intrinsics: [B, S, 3, 3]
+ # return [B, S, N, H * W, N_sample, 2]
+ S = source_c2ws.shape[1]
+ B = ray_samples.shape[0]
+ N = ray_samples.shape[1]
+ HW = ray_samples.shape[2]
+ ray_samples = repeat(
+ ray_samples,
+ "B N HW N_sample C -> B S N HW N_sample C",
+ S=source_c2ws.shape[1],
+ )
+ padding = torch.ones((B, S, N, HW, self.num_samples_per_ray, 1)).to(ray_samples)
+ ray_samples_homo = torch.cat([ray_samples, padding], dim=-1)
+ source_c2ws = repeat(source_c2ws, "B S C1 C2 -> B S N 1 1 C1 C2", N=N)
+ source_instrincs = repeat(source_instrincs, "B S C1 C2 -> B S N 1 1 C1 C2", N=N)
+ source_w2c = source_c2ws.inverse()
+ projected_samples = einsum(
+ source_w2c, ray_samples_homo, "... i j, ... j -> ... i"
+ )[..., :3]
+ # NOTE: assumes opengl convention
+ projected_samples = -1 * projected_samples[..., :2] / projected_samples[..., 2:]
+ # NOTE: intrinsics here are normalized by resolution
+ fx = source_instrincs[..., 0, 0]
+ fy = source_instrincs[..., 1, 1]
+ cx = source_instrincs[..., 0, 2]
+ cy = source_instrincs[..., 1, 2]
+ x = projected_samples[..., 0] * fx + cx
+ # negative sign here is caused by opengl, F.grid_sample is consistent with openCV convention
+ y = -projected_samples[..., 1] * fy + cy
+
+ return torch.stack([x, y], dim=-1)
+
+ def forward(
+ self, image_feats, source_c2ws, source_intrinsics, c2ws, intrinsics, render_size
+ ):
+ # image_feats: [B S C H W]
+ B = c2ws.shape[0]
+ T = c2ws.shape[1]
+ ray_origins, ray_directions = self.ray_generator(
+ c2ws.reshape(-1, 4, 4), intrinsics.reshape(-1, 3, 3), render_size
+ ) # [B * N, H * W, 3]
+ # breakpoint()
+
+ ray_samples = self.ray_sampler(
+ ray_origins, ray_directions
+ ) # [N_sample, B * N, H * W, 3]
+ ray_samples = rearrange(ray_samples, "Ns (B N) HW C -> B N HW Ns C", B=B)
+
+ projected_samples = self.project(ray_samples, source_c2ws, source_intrinsics)
+ # # debug
+ # p = projected_samples[:, :, 0, :, 0, :]
+ # p = p.reshape(p.shape[0] * p.shape[1], *p.shape[2:])
+
+ # breakpoint()
+
+ # image_feats = repeat(image_feats, "B S C H W -> (B S N) C H W", N=T)
+ image_feats = rearrange(image_feats, "B S C H W -> (B S) C H W")
+ projected_samples = rearrange(
+ projected_samples, "B S N HW Ns xy -> (B S) (N Ns) HW xy"
+ )
+ # make sure the projected samples are in the range of [-1, 1], as required by F.grid_sample
+ joint = F.grid_sample(
+ image_feats,
+ projected_samples * 2.0 - 1.0,
+ padding_mode=self.padding,
+ mode=self.interp,
+ align_corners=True,
+ )
+ # print("image_feats", image_feats.max(), image_feats.min())
+ # print("samples", projected_samples.max(), projected_samples.min())
+ joint = rearrange(
+ joint,
+ "(B S) C (N Ns) HW -> B S N HW Ns C",
+ B=B,
+ Ns=self.num_samples_per_ray,
+ )
+
+ reduced = torch.mean(joint, dim=1) # reduce on source dimension
+ if self.use_feats_std:
+ if not joint.shape[1] == 1:
+ reduced = torch.cat((reduced, joint.std(dim=1)), dim=-1)
+ else:
+ reduced = torch.cat((reduced, torch.zeros_like(reduced)), dim=-1)
+
+ if self.use_pos_emb:
+ reduced = torch.cat((reduced, self.positional_encoder(ray_samples)), dim=-1)
+ reduced = self.feature_aggregator(reduced)
+
+ feats, weights = reduced.split([reduced.shape[-1] - 1, 1], dim=-1)
+ # feats: [B, N, H * W, N_samples, N_c]
+ # weights: [B, N, H * W, N_samples, 1]
+ weights = F.softmax(weights, dim=-2)
+
+ feats = torch.sum(feats * weights, dim=-2)
+
+ rgb, feats = self.decoder(feats).split([3, 128], dim=-1)
+
+ rgb = F.sigmoid(rgb)
+ rgb = rearrange(rgb, "B N (H W) C -> B N C H W", H=render_size)
+ feats = rearrange(feats, "B N (H W) C -> B N C H W", H=render_size)
+
+ # print(rgb.max(), rgb.min())
+ # print(feats.max(), feats.min())
+
+ return rgb, feats
diff --git a/sgm/modules/video_attention.py b/sgm/modules/video_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..783395aa554144936766b57380f35dab29c093c3
--- /dev/null
+++ b/sgm/modules/video_attention.py
@@ -0,0 +1,301 @@
+import torch
+
+from ..modules.attention import *
+from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
+
+
+class TimeMixSequential(nn.Sequential):
+ def forward(self, x, context=None, timesteps=None):
+ for layer in self:
+ x = layer(x, context, timesteps)
+
+ return x
+
+
+class VideoTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention,
+ "softmax-xformers": MemoryEfficientCrossAttention,
+ }
+
+ def __init__(
+ self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.0,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ timesteps=None,
+ ff_in=False,
+ inner_dim=None,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ switch_temporal_ca_to_sa=False,
+ ):
+ super().__init__()
+
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+
+ self.ff_in = ff_in or inner_dim is not None
+ if inner_dim is None:
+ inner_dim = dim
+
+ assert int(n_heads * d_head) == inner_dim
+
+ self.is_res = inner_dim == dim
+
+ if self.ff_in:
+ self.norm_in = nn.LayerNorm(dim)
+ self.ff_in = FeedForward(
+ dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff
+ )
+
+ self.timesteps = timesteps
+ self.disable_self_attn = disable_self_attn
+ if self.disable_self_attn:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ context_dim=context_dim,
+ dropout=dropout,
+ ) # is a cross-attention
+ else:
+ self.attn1 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+
+ self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff)
+
+ if disable_temporal_crossattention:
+ if switch_temporal_ca_to_sa:
+ raise ValueError
+ else:
+ self.attn2 = None
+ else:
+ self.norm2 = nn.LayerNorm(inner_dim)
+ if switch_temporal_ca_to_sa:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout
+ ) # is a self-attention
+ else:
+ self.attn2 = attn_cls(
+ query_dim=inner_dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ ) # is self-attn if context is none
+
+ self.norm1 = nn.LayerNorm(inner_dim)
+ self.norm3 = nn.LayerNorm(inner_dim)
+ self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa
+
+ self.checkpoint = checkpoint
+ if self.checkpoint:
+ print(f"{self.__class__.__name__} is using checkpointing")
+
+ def forward(
+ self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None
+ ) -> torch.Tensor:
+ if self.checkpoint:
+ return checkpoint(self._forward, x, context, timesteps)
+ else:
+ return self._forward(x, context, timesteps=timesteps)
+
+ def _forward(self, x, context=None, timesteps=None):
+ assert self.timesteps or timesteps
+ assert not (self.timesteps and timesteps) or self.timesteps == timesteps
+ timesteps = self.timesteps or timesteps
+ B, S, C = x.shape
+ x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps)
+
+ if self.ff_in:
+ x_skip = x
+ x = self.ff_in(self.norm_in(x))
+ if self.is_res:
+ x += x_skip
+
+ if self.disable_self_attn:
+ x = self.attn1(self.norm1(x), context=context) + x
+ else:
+ x = self.attn1(self.norm1(x)) + x
+
+ if self.attn2 is not None:
+ if self.switch_temporal_ca_to_sa:
+ x = self.attn2(self.norm2(x)) + x
+ else:
+ x = self.attn2(self.norm2(x), context=context) + x
+ x_skip = x
+ x = self.ff(self.norm3(x))
+ if self.is_res:
+ x += x_skip
+
+ x = rearrange(
+ x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
+ )
+ return x
+
+ def get_last_layer(self):
+ return self.ff.net[-1].weight
+
+
+class SpatialVideoTransformer(SpatialTransformer):
+ def __init__(
+ self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.0,
+ use_linear=False,
+ context_dim=None,
+ use_spatial_context=False,
+ timesteps=None,
+ merge_strategy: str = "fixed",
+ merge_factor: float = 0.5,
+ time_context_dim=None,
+ ff_in=False,
+ checkpoint=False,
+ time_depth=1,
+ attn_mode="softmax",
+ disable_self_attn=False,
+ disable_temporal_crossattention=False,
+ max_time_embed_period: int = 10000,
+ ):
+ super().__init__(
+ in_channels,
+ n_heads,
+ d_head,
+ depth=depth,
+ dropout=dropout,
+ attn_type=attn_mode,
+ use_checkpoint=checkpoint,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ disable_self_attn=disable_self_attn,
+ )
+ self.time_depth = time_depth
+ self.depth = depth
+ self.max_time_embed_period = max_time_embed_period
+
+ time_mix_d_head = d_head
+ n_time_mix_heads = n_heads
+
+ time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads)
+
+ inner_dim = n_heads * d_head
+ if use_spatial_context:
+ time_context_dim = context_dim
+
+ self.time_stack = nn.ModuleList(
+ [
+ VideoTransformerBlock(
+ inner_dim,
+ n_time_mix_heads,
+ time_mix_d_head,
+ dropout=dropout,
+ context_dim=time_context_dim,
+ timesteps=timesteps,
+ checkpoint=checkpoint,
+ ff_in=ff_in,
+ inner_dim=time_mix_inner_dim,
+ attn_mode=attn_mode,
+ disable_self_attn=disable_self_attn,
+ disable_temporal_crossattention=disable_temporal_crossattention,
+ )
+ for _ in range(self.depth)
+ ]
+ )
+
+ assert len(self.time_stack) == len(self.transformer_blocks)
+
+ self.use_spatial_context = use_spatial_context
+ self.in_channels = in_channels
+
+ time_embed_dim = self.in_channels * 4
+ self.time_pos_embed = nn.Sequential(
+ linear(self.in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, self.in_channels),
+ )
+
+ self.time_mixer = AlphaBlender(
+ alpha=merge_factor, merge_strategy=merge_strategy
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: Optional[torch.Tensor] = None,
+ time_context: Optional[torch.Tensor] = None,
+ timesteps: Optional[int] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ _, _, h, w = x.shape
+ x_in = x
+ spatial_context = None
+ if exists(context):
+ spatial_context = context
+
+ if self.use_spatial_context:
+ assert (
+ context.ndim == 3
+ ), f"n dims of spatial context should be 3 but are {context.ndim}"
+
+ time_context = context
+ time_context_first_timestep = time_context[::timesteps]
+ time_context = repeat(
+ time_context_first_timestep, "b ... -> (b n) ...", n=h * w
+ )
+ elif time_context is not None and not self.use_spatial_context:
+ time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w)
+ if time_context.ndim == 2:
+ time_context = rearrange(time_context, "b c -> b 1 c")
+
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, "b c h w -> b (h w) c")
+ if self.use_linear:
+ x = self.proj_in(x)
+
+ num_frames = torch.arange(timesteps, device=x.device)
+ num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps)
+ num_frames = rearrange(num_frames, "b t -> (b t)")
+ t_emb = timestep_embedding(
+ num_frames,
+ self.in_channels,
+ repeat_only=False,
+ max_period=self.max_time_embed_period,
+ )
+ emb = self.time_pos_embed(t_emb)
+ emb = emb[:, None, :]
+
+ for it_, (block, mix_block) in enumerate(
+ zip(self.transformer_blocks, self.time_stack)
+ ):
+ x = block(
+ x,
+ context=spatial_context,
+ )
+
+ x_mix = x
+ x_mix = x_mix + emb
+
+ x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps)
+ x = self.time_mixer(
+ x_spatial=x,
+ x_temporal=x_mix,
+ image_only_indicator=image_only_indicator,
+ )
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
+ if not self.use_linear:
+ x = self.proj_out(x)
+ out = x + x_in
+ return out
diff --git a/sgm/sampling/hier.py b/sgm/sampling/hier.py
new file mode 100644
index 0000000000000000000000000000000000000000..375261c89b9f2fb38b2b853af8872ef4f0f500af
--- /dev/null
+++ b/sgm/sampling/hier.py
@@ -0,0 +1 @@
+# hierachical sampling, (autogressive sampling like GeNVS)
diff --git a/sgm/util.py b/sgm/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..49cc0df0e14326087e1adaf515b76137c2977fbe
--- /dev/null
+++ b/sgm/util.py
@@ -0,0 +1,310 @@
+import functools
+import importlib
+import os
+from functools import partial
+from inspect import isfunction
+
+import fsspec
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+from safetensors.torch import load_file as load_safetensors
+from einops import rearrange
+from mediapy import write_image
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def get_string_from_tuple(s):
+ try:
+ # Check if the string starts and ends with parentheses
+ if s[0] == "(" and s[-1] == ")":
+ # Convert the string to a tuple
+ t = eval(s)
+ # Check if the type of t is tuple
+ if type(t) == tuple:
+ return t[0]
+ else:
+ pass
+ except:
+ pass
+ return s
+
+
+def is_power_of_two(n):
+ """
+ chat.openai.com/chat
+ Return True if n is a power of 2, otherwise return False.
+
+ The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
+ The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
+ If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
+ Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
+
+ """
+ if n <= 0:
+ return False
+ return (n & (n - 1)) == 0
+
+
+def autocast(f, enabled=True):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=enabled,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled(),
+ ):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def load_partial_from_config(config):
+ return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
+ nc = int(40 * (wh[0] / 256))
+ if isinstance(xc[bi], list):
+ text_seq = xc[bi][0]
+ else:
+ text_seq = xc[bi]
+ lines = "\n".join(
+ text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
+ )
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def partialclass(cls, *args, **kwargs):
+ class NewCls(cls):
+ __init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
+
+ return NewCls
+
+
+def make_path_absolute(path):
+ fs, p = fsspec.core.url_to_fs(path)
+ if fs.protocol == "file":
+ return os.path.abspath(p)
+ return path
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def isheatmap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+
+ return x.ndim == 2
+
+
+def isneighbors(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def expand_dims_like(x, y):
+ while x.dim() != y.dim():
+ x = x.unsqueeze(-1)
+ return x
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == "__is_first_stage__":
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False, invalidate_cache=True):
+ module, cls = string.rsplit(".", 1)
+ if invalidate_cache:
+ importlib.invalidate_caches()
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def append_zero(x):
+ return torch.cat([x, x.new_zeros([1])])
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
+ )
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def load_model_from_config(config, ckpt, verbose=True, freeze=True):
+ print(f"Loading model from {ckpt}")
+ if ckpt.endswith("ckpt"):
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ elif ckpt.endswith("safetensors"):
+ sd = load_safetensors(ckpt)
+ else:
+ raise NotImplementedError
+
+ model = instantiate_from_config(config.model)
+
+ m, u = model.load_state_dict(sd, strict=False)
+
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if freeze:
+ for param in model.parameters():
+ param.requires_grad = False
+
+ model.eval()
+ return model
+
+
+def get_configs_path() -> str:
+ """
+ Get the `configs` directory.
+ For a working copy, this is the one in the root of the repository,
+ but for an installed copy, it's in the `sgm` package (see pyproject.toml).
+ """
+ this_dir = os.path.dirname(__file__)
+ candidates = (
+ os.path.join(this_dir, "configs"),
+ os.path.join(this_dir, "..", "configs"),
+ )
+ for candidate in candidates:
+ candidate = os.path.abspath(candidate)
+ if os.path.isdir(candidate):
+ return candidate
+ raise FileNotFoundError(f"Could not find SGM configs in {candidates}")
+
+
+def get_nested_attribute(obj, attribute_path, depth=None, return_key=False):
+ """
+ Will return the result of a recursive get attribute call.
+ E.g.:
+ a.b.c
+ = getattr(getattr(a, "b"), "c")
+ = get_nested_attribute(a, "b.c")
+ If any part of the attribute call is an integer x with current obj a, will
+ try to call a[x] instead of a.x first.
+ """
+ attributes = attribute_path.split(".")
+ if depth is not None and depth > 0:
+ attributes = attributes[:depth]
+ assert len(attributes) > 0, "At least one attribute should be selected"
+ current_attribute = obj
+ current_key = None
+ for level, attribute in enumerate(attributes):
+ current_key = ".".join(attributes[: level + 1])
+ try:
+ id_ = int(attribute)
+ current_attribute = current_attribute[id_]
+ except ValueError:
+ current_attribute = getattr(current_attribute, attribute)
+
+ return (current_attribute, current_key) if return_key else current_attribute
+
+
+def video_frames_as_grid(frames, save_path):
+ # frames: [T, C, H, W]
+ frames = frames.detach().cpu()
+ frames = rearrange(frames, "t c h w -> h (t w) c")
+ write_image(save_path, frames)
+
+
+def server_safe_call(keep_trying: bool = False):
+ """Decorator for server calls. If the call fails, it will keep trying until it succeeds.
+
+ Args:
+ keep_trying (bool, optional): whether to call again if the first try fails. Defaults to False.
+ """
+
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ success = False
+ while not success:
+ try:
+ ret = func(*args, **kwargs)
+ success = True
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not keep_trying:
+ break
+ return ret
+
+ return wrapper
+
+ return decorator