Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import copy | |
import random | |
from contextlib import nullcontext | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from scepter.modules.model.network.ldm import LatentDiffusion | |
from scepter.modules.model.registry import MODELS | |
from scepter.modules.utils.config import dict_to_yaml | |
from scepter.modules.utils.distribute import we | |
from ..utils.basic_utils import ( | |
check_list_of_list, | |
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor, | |
to_device, | |
unpack_tensor_into_imagelist | |
) | |
class TextEmbedding(nn.Module): | |
def __init__(self, embedding_shape): | |
super().__init__() | |
self.pos = nn.Parameter(data=torch.zeros(embedding_shape)) | |
class LdmACE(LatentDiffusion): | |
para_dict = LatentDiffusion.para_dict | |
para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''} | |
def __init__(self, cfg, logger=None): | |
super().__init__(cfg, logger=logger) | |
self.interpolate_func = lambda x: (F.interpolate( | |
x.unsqueeze(0), | |
scale_factor=1 / self.size_factor, | |
mode='nearest-exact') if x is not None else None) | |
self.text_indentifers = cfg.get('TEXT_IDENTIFIER', []) | |
self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS', | |
False) | |
if self.use_text_pos_embeddings: | |
self.text_position_embeddings = TextEmbedding( | |
(10, 4096)).eval().requires_grad_(False) | |
else: | |
self.text_position_embeddings = None | |
self.logger.info(self.model) | |
def encode_first_stage(self, x, **kwargs): | |
return [ | |
self.scale_factor * | |
self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16)) | |
for i in x | |
] | |
def decode_first_stage(self, z): | |
return [ | |
self.first_stage_model._decode(1. / self.scale_factor * | |
i.to(torch.float16)) for i in z | |
] | |
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask): | |
if self.use_text_pos_embeddings and not torch.sum( | |
self.text_position_embeddings.pos) > 0: | |
identifier_cont, identifier_cont_mask = getattr( | |
self.cond_stage_model, 'encode')(self.text_indentifers, | |
return_mask=True) | |
self.text_position_embeddings.load_state_dict( | |
{'pos': identifier_cont[:, 0, :]}) | |
cont_, cont_mask_ = [], [] | |
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask): | |
if isinstance(pp, list): | |
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]]) | |
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]]) | |
else: | |
raise NotImplementedError | |
return cont_, cont_mask_ | |
def limit_batch_data(self, batch_data_list, log_num): | |
if log_num and log_num > 0: | |
batch_data_list_limited = [] | |
for sub_data in batch_data_list: | |
if sub_data is not None: | |
sub_data = sub_data[:log_num] | |
batch_data_list_limited.append(sub_data) | |
return batch_data_list_limited | |
else: | |
return batch_data_list | |
def forward_train(self, | |
edit_image=[], | |
edit_image_mask=[], | |
image=None, | |
image_mask=None, | |
noise=None, | |
prompt=[], | |
**kwargs): | |
''' | |
Args: | |
edit_image: list of list of edit_image | |
edit_image_mask: list of list of edit_image_mask | |
image: target image | |
image_mask: target image mask | |
noise: default is None, generate automaticly | |
prompt: list of list of text | |
**kwargs: | |
Returns: | |
''' | |
assert check_list_of_list(prompt) and check_list_of_list( | |
edit_image) and check_list_of_list(edit_image_mask) | |
assert len(edit_image) == len(edit_image_mask) == len(prompt) | |
assert self.cond_stage_model is not None | |
gc_seg = kwargs.pop('gc_seg', []) | |
gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0 | |
context = {} | |
# process image | |
image = to_device(image) | |
x_start = self.encode_first_stage(image, **kwargs) | |
x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L | |
n, _, _ = x_start.shape | |
t = torch.randint(0, self.num_timesteps, (n, ), | |
device=x_start.device).long() | |
context['x_shapes'] = x_shapes | |
# process image mask | |
image_mask = to_device(image_mask, strict=False) | |
context['x_mask'] = [self.interpolate_func(i) for i in image_mask | |
] if image_mask is not None else [None] * n | |
# process text | |
# with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): | |
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] | |
try: | |
cont, cont_mask = getattr(self.cond_stage_model, | |
'encode_list')(prompt_, return_mask=True) | |
except Exception as e: | |
print(e, prompt_) | |
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, | |
cont_mask) | |
context['crossattn'] = cont | |
# process edit image & edit image mask | |
edit_image = [to_device(i, strict=False) for i in edit_image] | |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] | |
e_img, e_mask = [], [] | |
for u, m in zip(edit_image, edit_image_mask): | |
if m is None: | |
m = [None] * len(u) if u is not None else [None] | |
e_img.append( | |
self.encode_first_stage(u, **kwargs) if u is not None else u) | |
e_mask.append([ | |
self.interpolate_func(i) if i is not None else None for i in m | |
]) | |
context['edit'], context['edit_mask'] = e_img, e_mask | |
# process loss | |
loss = self.diffusion.loss( | |
x_0=x_start, | |
t=t, | |
noise=noise, | |
model=self.model, | |
model_kwargs={ | |
'cond': | |
context, | |
'mask': | |
cont_mask, | |
'gc_seg': | |
gc_seg, | |
'text_position_embeddings': | |
self.text_position_embeddings.pos if hasattr( | |
self.text_position_embeddings, 'pos') else None | |
}, | |
**kwargs) | |
loss = loss.mean() | |
ret = {'loss': loss, 'probe_data': {'prompt': prompt}} | |
return ret | |
def forward_test(self, | |
edit_image=[], | |
edit_image_mask=[], | |
image=None, | |
image_mask=None, | |
prompt=[], | |
n_prompt=[], | |
sampler='ddim', | |
sample_steps=20, | |
guide_scale=4.5, | |
guide_rescale=0.5, | |
log_num=-1, | |
seed=2024, | |
**kwargs): | |
assert check_list_of_list(prompt) and check_list_of_list( | |
edit_image) and check_list_of_list(edit_image_mask) | |
assert len(edit_image) == len(edit_image_mask) == len(prompt) | |
assert self.cond_stage_model is not None | |
# gc_seg is unused | |
kwargs.pop('gc_seg', -1) | |
# prepare data | |
context, null_context = {}, {} | |
prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data( | |
[prompt, n_prompt, image, image_mask, edit_image, edit_image_mask], | |
log_num) | |
g = torch.Generator(device=we.device_id) | |
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1) | |
g.manual_seed(seed) | |
n_prompt = copy.deepcopy(prompt) | |
# only modify the last prompt to be zero | |
for nn_p_id, nn_p in enumerate(n_prompt): | |
if isinstance(nn_p, str): | |
n_prompt[nn_p_id] = [''] | |
elif isinstance(nn_p, list): | |
n_prompt[nn_p_id][-1] = '' | |
else: | |
raise NotImplementedError | |
# process image | |
image = to_device(image) | |
x = self.encode_first_stage(image, **kwargs) | |
noise = [ | |
torch.empty(*i.shape, device=we.device_id).normal_(generator=g) | |
for i in x | |
] | |
noise, x_shapes = pack_imagelist_into_tensor(noise) | |
context['x_shapes'] = null_context['x_shapes'] = x_shapes | |
# process image mask | |
image_mask = to_device(image_mask, strict=False) | |
cond_mask = [self.interpolate_func(i) for i in image_mask | |
] if image_mask is not None else [None] * len(image) | |
context['x_mask'] = null_context['x_mask'] = cond_mask | |
# process text | |
# with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): | |
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt] | |
cont, cont_mask = getattr(self.cond_stage_model, | |
'encode_list')(prompt_, return_mask=True) | |
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont, | |
cont_mask) | |
null_cont, null_cont_mask = getattr(self.cond_stage_model, | |
'encode_list')(n_prompt, | |
return_mask=True) | |
null_cont, null_cont_mask = self.cond_stage_embeddings( | |
prompt, edit_image, null_cont, null_cont_mask) | |
context['crossattn'] = cont | |
null_context['crossattn'] = null_cont | |
# processe edit image & edit image mask | |
edit_image = [to_device(i, strict=False) for i in edit_image] | |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask] | |
e_img, e_mask = [], [] | |
for u, m in zip(edit_image, edit_image_mask): | |
if u is None: | |
continue | |
if m is None: | |
m = [None] * len(u) | |
e_img.append(self.encode_first_stage(u, **kwargs)) | |
e_mask.append([self.interpolate_func(i) for i in m]) | |
null_context['edit'] = context['edit'] = e_img | |
null_context['edit_mask'] = context['edit_mask'] = e_mask | |
# process sample | |
model = self.model_ema if self.use_ema and self.eval_ema else self.model | |
embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \ | |
else nullcontext | |
with embedding_context(): | |
samples = self.diffusion.sample( | |
sampler=sampler, | |
noise=noise, | |
model=model, | |
model_kwargs=[{ | |
'cond': | |
context, | |
'mask': | |
cont_mask, | |
'text_position_embeddings': | |
self.text_position_embeddings.pos if hasattr( | |
self.text_position_embeddings, 'pos') else None | |
}, { | |
'cond': | |
null_context, | |
'mask': | |
null_cont_mask, | |
'text_position_embeddings': | |
self.text_position_embeddings.pos if hasattr( | |
self.text_position_embeddings, 'pos') else None | |
}] if guide_scale is not None and guide_scale > 1 else { | |
'cond': | |
context, | |
'mask': | |
cont_mask, | |
'text_position_embeddings': | |
self.text_position_embeddings.pos if hasattr( | |
self.text_position_embeddings, 'pos') else None | |
}, | |
steps=sample_steps, | |
guide_scale=guide_scale, | |
guide_rescale=guide_rescale, | |
show_progress=True, | |
**kwargs) | |
samples = unpack_tensor_into_imagelist(samples, x_shapes) | |
x_samples = self.decode_first_stage(samples) | |
outputs = list() | |
for i in range(len(prompt)): | |
rec_img = torch.clamp( | |
(x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255, | |
min=0.0, | |
max=1.0) | |
rec_img = rec_img.squeeze(0) | |
edit_imgs, edit_img_masks = [], [] | |
if edit_image is not None and edit_image[i] is not None: | |
if edit_image_mask[i] is None: | |
edit_image_mask[i] = [None] * len(edit_image[i]) | |
for edit_img, edit_mask in zip(edit_image[i], | |
edit_image_mask[i]): | |
edit_img = torch.clamp((edit_img + 1.0) / 2.0, | |
min=0.0, | |
max=1.0) | |
edit_imgs.append(edit_img.squeeze(0)) | |
if edit_mask is None: | |
edit_mask = torch.ones_like(edit_img[[0], :, :]) | |
edit_img_masks.append(edit_mask) | |
one_tup = { | |
'reconstruct_image': rec_img, | |
'instruction': prompt[i], | |
'edit_image': edit_imgs if len(edit_imgs) > 0 else None, | |
'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None | |
} | |
if image is not None: | |
if image_mask is None: | |
image_mask = [None] * len(image) | |
ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0) | |
one_tup['target_image'] = ori_img.squeeze(0) | |
one_tup['target_mask'] = image_mask[i] if image_mask[ | |
i] is not None else torch.ones_like(ori_img[[0], :, :]) | |
outputs.append(one_tup) | |
return outputs | |
def get_config_template(): | |
return dict_to_yaml('MODEL', | |
__class__.__name__, | |
LdmACE.para_dict, | |
set_name=True) | |