# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import numpy as np import torch from scepter.modules.solver import LatentDiffusionSolver from scepter.modules.solver.registry import SOLVERS from scepter.modules.utils.data import transfer_data_to_cuda from scepter.modules.utils.distribute import we from scepter.modules.utils.probe import ProbeData from tqdm import tqdm @SOLVERS.register_class() class FormalACEPlusSolver(LatentDiffusionSolver): def __init__(self, cfg, logger=None): super().__init__(cfg, logger=logger) self.probe_prompt = cfg.get("PROBE_PROMPT", None) self.probe_hw = cfg.get("PROBE_HW", []) @torch.no_grad() def run_eval(self): self.eval_mode() self.before_all_iter(self.hooks_dict[self._mode]) all_results = [] for batch_idx, batch_data in tqdm( enumerate(self.datas[self._mode].dataloader)): self.before_iter(self.hooks_dict[self._mode]) if self.sample_args: batch_data.update(self.sample_args.get_lowercase_dict()) with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.run_step_eval(transfer_data_to_cuda(batch_data), batch_idx, step=self.total_iter, rank=we.rank) all_results.extend(results) self.after_iter(self.hooks_dict[self._mode]) log_data, log_label = self.save_results(all_results) self.register_probe({'eval_label': log_label}) self.register_probe({ 'eval_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.after_all_iter(self.hooks_dict[self._mode]) @torch.no_grad() def run_test(self): self.test_mode() self.before_all_iter(self.hooks_dict[self._mode]) all_results = [] for batch_idx, batch_data in tqdm( enumerate(self.datas[self._mode].dataloader)): self.before_iter(self.hooks_dict[self._mode]) if self.sample_args: batch_data.update(self.sample_args.get_lowercase_dict()) with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.run_step_eval(transfer_data_to_cuda(batch_data), batch_idx, step=self.total_iter, rank=we.rank) all_results.extend(results) self.after_iter(self.hooks_dict[self._mode]) log_data, log_label = self.save_results(all_results) self.register_probe({'test_label': log_label}) self.register_probe({ 'test_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.after_all_iter(self.hooks_dict[self._mode]) def run_step_val(self, batch_data, batch_idx=0, step=None, rank=None): sample_id_list = batch_data['sample_id'] loss_dict = {} with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): results = self.model.forward_train(**batch_data) loss = results['loss'] for sample_id in sample_id_list: loss_dict[sample_id] = loss.detach().cpu().numpy() return loss_dict def save_results(self, results): log_data, log_label = [], [] for result in results: ret_images, ret_labels = [], [] edit_image = result.get('edit_image', None) modify_image = result.get('modify_image', None) edit_mask = result.get('edit_mask', None) if edit_image is not None: for i, edit_img in enumerate(result['edit_image']): if edit_img is None: continue ret_images.append((edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'edit_image{i}; ') ret_images.append((modify_image[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'modify_image{i}; ') if edit_mask is not None: ret_images.append((edit_mask[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'edit_mask{i}; ') target_image = result.get('target_image', None) target_mask = result.get('target_mask', None) if target_image is not None: ret_images.append((target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'target_image; ') if target_mask is not None: ret_images.append((target_mask.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f'target_mask; ') teacher_image = result.get('image', None) if teacher_image is not None: ret_images.append((teacher_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f"teacher_image") reconstruct_image = result.get('reconstruct_image', None) if reconstruct_image is not None: ret_images.append((reconstruct_image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) ret_labels.append(f"{result['instruction']}") log_data.append(ret_images) log_label.append(ret_labels) return log_data, log_label @property def probe_data(self): if not we.debug and self.mode == 'train': batch_data = transfer_data_to_cuda(self.current_batch_data[self.mode]) self.eval_mode() with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): batch_data['log_num'] = self.log_train_num batch_data.update(self.sample_args.get_lowercase_dict()) results = self.run_step_eval(batch_data) self.train_mode() log_data, log_label = self.save_results(results) self.register_probe({ 'train_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) self.register_probe({'train_label': log_label}) if self.probe_prompt: self.eval_mode() all_results = [] for prompt in self.probe_prompt: with torch.autocast(device_type='cuda', enabled=self.use_amp, dtype=self.dtype): batch_data = { "prompt": [[prompt]], "image": [torch.zeros(3, self.probe_hw[0], self.probe_hw[1])], "image_mask": [torch.ones(1, self.probe_hw[0], self.probe_hw[1])], "src_image_list": [[]], "modify_image_list": [[]], "src_mask_list": [[]], "edit_id": [[]], "height": self.probe_hw[0], "width": self.probe_hw[1] } batch_data.update(self.sample_args.get_lowercase_dict()) results = self.run_step_eval(batch_data) all_results.extend(results) self.train_mode() log_data, log_label = self.save_results(all_results) self.register_probe({ 'probe_image': ProbeData(log_data, is_image=True, build_html=True, build_label=log_label) }) return super(LatentDiffusionSolver, self).probe_data