ACE-Plus / modules /checkpoint.py
chaojiemao's picture
modify ace plus
d1a539d
raw
history blame
6.3 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import os, torch
import os.path as osp
import warnings
from collections import OrderedDict
from safetensors.torch import save_file
from scepter.modules.solver.hooks import CheckpointHook, BackwardHook
from scepter.modules.solver.hooks.registry import HOOKS
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.distribute import we
from scepter.modules.utils.file_system import FS
_DEFAULT_CHECKPOINT_PRIORITY = 300
def convert_to_comfyui_lora(ori_sd, prefix = "lora_unet"):
new_ckpt = OrderedDict()
for k,v in ori_sd.items():
new_k = k.replace(".lora_A.0_SwiftLoRA.", ".lora_down.").replace(".lora_B.0_SwiftLoRA.", ".lora_up.")
new_k = prefix + "_" + new_k.split(".lora")[0].replace("model.", "").replace(".", "_") + ".lora" + new_k.split(".lora")[1]
alpha_k = new_k.split(".lora")[0] + ".alpha"
new_ckpt[new_k] = v
if "lora_up" in new_k:
alpha = v.shape[-1]
elif "lora_down" in new_k:
alpha = v.shape[0]
new_ckpt[alpha_k] = torch.tensor(float(alpha)).to(v)
return new_ckpt
@HOOKS.register_class()
class ACECheckpointHook(CheckpointHook):
""" Checkpoint resume or save hook.
Args:
interval (int): Save interval, by epoch.
save_best (bool): Save the best checkpoint by a metric key, default is False.
save_best_by (str): How to get the best the checkpoint by the metric key, default is ''.
+ means the higher the best (default).
- means the lower the best.
E.g. +acc@1, -err@1, acc@5(same as +acc@5)
"""
def __init__(self, cfg, logger=None):
super(ACECheckpointHook, self).__init__(cfg, logger=logger)
def after_iter(self, solver):
super().after_iter(solver)
if solver.total_iter != 0 and (
(solver.total_iter + 1) % self.interval == 0
or solver.total_iter == solver.max_steps - 1):
from swift import SwiftModel
if isinstance(solver.model, SwiftModel) or (
hasattr(solver.model, 'module')
and isinstance(solver.model.module, SwiftModel)):
save_path = osp.join(
solver.work_dir,
'checkpoints/{}-{}'.format(self.save_name_prefix,
solver.total_iter + 1))
if we.rank == 0:
tuner_model = os.path.join(save_path, '0_SwiftLoRA', 'adapter_model.bin')
save_model = os.path.join(save_path, '0_SwiftLoRA', 'comfyui_model.safetensors')
if FS.exists(tuner_model):
with FS.get_from(tuner_model) as local_file:
swift_lora_sd = torch.load(local_file, weights_only=True)
safetensor_lora_sd = convert_to_comfyui_lora(swift_lora_sd)
with FS.put_to(save_model) as local_file:
save_file(safetensor_lora_sd, local_file)
@staticmethod
def get_config_template():
return dict_to_yaml('hook',
__class__.__name__,
ACECheckpointHook.para_dict,
set_name=True)
@HOOKS.register_class()
class ACEBackwardHook(BackwardHook):
def grad_clip(self, optimizer):
for params_group in optimizer.param_groups:
train_params = []
for param in params_group['params']:
if param.requires_grad:
train_params.append(param)
# print(len(train_params), self.gradient_clip)
torch.nn.utils.clip_grad_norm_(parameters=train_params,
max_norm=self.gradient_clip)
def after_iter(self, solver):
if solver.optimizer is not None and solver.is_train_mode:
if solver.loss is None:
warnings.warn(
'solver.loss should not be None in train mode, remember to call solver._reduce_scalar()!'
)
return
if solver.scaler is not None:
solver.scaler.scale(solver.loss /
self.accumulate_step).backward()
self.current_step += 1
# Suppose profiler run after backward, so we need to set backward_prev_step
# as the previous one step before the backward step
if self.current_step % self.accumulate_step == 0:
solver.scaler.unscale_(solver.optimizer)
if self.gradient_clip > 0:
self.grad_clip(solver.optimizer)
self.profile(solver)
solver.scaler.step(solver.optimizer)
solver.scaler.update()
solver.optimizer.zero_grad()
else:
(solver.loss / self.accumulate_step).backward()
self.current_step += 1
# Suppose profiler run after backward, so we need to set backward_prev_step
# as the previous one step before the backward step
if self.current_step % self.accumulate_step == 0:
if self.gradient_clip > 0:
self.grad_clip(solver.optimizer)
self.profile(solver)
solver.optimizer.step()
solver.optimizer.zero_grad()
if solver.lr_scheduler:
if self.current_step % self.accumulate_step == 0:
solver.lr_scheduler.step()
if self.current_step % self.accumulate_step == 0:
setattr(solver, 'backward_step', True)
self.current_step = 0
else:
setattr(solver, 'backward_step', False)
solver.loss = None
if self.empty_cache_step > 0 and solver.total_iter % self.empty_cache_step == 0:
torch.cuda.empty_cache()
@staticmethod
def get_config_template():
return dict_to_yaml('hook',
__class__.__name__,
ACEBackwardHook.para_dict,
set_name=True)