ACE-Chat / modules /solver /ace_solver.py
pan-yl's picture
update file
2a00960
raw
history blame
6.26 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from tqdm import tqdm
from scepter.modules.utils.data import transfer_data_to_cuda
from scepter.modules.utils.distribute import we
from scepter.modules.utils.probe import ProbeData
from scepter.modules.solver.registry import SOLVERS
from scepter.modules.solver.diffusion_solver import LatentDiffusionSolver
@SOLVERS.register_class()
class ACESolverV1(LatentDiffusionSolver):
def __init__(self, cfg, logger=None):
super().__init__(cfg, logger=logger)
self.log_train_num = cfg.get('LOG_TRAIN_NUM', -1)
def save_results(self, results):
log_data, log_label = [], []
for result in results:
ret_images, ret_labels = [], []
edit_image = result.get('edit_image', None)
edit_mask = result.get('edit_mask', None)
if edit_image is not None:
for i, edit_img in enumerate(result['edit_image']):
if edit_img is None:
continue
ret_images.append(
(edit_img.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8))
ret_labels.append(f'edit_image{i}; ')
if edit_mask is not None:
ret_images.append(
(edit_mask[i].permute(1, 2, 0).cpu().numpy() *
255).astype(np.uint8))
ret_labels.append(f'edit_mask{i}; ')
target_image = result.get('target_image', None)
target_mask = result.get('target_mask', None)
if target_image is not None:
ret_images.append(
(target_image.permute(1, 2, 0).cpu().numpy() * 255).astype(
np.uint8))
ret_labels.append('target_image; ')
if target_mask is not None:
ret_images.append(
(target_mask.permute(1, 2, 0).cpu().numpy() *
255).astype(np.uint8))
ret_labels.append('target_mask; ')
reconstruct_image = result.get('reconstruct_image', None)
if reconstruct_image is not None:
ret_images.append(
(reconstruct_image.permute(1, 2, 0).cpu().numpy() *
255).astype(np.uint8))
ret_labels.append(f"{result['instruction']}")
log_data.append(ret_images)
log_label.append(ret_labels)
return log_data, log_label
@torch.no_grad()
def run_eval(self):
self.eval_mode()
self.before_all_iter(self.hooks_dict[self._mode])
all_results = []
for batch_idx, batch_data in tqdm(
enumerate(self.datas[self._mode].dataloader)):
self.before_iter(self.hooks_dict[self._mode])
if self.sample_args:
batch_data.update(self.sample_args.get_lowercase_dict())
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
batch_idx,
step=self.total_iter,
rank=we.rank)
all_results.extend(results)
self.after_iter(self.hooks_dict[self._mode])
log_data, log_label = self.save_results(all_results)
self.register_probe({'eval_label': log_label})
self.register_probe({
'eval_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
self.after_all_iter(self.hooks_dict[self._mode])
@torch.no_grad()
def run_test(self):
self.test_mode()
self.before_all_iter(self.hooks_dict[self._mode])
all_results = []
for batch_idx, batch_data in tqdm(
enumerate(self.datas[self._mode].dataloader)):
self.before_iter(self.hooks_dict[self._mode])
if self.sample_args:
batch_data.update(self.sample_args.get_lowercase_dict())
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
results = self.run_step_eval(transfer_data_to_cuda(batch_data),
batch_idx,
step=self.total_iter,
rank=we.rank)
all_results.extend(results)
self.after_iter(self.hooks_dict[self._mode])
log_data, log_label = self.save_results(all_results)
self.register_probe({'test_label': log_label})
self.register_probe({
'test_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
self.after_all_iter(self.hooks_dict[self._mode])
@property
def probe_data(self):
if not we.debug and self.mode == 'train':
batch_data = transfer_data_to_cuda(
self.current_batch_data[self.mode])
self.eval_mode()
with torch.autocast(device_type='cuda',
enabled=self.use_amp,
dtype=self.dtype):
batch_data['log_num'] = self.log_train_num
results = self.run_step_eval(batch_data)
self.train_mode()
log_data, log_label = self.save_results(results)
self.register_probe({
'train_image':
ProbeData(log_data,
is_image=True,
build_html=True,
build_label=log_label)
})
self.register_probe({'train_label': log_label})
return super(LatentDiffusionSolver, self).probe_data