Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import os, torch | |
import os.path as osp | |
import warnings | |
from collections import OrderedDict | |
from safetensors.torch import save_file | |
from scepter.modules.solver.hooks import CheckpointHook, BackwardHook | |
from scepter.modules.solver.hooks.registry import HOOKS | |
from scepter.modules.utils.config import dict_to_yaml | |
from scepter.modules.utils.distribute import we | |
from scepter.modules.utils.file_system import FS | |
_DEFAULT_CHECKPOINT_PRIORITY = 300 | |
def convert_to_comfyui_lora(ori_sd, prefix = "lora_unet"): | |
new_ckpt = OrderedDict() | |
for k,v in ori_sd.items(): | |
new_k = k.replace(".lora_A.0_SwiftLoRA.", ".lora_down.").replace(".lora_B.0_SwiftLoRA.", ".lora_up.") | |
new_k = prefix + "_" + new_k.split(".lora")[0].replace("model.", "").replace(".", "_") + ".lora" + new_k.split(".lora")[1] | |
alpha_k = new_k.split(".lora")[0] + ".alpha" | |
new_ckpt[new_k] = v | |
if "lora_up" in new_k: | |
alpha = v.shape[-1] | |
elif "lora_down" in new_k: | |
alpha = v.shape[0] | |
new_ckpt[alpha_k] = torch.tensor(float(alpha)).to(v) | |
return new_ckpt | |
class ACECheckpointHook(CheckpointHook): | |
""" Checkpoint resume or save hook. | |
Args: | |
interval (int): Save interval, by epoch. | |
save_best (bool): Save the best checkpoint by a metric key, default is False. | |
save_best_by (str): How to get the best the checkpoint by the metric key, default is ''. | |
+ means the higher the best (default). | |
- means the lower the best. | |
E.g. +acc@1, -err@1, acc@5(same as +acc@5) | |
""" | |
def __init__(self, cfg, logger=None): | |
super(ACECheckpointHook, self).__init__(cfg, logger=logger) | |
def after_iter(self, solver): | |
super().after_iter(solver) | |
if solver.total_iter != 0 and ( | |
(solver.total_iter + 1) % self.interval == 0 | |
or solver.total_iter == solver.max_steps - 1): | |
from swift import SwiftModel | |
if isinstance(solver.model, SwiftModel) or ( | |
hasattr(solver.model, 'module') | |
and isinstance(solver.model.module, SwiftModel)): | |
save_path = osp.join( | |
solver.work_dir, | |
'checkpoints/{}-{}'.format(self.save_name_prefix, | |
solver.total_iter + 1)) | |
if we.rank == 0: | |
tuner_model = os.path.join(save_path, '0_SwiftLoRA', 'adapter_model.bin') | |
save_model = os.path.join(save_path, '0_SwiftLoRA', 'comfyui_model.safetensors') | |
if FS.exists(tuner_model): | |
with FS.get_from(tuner_model) as local_file: | |
swift_lora_sd = torch.load(local_file, weights_only=True) | |
safetensor_lora_sd = convert_to_comfyui_lora(swift_lora_sd) | |
with FS.put_to(save_model) as local_file: | |
save_file(safetensor_lora_sd, local_file) | |
def get_config_template(): | |
return dict_to_yaml('hook', | |
__class__.__name__, | |
ACECheckpointHook.para_dict, | |
set_name=True) | |
class ACEBackwardHook(BackwardHook): | |
def grad_clip(self, optimizer): | |
for params_group in optimizer.param_groups: | |
train_params = [] | |
for param in params_group['params']: | |
if param.requires_grad: | |
train_params.append(param) | |
# print(len(train_params), self.gradient_clip) | |
torch.nn.utils.clip_grad_norm_(parameters=train_params, | |
max_norm=self.gradient_clip) | |
def after_iter(self, solver): | |
if solver.optimizer is not None and solver.is_train_mode: | |
if solver.loss is None: | |
warnings.warn( | |
'solver.loss should not be None in train mode, remember to call solver._reduce_scalar()!' | |
) | |
return | |
if solver.scaler is not None: | |
solver.scaler.scale(solver.loss / | |
self.accumulate_step).backward() | |
self.current_step += 1 | |
# Suppose profiler run after backward, so we need to set backward_prev_step | |
# as the previous one step before the backward step | |
if self.current_step % self.accumulate_step == 0: | |
solver.scaler.unscale_(solver.optimizer) | |
if self.gradient_clip > 0: | |
self.grad_clip(solver.optimizer) | |
self.profile(solver) | |
solver.scaler.step(solver.optimizer) | |
solver.scaler.update() | |
solver.optimizer.zero_grad() | |
else: | |
(solver.loss / self.accumulate_step).backward() | |
self.current_step += 1 | |
# Suppose profiler run after backward, so we need to set backward_prev_step | |
# as the previous one step before the backward step | |
if self.current_step % self.accumulate_step == 0: | |
if self.gradient_clip > 0: | |
self.grad_clip(solver.optimizer) | |
self.profile(solver) | |
solver.optimizer.step() | |
solver.optimizer.zero_grad() | |
if solver.lr_scheduler: | |
if self.current_step % self.accumulate_step == 0: | |
solver.lr_scheduler.step() | |
if self.current_step % self.accumulate_step == 0: | |
setattr(solver, 'backward_step', True) | |
self.current_step = 0 | |
else: | |
setattr(solver, 'backward_step', False) | |
solver.loss = None | |
if self.empty_cache_step > 0 and solver.total_iter % self.empty_cache_step == 0: | |
torch.cuda.empty_cache() | |
def get_config_template(): | |
return dict_to_yaml('hook', | |
__class__.__name__, | |
ACEBackwardHook.para_dict, | |
set_name=True) | |