# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch from tqdm import tqdm from scepter.modules.utils.data import transfer_data_to_cuda from scepter.modules.utils.distribute import we from scepter.modules.utils.probe import ProbeData from scepter.modules.solver.registry import SOLVERS from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver @SOLVERS.register_class() class ACESolverV1(LatentDiffusionSolver): def __init__(self, cfg, logger=None): super().__init__(cfg, logger=logger) self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1) def save_results(self, results): log_data, log_label = [], [] for result in results: ret_images, ret_labels = [], [] edit_image = result.get('edit_image', None) edit_mask = result.get('edit_mask', None) if edit_image is not None: for i, edit_img in enumerate(result['edit_image']): if edit_img is None: continue ret_images.append( (edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype( np.uint8)) ret_labels.append(f'edit_image{i}; ') if edit_mask is not None: ret_images.append( (edit_mask[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'edit_mask{i}; ') target_image = result.get('target_image', None) target_mask = result.get('target_mask', None) if target_image is not None: ret_images.append( (target_image.permute(1, 2, 0).cpu().numpy() * 255).astype( np.uint8)) ret_labels.append('target_image; ') if target_mask is not None: ret_images.append( (target_mask.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append('target_mask; ') reconstruct_image = result.get('reconstruct_image', None) if reconstruct_image is not None: ret_images.append( (reconstruct_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f"{result['instruction']}") log_data.append(ret_images) log_label.append(ret_labels) return log_data, log_label @torch.no_grad() def run_eval(self): self.eval_mode() self.before_all_iter(self.hooks_dict[self._mode]) all_results = [] for batch_idx, batch_data in tqdm( enumerate(self.datas[self._mode].dataloader)): self.before_iter(self.hooks_dict[self._mode]) if self.sample_args: batch_data.update(self.sample_args.get_lowercase_dict()) with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.run_step_eval(transfer_data_to_cuda(batch_data), batch_idx, step=self.total_iter, rank=we.rank) all_results.extend(results) self.after_iter(self.hooks_dict[self._mode]) log_data, log_label = self.save_results(all_results) self.register_probe({'eval_label': log_label}) self.register_probe({ 'eval_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.after_all_iter(self.hooks_dict[self._mode]) @torch.no_grad() def run_test(self): self.test_mode() self.before_all_iter(self.hooks_dict[self._mode]) all_results = [] for batch_idx, batch_data in tqdm( enumerate(self.datas[self._mode].dataloader)): self.before_iter(self.hooks_dict[self._mode]) if self.sample_args: batch_data.update(self.sample_args.get_lowercase_dict()) with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.run_step_eval(transfer_data_to_cuda(batch_data), batch_idx, step=self.total_iter, rank=we.rank) all_results.extend(results) self.after_iter(self.hooks_dict[self._mode]) log_data, log_label = self.save_results(all_results) self.register_probe({'test_label': log_label}) self.register_probe({ 'test_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.after_all_iter(self.hooks_dict[self._mode]) @property def probe_data(self): if not we.debug and self.mode == 'train': batch_data = transfer_data_to_cuda( self.current_batch_data[self.mode]) self.eval_mode() with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): batch_data['log_num'] = self.log_train_num results = self.run_step_eval(batch_data) self.train_mode() log_data, log_label = self.save_results(results) self.register_probe({ 'train_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.register_probe({'train_label': log_label}) return super(LatentDiffusionSolver, self).probe_data