# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import copy import random from contextlib import nullcontext import torch import torch.nn.functional as F from torch import nn from scepter.modules.model.network.ldm import LatentDiffusion from scepter.modules.model.registry import MODELS from scepter.modules.utils.config import dict_to_yaml from scepter.modules.utils.distribute import we from ..utils.basic_utils import ( check_list_of_list, pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor, to_device, unpack_tensor_into_imagelist ) class TextEmbedding(nn.Module): def __init__(self, embedding_shape): super().__init__() self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) @MODELS.register_class() class LdmACE(LatentDiffusion): para_dict = LatentDiffusion.para_dict para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''} def __init__(self, cfg, logger=None): super().__init__(cfg, logger=logger) self.interpolate_func = lambda x: (F.interpolate( x.unsqueeze(0), scale_factor=1 / self.size_factor, mode='nearest-exact') if x is not None else None) self.text_indentifers = cfg.get('TEXT_IDENTIFIER', []) self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS', False) if self.use_text_pos_embeddings: self.text_position_embeddings = TextEmbedding( (10, 4096)).eval().requires_grad_(False) else: self.text_position_embeddings = None self.logger.info(self.model) @torch.no_grad() def encode_first_stage(self, x, **kwargs): return [ self.scale_factor * self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16)) for i in x ] @torch.no_grad() def decode_first_stage(self, z): return [ self.first_stage_model._decode(1. / self.scale_factor * i.to(torch.float16)) for i in z ] def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): if self.use_text_pos_embeddings and not torch.sum( self.text_position_embeddings.pos) > 0: identifier_cont, identifier_cont_mask = getattr( self.cond_stage_model, 'encode')(self.text_indentifers, return_mask=True) self.text_position_embeddings.load_state_dict( {'pos': identifier_cont[:, 0, :]}) cont_, cont_mask_ = [], [] for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): if isinstance(pp, list): cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) else: raise NotImplementedError return cont_, cont_mask_ def limit_batch_data(self, batch_data_list, log_num): if log_num and log_num > 0: batch_data_list_limited = [] for sub_data in batch_data_list: if sub_data is not None: sub_data = sub_data[:log_num] batch_data_list_limited.append(sub_data) return batch_data_list_limited else: return batch_data_list def forward_train(self, edit_image=[], edit_image_mask=[], image=None, image_mask=None, noise=None, prompt=[], **kwargs): ''' Args: edit_image: list of list of edit_image edit_image_mask: list of list of edit_image_mask image: target image image_mask: target image mask noise: default is None, generate automaticly prompt: list of list of text **kwargs: Returns: ''' assert check_list_of_list(prompt) and check_list_of_list( edit_image) and check_list_of_list(edit_image_mask) assert len(edit_image) == len(edit_image_mask) == len(prompt) assert self.cond_stage_model is not None gc_seg = kwargs.pop('gc_seg', []) gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0 context = {} # process image image = to_device(image) x_start = self.encode_first_stage(image, **kwargs) x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L n, _, _ = x_start.shape t = torch.randint(0, self.num_timesteps, (n, ), device=x_start.device).long() context['x_shapes'] = x_shapes # process image mask image_mask = to_device(image_mask, strict=False) context['x_mask'] = [self.interpolate_func(i) for i in image_mask ] if image_mask is not None else [None] * n # process text # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] try: cont, cont_mask = getattr(self.cond_stage_model, 'encode_list')(prompt_, return_mask=True) except Exception as e: print(e, prompt_) cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, cont_mask) context['crossattn'] = cont # process edit image & edit image mask edit_image = [to_device(i, strict=False) for i in edit_image] edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] e_img, e_mask = [], [] for u, m in zip(edit_image, edit_image_mask): if m is None: m = [None] * len(u) if u is not None else [None] e_img.append( self.encode_first_stage(u, **kwargs) if u is not None else u) e_mask.append([ self.interpolate_func(i) if i is not None else None for i in m ]) context['edit'], context['edit_mask'] = e_img, e_mask # process loss loss = self.diffusion.loss( x_0=x_start, t=t, noise=noise, model=self.model, model_kwargs={ 'cond': context, 'mask': cont_mask, 'gc_seg': gc_seg, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, **kwargs) loss = loss.mean() ret = {'loss': loss, 'probe_data': {'prompt': prompt}} return ret @torch.no_grad() def forward_test(self, edit_image=[], edit_image_mask=[], image=None, image_mask=None, prompt=[], n_prompt=[], sampler='ddim', sample_steps=20, guide_scale=4.5, guide_rescale=0.5, log_num=-1, seed=2024, **kwargs): assert check_list_of_list(prompt) and check_list_of_list( edit_image) and check_list_of_list(edit_image_mask) assert len(edit_image) == len(edit_image_mask) == len(prompt) assert self.cond_stage_model is not None # gc_seg is unused kwargs.pop('gc_seg', -1) # prepare data context, null_context = {}, {} prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data( [prompt, n_prompt, image, image_mask, edit_image, edit_image_mask], log_num) g = torch.Generator(device=we.device_id) seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) g.manual_seed(seed) n_prompt = copy.deepcopy(prompt) # only modify the last prompt to be zero for nn_p_id, nn_p in enumerate(n_prompt): if isinstance(nn_p, str): n_prompt[nn_p_id] = [''] elif isinstance(nn_p, list): n_prompt[nn_p_id][-1] = '' else: raise NotImplementedError # process image image = to_device(image) x = self.encode_first_stage(image, **kwargs) noise = [ torch.empty(*i.shape, device=we.device_id).normal_(generator=g) for i in x ] noise, x_shapes = pack_imagelist_into_tensor(noise) context['x_shapes'] = null_context['x_shapes'] = x_shapes # process image mask image_mask = to_device(image_mask, strict=False) cond_mask = [self.interpolate_func(i) for i in image_mask ] if image_mask is not None else [None] * len(image) context['x_mask'] = null_context['x_mask'] = cond_mask # process text # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] cont, cont_mask = getattr(self.cond_stage_model, 'encode_list')(prompt_, return_mask=True) cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, cont_mask) null_cont, null_cont_mask = getattr(self.cond_stage_model, 'encode_list')(n_prompt, return_mask=True) null_cont, null_cont_mask = self.cond_stage_embeddings( prompt, edit_image, null_cont, null_cont_mask) context['crossattn'] = cont null_context['crossattn'] = null_cont # processe edit image & edit image mask edit_image = [to_device(i, strict=False) for i in edit_image] edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] e_img, e_mask = [], [] for u, m in zip(edit_image, edit_image_mask): if u is None: continue if m is None: m = [None] * len(u) e_img.append(self.encode_first_stage(u, **kwargs)) e_mask.append([self.interpolate_func(i) for i in m]) null_context['edit'] = context['edit'] = e_img null_context['edit_mask'] = context['edit_mask'] = e_mask # process sample model = self.model_ema if self.use_ema and self.eval_ema else self.model embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \ else nullcontext with embedding_context(): samples = self.diffusion.sample( sampler=sampler, noise=noise, model=model, model_kwargs=[{ 'cond': context, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, { 'cond': null_context, 'mask': null_cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }] if guide_scale is not None and guide_scale > 1 else { 'cond': context, 'mask': cont_mask, 'text_position_embeddings': self.text_position_embeddings.pos if hasattr( self.text_position_embeddings, 'pos') else None }, steps=sample_steps, guide_scale=guide_scale, guide_rescale=guide_rescale, show_progress=True, **kwargs) samples = unpack_tensor_into_imagelist(samples, x_shapes) x_samples = self.decode_first_stage(samples) outputs = list() for i in range(len(prompt)): rec_img = torch.clamp( (x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255, min=0.0, max=1.0) rec_img = rec_img.squeeze(0) edit_imgs, edit_img_masks = [], [] if edit_image is not None and edit_image[i] is not None: if edit_image_mask[i] is None: edit_image_mask[i] = [None] * len(edit_image[i]) for edit_img, edit_mask in zip(edit_image[i], edit_image_mask[i]): edit_img = torch.clamp((edit_img + 1.0) / 2.0, min=0.0, max=1.0) edit_imgs.append(edit_img.squeeze(0)) if edit_mask is None: edit_mask = torch.ones_like(edit_img[[0], :, :]) edit_img_masks.append(edit_mask) one_tup = { 'reconstruct_image': rec_img, 'instruction': prompt[i], 'edit_image': edit_imgs if len(edit_imgs) > 0 else None, 'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None } if image is not None: if image_mask is None: image_mask = [None] * len(image) ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0) one_tup['target_image'] = ori_img.squeeze(0) one_tup['target_mask'] = image_mask[i] if image_mask[ i] is not None else torch.ones_like(ori_img[[0], :, :]) outputs.append(one_tup) return outputs @staticmethod def get_config_template(): return dict_to_yaml('MODEL', __class__.__name__, LdmACE.para_dict, set_name=True)