ACE-Plus / modules /ace_plus_solver.py
chaojiemao's picture
modify ace plus
d1a539d
raw
history blame
8.69 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from scepter.modules.solver import LatentDiffusionSolver
from scepter.modules.solver.registry import SOLVERS
from scepter.modules.utils.data import transfer_data_to_cuda
from scepter.modules.utils.distribute import we
from scepter.modules.utils.probe import ProbeData
from tqdm import tqdm
@SOLVERS.register_class()
class FormalACEPlusSolver(LatentDiffusionSolver):
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
self.probe_prompt = cfg.get("PROBE_PROMPT", None)
self.probe_hw = cfg.get("PROBE_HW", [])
@torch.no_grad()
def run_eval(self):
self.eval_mode()
self.before_all_iter(self.hooks_dict[self._mode])
all_results = []
for batch_idx, batch_data in tqdm(
enumerate(self.datas[self._mode].dataloader)):
self.before_iter(self.hooks_dict[self._mode])
if self.sample_args:
batch_data.update(self.sample_args.get_lowercase_dict())
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
batch_idx,
step=self.total_iter,
rank=we.rank)
all_results.extend(results)
self.after_iter(self.hooks_dict[self._mode])
log_data, log_label = self.save_results(all_results)
self.register_probe({'eval_label': log_label})
self.register_probe({
'eval_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
self.after_all_iter(self.hooks_dict[self._mode])
@torch.no_grad()
def run_test(self):
self.test_mode()
self.before_all_iter(self.hooks_dict[self._mode])
all_results = []
for batch_idx, batch_data in tqdm(
enumerate(self.datas[self._mode].dataloader)):
self.before_iter(self.hooks_dict[self._mode])
if self.sample_args:
batch_data.update(self.sample_args.get_lowercase_dict())
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
batch_idx,
step=self.total_iter,
rank=we.rank)
all_results.extend(results)
self.after_iter(self.hooks_dict[self._mode])
log_data, log_label = self.save_results(all_results)
self.register_probe({'test_label': log_label})
self.register_probe({
'test_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
self.after_all_iter(self.hooks_dict[self._mode])
def run_step_val(self, batch_data, batch_idx=0, step=None, rank=None):
sample_id_list = batch_data['sample_id']
loss_dict = {}
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
results = self.model.forward_train(**batch_data)
loss = results['loss']
for sample_id in sample_id_list:
loss_dict[sample_id] = loss.detach().cpu().numpy()
return loss_dict
def save_results(self, results):
log_data, log_label = [], []
for result in results:
ret_images, ret_labels = [], []
edit_image = result.get('edit_image', None)
modify_image = result.get('modify_image', None)
edit_mask = result.get('edit_mask', None)
if edit_image is not None:
for i, edit_img in enumerate(result['edit_image']):
if edit_img is None:
continue
ret_images.append((edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f'edit_image{i}; ')
ret_images.append((modify_image[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f'modify_image{i}; ')
if edit_mask is not None:
ret_images.append((edit_mask[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f'edit_mask{i}; ')
target_image = result.get('target_image', None)
target_mask = result.get('target_mask', None)
if target_image is not None:
ret_images.append((target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f'target_image; ')
if target_mask is not None:
ret_images.append((target_mask.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f'target_mask; ')
teacher_image = result.get('image', None)
if teacher_image is not None:
ret_images.append((teacher_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f"teacher_image")
reconstruct_image = result.get('reconstruct_image', None)
if reconstruct_image is not None:
ret_images.append((reconstruct_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
ret_labels.append(f"{result['instruction']}")
log_data.append(ret_images)
log_label.append(ret_labels)
return log_data, log_label
@property
def probe_data(self):
if not we.debug and self.mode == 'train':
batch_data = transfer_data_to_cuda(self.current_batch_data[self.mode])
self.eval_mode()
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
batch_data['log_num'] = self.log_train_num
batch_data.update(self.sample_args.get_lowercase_dict())
results = self.run_step_eval(batch_data)
self.train_mode()
log_data, log_label = self.save_results(results)
self.register_probe({
'train_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
self.register_probe({'train_label': log_label})
if self.probe_prompt:
self.eval_mode()
all_results = []
for prompt in self.probe_prompt:
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
batch_data = {
"prompt": [[prompt]],
"image": [torch.zeros(3, self.probe_hw[0], self.probe_hw[1])],
"image_mask": [torch.ones(1, self.probe_hw[0], self.probe_hw[1])],
"src_image_list": [[]],
"modify_image_list": [[]],
"src_mask_list": [[]],
"edit_id": [[]],
"height": self.probe_hw[0],
"width": self.probe_hw[1]
}
batch_data.update(self.sample_args.get_lowercase_dict())
results = self.run_step_eval(batch_data)
all_results.extend(results)
self.train_mode()
log_data, log_label = self.save_results(all_results)
self.register_probe({
'probe_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
return super(LatentDiffusionSolver, self).probe_data