import sys import copy from typing import List import numpy as np import torch from einops import rearrange from omegaconf import OmegaConf from PIL import Image from pytorch_lightning import seed_everything from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer import look_at_view_transform from pytorch3d.renderer.camera_utils import join_cameras_as_batch import json sys.path.append('./custom-diffusion360/') from sgm.util import instantiate_from_config, load_safetensors choices = [] def load_base_model(config, ckpt=None, verbose=True): config = OmegaConf.load(config) # load model config.model.params.network_config.params.far = 3 config.model.params.first_stage_config.params.ckpt_path = "pretrained-models/sdxl_vae.safetensors" guider_config = {'target': 'sgm.modules.diffusionmodules.guiders.ScheduledCFGImgTextRef', 'params': {'scale': 7.5, 'scale_im': 3.5} } config.model.params.sampler_config.params.guider_config = guider_config model = instantiate_from_config(config.model) if ckpt is not None: print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) if 'modifier_token' in config.data.params: del sd['conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight'] del sd['conditioner.embedders.1.model.token_embedding.weight'] else: raise NotImplementedError m, u = model.load_state_dict(sd, strict=False) model.eval() return model def load_delta_model(model, delta_ckpt=None, verbose=True, freeze=True): """ model is preloaded base stable diffusion model """ msg = None if delta_ckpt is not None: pl_sd_delta = torch.load(delta_ckpt, map_location="cpu") sd_delta = pl_sd_delta["delta_state_dict"] # TODO: add new delta loading embedding stuff? for name, module in model.model.diffusion_model.named_modules(): if len(name.split('.')) > 1 and name.split('.')[-2] == 'transformer_blocks': if hasattr(module, 'pose_emb_layers'): module.register_buffer('references', sd_delta[f'model.diffusion_model.{name}.references']) del sd_delta[f'model.diffusion_model.{name}.references'] m, u = model.load_state_dict(sd_delta, strict=False) if len(m) > 0 and verbose: print("missing keys:") if len(u) > 0 and verbose: print("unexpected keys:") if freeze: for param in model.parameters(): param.requires_grad = False model.eval() return model, msg def get_unique_embedder_keys_from_conditioner(conditioner): p = [x.input_keys for x in conditioner.embedders] return list(set([item for sublist in p for item in sublist])) + ['jpg_ref'] def customforward(self, x, xr, context=None, contextr=None, pose=None, mask_ref=None, prev_weights=None, timesteps=None, drop_im=None): # note: if no context is given, cross-attention defaults to self-attention if not isinstance(context, list): context = [context] b, c, h, w = x.shape x_in = x fg_masks = [] alphas = [] rgbs = [] x = self.norm(x) if not self.use_linear: x = self.proj_in(x) x = rearrange(x, "b c h w -> b (h w) c").contiguous() if self.use_linear: x = self.proj_in(x) prev_weights = None counter = 0 for i, block in enumerate(self.transformer_blocks): if i > 0 and len(context) == 1: i = 0 # use same context for each block if self.image_cross and (counter % self.poscontrol_interval == 0): x, fg_mask, weights, alpha, rgb = block(x, context=context[i], context_ref=x, pose=pose, mask_ref=mask_ref, prev_weights=prev_weights, drop_im=drop_im) prev_weights = weights fg_masks.append(fg_mask) if alpha is not None: alphas.append(alpha) if rgb is not None: rgbs.append(rgb) else: x, _, _, _, _ = block(x, context=context[i], drop_im=drop_im) counter += 1 if self.use_linear: x = self.proj_out(x) x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() if not self.use_linear: x = self.proj_out(x) if len(fg_masks) > 0: if len(rgbs) <= 0: rgbs = None if len(alphas) <= 0: alphas = None return x + x_in, None, fg_masks, prev_weights, alphas, rgbs else: return x + x_in, None, None, prev_weights, None, None def _customforward( self, x, context=None, context_ref=None, pose=None, mask_ref=None, prev_weights=None, drop_im=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 ): if context_ref is not None: global choices batch_size = x.size(0) # IP2P like sampling or default sampling if batch_size % 3 == 0: batch_size = batch_size // 3 context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref, context_ref], dim=0) else: batch_size = batch_size // 2 context_ref = torch.stack([self.references[:-1][y] for y in choices]).unsqueeze(0).expand(batch_size, -1, -1, -1) context_ref = torch.cat([self.references[-1:].unsqueeze(0).expand(batch_size, context_ref.size(1), -1, -1), context_ref], dim=0) fg_mask = None weights = None alphas = None predicted_rgb = None x = ( self.attn1( self.norm1(x), context=context if self.disable_self_attn else None, additional_tokens=additional_tokens, n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self if not self.disable_self_attn else 0, ) + x ) x = ( self.attn2( self.norm2(x), context=context, additional_tokens=additional_tokens, ) + x ) if context_ref is not None: if self.rendered_feat is not None: x = self.pose_emb_layers(torch.cat([x, self.rendered_feat], dim=-1)) else: xref, fg_mask, weights, alphas, predicted_rgb = self.reference_attn(x, context_ref, context, pose, prev_weights, mask_ref) self.rendered_feat = xref x = self.pose_emb_layers(torch.cat([x, xref], -1)) x = self.ff(self.norm3(x)) + x return x, fg_mask, weights, alphas, predicted_rgb def log_images( model, batch, N: int = 1, noise=None, scale_im=3.5, num_steps: int = 10, ucg_keys: List[str] = None, **kwargs, ): log = dict() conditioner_input_keys = [e.input_keys for e in model.conditioner.embedders] ucg_keys = conditioner_input_keys pose = batch['pose'] c, uc = model.conditioner.get_unconditional_conditioning( batch, force_uc_zero_embeddings=ucg_keys if len(model.conditioner.embedders) > 0 else [], force_ref_zero_embeddings=True ) _, n = 1, len(pose)-1 sampling_kwargs = {} if scale_im > 0: if uc is not None: if isinstance(pose, list): pose = pose[:N]*3 else: pose = torch.cat([pose[:N]] * 3) else: if uc is not None: if isinstance(pose, list): pose = pose[:N]*2 else: pose = torch.cat([pose[:N]] * 2) sampling_kwargs['pose'] = pose sampling_kwargs['drop_im'] = None sampling_kwargs['mask_ref'] = None for k in c: if isinstance(c[k], torch.Tensor): c[k], uc[k] = map(lambda y: y[k][:(n+1)*N].to('cuda'), (c, uc)) import time st = time.time() with model.ema_scope("Plotting"): samples = model.sample( c, shape=noise.shape[1:], uc=uc, batch_size=N, num_steps=num_steps, noise=noise, **sampling_kwargs ) model.clear_rendered_feat() samples = model.decode_first_stage(samples) print("Time taken for sampling", time.time() - st) log["samples"] = samples.cpu() return log def process_camera_json(camera_json, example_cam): # replace all single quotes in the camera_json with quotes quotes camera_json = camera_json.replace("'", "\"") print("input camera json") print(camera_json) camera_dict = json.loads(camera_json)["scene.camera"] eye = torch.tensor([camera_dict["eye"]["x"], camera_dict["eye"]["y"], camera_dict["eye"]["z"]], dtype=torch.float32).unsqueeze(0) up = torch.tensor([camera_dict["up"]["x"], camera_dict["up"]["y"], camera_dict["up"]["z"]], dtype=torch.float32).unsqueeze(0) center = torch.tensor([camera_dict["center"]["x"], camera_dict["center"]["y"], camera_dict["center"]["z"]], dtype=torch.float32).unsqueeze(0) new_R, new_T = look_at_view_transform(eye=eye, at=center, up=up) print("focal length", example_cam.focal_length) print("principal point", example_cam.principal_point) newcam = PerspectiveCameras(R=new_R, T=new_T, focal_length=example_cam.focal_length, principal_point=example_cam.principal_point, image_size=512) print("input pose") print(newcam.get_world_to_view_transform().get_matrix()) return newcam def load_and_return_model_and_data(config, model, ckpt="pretrained-models/sd_xl_base_1.0.safetensors", delta_ckpt=None, train=False, valid=False, far=3, num_images=1, num_ref=8, max_images=20, ): config = OmegaConf.load(config) # load data data = None # config.data.params.jitter = False # config.data.params.addreg = False # config.data.params.bbox = False # data = instantiate_from_config(config.data) # data = data.train_dataset # single_id = data.single_id # if hasattr(data, 'rotations'): # total_images = len(data.rotations[data.sequence_list[single_id]]) # else: # total_images = len(data.annotations['chair']) # print(f"Total images in dataset: {total_images}") model, msg = load_delta_model(model, delta_ckpt,) model = model.cuda() # change forward methods to store rendered features and use the pre-calculated reference features def register_recr(net_): if net_.__class__.__name__ == 'SpatialTransformer': print(net_.__class__.__name__, "adding control") bound_method = customforward.__get__(net_, net_.__class__) setattr(net_, 'forward', bound_method) return elif hasattr(net_, 'children'): for net__ in net_.children(): register_recr(net__) return def register_recr2(net_): if net_.__class__.__name__ == 'BasicTransformerBlock': print(net_.__class__.__name__, "adding control") bound_method = _customforward.__get__(net_, net_.__class__) setattr(net_, 'forward', bound_method) return elif hasattr(net_, 'children'): for net__ in net_.children(): register_recr2(net__) return sub_nets = model.model.diffusion_model.named_children() for net in sub_nets: register_recr(net[1]) register_recr2(net[1]) # start sampling model.clear_rendered_feat() return model, data def sample(model, data, num_images=1, prompt="", appendpath="", camera_json=None, train=False, scale=7.5, scale_im=3.5, beta=1.0, num_ref=8, skipreflater=False, num_steps=10, valid=False, max_images=20, seed=42, camera_path="pretrained-models/car0/camera.bin", ): """ Only works with num_images=1 (because of camera_json processing) """ if num_images != 1: print("forcing num_images to be 1") num_images = 1 # set guidance scales model.sampler.guider.scale_im = scale_im model.sampler.guider.scale = scale seed_everything(seed) # load cameras cameras_val, cameras_train = torch.load(camera_path) global choices num_ref = 8 max_diff = len(cameras_train)/num_ref choices = [int(x) for x in torch.linspace(0, len(cameras_train) - max_diff, num_ref)] cameras_train_final = [cameras_train[i] for i in choices] # start sampling model.clear_rendered_feat() if prompt == "": prompt = None noise = torch.randn(1, 4, 64, 64).to('cuda').repeat(num_images, 1, 1, 1) # random sample camera poses pose_ids = np.random.choice(len(cameras_val), num_images, replace=False) print(pose_ids) pose_ids[0] = 21 pose = [cameras_val[i] for i in pose_ids] print("example camera") print(pose[0].R) print(pose[0].T) print(pose[0].focal_length) print(pose[0].principal_point) # prepare batches [if translating then call required functions on the target pose] batches = [] for i in range(num_images): batch = {'pose': [pose[i]] + cameras_train_final, "original_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), "target_size_as_tuple": torch.tensor([512, 512]).reshape(-1, 2), "crop_coords_top_left": torch.tensor([0, 0]).reshape(-1, 2), "original_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), "target_size_as_tuple_ref": torch.tensor([512, 512]).reshape(-1, 2), "crop_coords_top_left_ref": torch.tensor([0, 0]).reshape(-1, 2), } batch_ = copy.deepcopy(batch) batch_["pose"][0] = process_camera_json(camera_json, pose[0]) batch_["pose"] = [join_cameras_as_batch(batch_["pose"])] # print('batched') # print(batch_["pose"][0].get_world_to_view_transform().get_matrix()) batches.append(batch_) print(f'len batches: {len(batches)}') image = None with torch.no_grad(): for batch in batches: for key in batch.keys(): if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to('cuda') elif 'pose' in key: batch[key] = [x.to('cuda') for x in batch[key]] else: pass if prompt is not None: batch["txt"] = [prompt for _ in range(1)] batch["txt_ref"] = [prompt for _ in range(len(batch["pose"])-1)] print(batch["txt"]) N = 1 log_ = log_images(model, batch, N=N, noise=noise.clone()[:N], num_steps=num_steps, scale_im=scale_im) image = log_["samples"] torch.cuda.empty_cache() model.clear_rendered_feat() print("generation done") return image