Spaces:
Build error
Build error
import importlib | |
from argparse import ArgumentParser | |
from omegaconf import OmegaConf | |
from os.path import join as pjoin | |
import os | |
import glob | |
def get_module_config(cfg, filepath="./configs"): | |
""" | |
Load yaml config files from subfolders | |
""" | |
yamls = glob.glob(pjoin(filepath, '*', '*.yaml')) | |
yamls = [y.replace(filepath, '') for y in yamls] | |
for yaml in yamls: | |
nodes = yaml.replace('.yaml', '').replace('/', '.') | |
nodes = nodes[1:] if nodes[0] == '.' else nodes | |
OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml)) | |
return cfg | |
def get_obj_from_str(string, reload=False): | |
""" | |
Get object from string | |
""" | |
module, cls = string.rsplit(".", 1) | |
if reload: | |
module_imp = importlib.import_module(module) | |
importlib.reload(module_imp) | |
return getattr(importlib.import_module(module, package=None), cls) | |
def instantiate_from_config(config): | |
""" | |
Instantiate object from config | |
""" | |
if not "target" in config: | |
raise KeyError("Expected key `target` to instantiate.") | |
return get_obj_from_str(config["target"])(**config.get("params", dict())) | |
def resume_config(cfg: OmegaConf): | |
""" | |
Resume model and wandb | |
""" | |
if cfg.TRAIN.RESUME: | |
resume = cfg.TRAIN.RESUME | |
if os.path.exists(resume): | |
# Checkpoints | |
cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt") | |
# Wandb | |
wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run")) | |
wandb_run = [item for item in wandb_files if "run-" in item][0] | |
cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "") | |
else: | |
raise ValueError("Resume path is not right.") | |
return cfg | |
def parse_args(phase="train"): | |
""" | |
Parse arguments and load config files | |
""" | |
parser = ArgumentParser() | |
group = parser.add_argument_group("Training options") | |
# Assets | |
group.add_argument( | |
"--cfg_assets", | |
type=str, | |
required=False, | |
default="./configs/assets.yaml", | |
help="config file for asset paths", | |
) | |
# Default config | |
if phase in ["train", "test"]: | |
cfg_defualt = "./configs/default.yaml" | |
elif phase == "render": | |
cfg_defualt = "./configs/render.yaml" | |
elif phase == "webui": | |
cfg_defualt = "./configs/webui.yaml" | |
group.add_argument( | |
"--cfg", | |
type=str, | |
required=False, | |
default=cfg_defualt, | |
help="config file", | |
) | |
# Parse for each phase | |
if phase in ["train", "test"]: | |
group.add_argument("--batch_size", | |
type=int, | |
required=False, | |
help="training batch size") | |
group.add_argument("--num_nodes", | |
type=int, | |
required=False, | |
help="number of nodes") | |
group.add_argument("--device", | |
type=int, | |
nargs="+", | |
required=False, | |
help="training device") | |
group.add_argument("--task", | |
type=str, | |
required=False, | |
help="evaluation task type") | |
group.add_argument("--nodebug", | |
action="store_true", | |
required=False, | |
help="debug or not") | |
if phase == "demo": | |
group.add_argument( | |
"--example", | |
type=str, | |
required=False, | |
help="input text and lengths with txt format", | |
) | |
group.add_argument( | |
"--out_dir", | |
type=str, | |
required=False, | |
help="output dir", | |
) | |
group.add_argument("--task", | |
type=str, | |
required=False, | |
help="evaluation task type") | |
if phase == "render": | |
group.add_argument("--npy", | |
type=str, | |
required=False, | |
default=None, | |
help="npy motion files") | |
group.add_argument("--dir", | |
type=str, | |
required=False, | |
default=None, | |
help="npy motion folder") | |
group.add_argument("--fps", | |
type=int, | |
required=False, | |
default=30, | |
help="render fps") | |
group.add_argument( | |
"--mode", | |
type=str, | |
required=False, | |
default="sequence", | |
help="render target: video, sequence, frame", | |
) | |
params = parser.parse_args() | |
# Load yaml config files | |
OmegaConf.register_new_resolver("eval", eval) | |
cfg_assets = OmegaConf.load(params.cfg_assets) | |
cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml')) | |
cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg)) | |
if not cfg_exp.FULL_CONFIG: | |
cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER) | |
cfg = OmegaConf.merge(cfg_exp, cfg_assets) | |
# Update config with arguments | |
if phase in ["train", "test"]: | |
cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE | |
cfg.DEVICE = params.device if params.device else cfg.DEVICE | |
cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES | |
cfg.model.params.task = params.task if params.task else cfg.model.params.task | |
cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG | |
# Force no debug in test | |
if phase == "test": | |
cfg.DEBUG = False | |
cfg.DEVICE = [0] | |
print("Force no debugging and one gpu when testing") | |
if phase == "demo": | |
cfg.DEMO.RENDER = params.render | |
cfg.DEMO.FRAME_RATE = params.frame_rate | |
cfg.DEMO.EXAMPLE = params.example | |
cfg.DEMO.TASK = params.task | |
cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER | |
os.makedirs(cfg.TEST.FOLDER, exist_ok=True) | |
if phase == "render": | |
if params.npy: | |
cfg.RENDER.NPY = params.npy | |
cfg.RENDER.INPUT_MODE = "npy" | |
if params.dir: | |
cfg.RENDER.DIR = params.dir | |
cfg.RENDER.INPUT_MODE = "dir" | |
if params.fps: | |
cfg.RENDER.FPS = float(params.fps) | |
cfg.RENDER.MODE = params.mode | |
# Debug mode | |
if cfg.DEBUG: | |
cfg.NAME = "debug--" + cfg.NAME | |
cfg.LOGGER.WANDB.params.offline = True | |
cfg.LOGGER.VAL_EVERY_STEPS = 1 | |
# Resume config | |
cfg = resume_config(cfg) | |
return cfg | |