Spaces:
Runtime error
Runtime error
import argparse | |
import logging | |
import os | |
import os.path as osp | |
import torch | |
from basicsr.utils import (get_env_info, get_root_logger, get_time_str, | |
scandir) | |
from basicsr.utils.options import copy_opt_file, dict2str | |
from omegaconf import OmegaConf | |
from ldm.data.dataset_depth import DepthDataset | |
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only | |
from ldm.modules.encoders.adapter import Adapter | |
from ldm.util import load_model_from_config | |
def mkdir_and_rename(path): | |
"""mkdirs. If path exists, rename it with timestamp and create a new one. | |
Args: | |
path (str): Folder path. | |
""" | |
if osp.exists(path): | |
new_name = path + '_archived_' + get_time_str() | |
print(f'Path already exists. Rename it to {new_name}', flush=True) | |
os.rename(path, new_name) | |
os.makedirs(path, exist_ok=True) | |
os.makedirs(osp.join(path, 'models')) | |
os.makedirs(osp.join(path, 'training_states')) | |
os.makedirs(osp.join(path, 'visualization')) | |
def load_resume_state(opt): | |
resume_state_path = None | |
if opt.auto_resume: | |
state_path = osp.join('experiments', opt.name, 'training_states') | |
if osp.isdir(state_path): | |
states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) | |
if len(states) != 0: | |
states = [float(v.split('.state')[0]) for v in states] | |
resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') | |
opt.resume_state_path = resume_state_path | |
if resume_state_path is None: | |
resume_state = None | |
else: | |
device_id = torch.cuda.current_device() | |
resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) | |
return resume_state | |
def parsr_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--bsize", | |
type=int, | |
default=8, | |
) | |
parser.add_argument( | |
"--epochs", | |
type=int, | |
default=10000, | |
) | |
parser.add_argument( | |
"--num_workers", | |
type=int, | |
default=8, | |
) | |
parser.add_argument( | |
"--plms", | |
action='store_true', | |
help="use plms sampling", | |
) | |
parser.add_argument( | |
"--auto_resume", | |
action='store_true', | |
help="use plms sampling", | |
) | |
parser.add_argument( | |
"--ckpt", | |
type=str, | |
default="models/sd-v1-4.ckpt", | |
help="path to checkpoint of model", | |
) | |
parser.add_argument( | |
"--config", | |
type=str, | |
default="configs/stable-diffusion/sd-v1-train.yaml", | |
help="path to config which constructs model", | |
) | |
parser.add_argument( | |
"--name", | |
type=str, | |
default="train_depth", | |
help="experiment name", | |
) | |
parser.add_argument( | |
"--print_fq", | |
type=int, | |
default=100, | |
help="path to config which constructs model", | |
) | |
parser.add_argument( | |
"--H", | |
type=int, | |
default=512, | |
help="image height, in pixel space", | |
) | |
parser.add_argument( | |
"--W", | |
type=int, | |
default=512, | |
help="image width, in pixel space", | |
) | |
parser.add_argument( | |
"--C", | |
type=int, | |
default=4, | |
help="latent channels", | |
) | |
parser.add_argument( | |
"--f", | |
type=int, | |
default=8, | |
help="downsampling factor", | |
) | |
parser.add_argument( | |
"--sample_steps", | |
type=int, | |
default=50, | |
help="number of ddim sampling steps", | |
) | |
parser.add_argument( | |
"--n_samples", | |
type=int, | |
default=1, | |
help="how many samples to produce for each given prompt. A.k.a. batch size", | |
) | |
parser.add_argument( | |
"--scale", | |
type=float, | |
default=7.5, | |
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", | |
) | |
parser.add_argument( | |
"--gpus", | |
default=[0, 1, 2, 3], | |
help="gpu idx", | |
) | |
parser.add_argument( | |
'--local_rank', | |
default=0, | |
type=int, | |
help='node rank for distributed training' | |
) | |
parser.add_argument( | |
'--launcher', | |
default='pytorch', | |
type=str, | |
help='node rank for distributed training' | |
) | |
opt = parser.parse_args() | |
return opt | |
def main(): | |
opt = parsr_args() | |
config = OmegaConf.load(f"{opt.config}") | |
# distributed setting | |
init_dist(opt.launcher) | |
torch.backends.cudnn.benchmark = True | |
device = 'cuda' | |
torch.cuda.set_device(opt.local_rank) | |
# dataset | |
train_dataset = DepthDataset('datasets/laion_depth_meta_v1.txt') | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, | |
batch_size=opt.bsize, | |
shuffle=(train_sampler is None), | |
num_workers=opt.num_workers, | |
pin_memory=True, | |
sampler=train_sampler) | |
# stable diffusion | |
model = load_model_from_config(config, f"{opt.ckpt}").to(device) | |
# depth encoder | |
model_ad = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to( | |
device) | |
# to gpus | |
model_ad = torch.nn.parallel.DistributedDataParallel( | |
model_ad, | |
device_ids=[opt.local_rank], | |
output_device=opt.local_rank) | |
model = torch.nn.parallel.DistributedDataParallel( | |
model, | |
device_ids=[opt.local_rank], | |
output_device=opt.local_rank) | |
# optimizer | |
params = list(model_ad.parameters()) | |
optimizer = torch.optim.AdamW(params, lr=config['training']['lr']) | |
experiments_root = osp.join('experiments', opt.name) | |
# resume state | |
resume_state = load_resume_state(opt) | |
if resume_state is None: | |
mkdir_and_rename(experiments_root) | |
start_epoch = 0 | |
current_iter = 0 | |
# WARNING: should not use get_root_logger in the above codes, including the called functions | |
# Otherwise the logger will not be properly initialized | |
log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") | |
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) | |
logger.info(get_env_info()) | |
logger.info(dict2str(config)) | |
else: | |
# WARNING: should not use get_root_logger in the above codes, including the called functions | |
# Otherwise the logger will not be properly initialized | |
log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") | |
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) | |
logger.info(get_env_info()) | |
logger.info(dict2str(config)) | |
resume_optimizers = resume_state['optimizers'] | |
optimizer.load_state_dict(resume_optimizers) | |
logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") | |
start_epoch = resume_state['epoch'] | |
current_iter = resume_state['iter'] | |
# copy the yml file to the experiment root | |
copy_opt_file(opt.config, experiments_root) | |
# training | |
logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') | |
for epoch in range(start_epoch, opt.epochs): | |
train_dataloader.sampler.set_epoch(epoch) | |
# train | |
for _, data in enumerate(train_dataloader): | |
current_iter += 1 | |
with torch.no_grad(): | |
c = model.module.get_learned_conditioning(data['sentence']) | |
z = model.module.encode_first_stage((data['im'] * 2 - 1.).to(device)) | |
z = model.module.get_first_stage_encoding(z) | |
optimizer.zero_grad() | |
model.zero_grad() | |
features_adapter = model_ad(data['depth'].to(device)) | |
l_pixel, loss_dict = model(z, c=c, features_adapter=features_adapter) | |
l_pixel.backward() | |
optimizer.step() | |
if (current_iter + 1) % opt.print_fq == 0: | |
logger.info(loss_dict) | |
# save checkpoint | |
rank, _ = get_dist_info() | |
if (rank == 0) and ((current_iter + 1) % config['training']['save_freq'] == 0): | |
save_filename = f'model_ad_{current_iter + 1}.pth' | |
save_path = os.path.join(experiments_root, 'models', save_filename) | |
save_dict = {} | |
state_dict = model_ad.state_dict() | |
for key, param in state_dict.items(): | |
if key.startswith('module.'): # remove unnecessary 'module.' | |
key = key[7:] | |
save_dict[key] = param.cpu() | |
torch.save(save_dict, save_path) | |
# save state | |
state = {'epoch': epoch, 'iter': current_iter + 1, 'optimizers': optimizer.state_dict()} | |
save_filename = f'{current_iter + 1}.state' | |
save_path = os.path.join(experiments_root, 'training_states', save_filename) | |
torch.save(state, save_path) | |
if __name__ == '__main__': | |
main() | |