MotionGPT / mGPT /config.py
bill-jiang's picture
Init
4409449
raw
history blame
6.9 kB
import importlib
from argparse import ArgumentParser
from omegaconf import OmegaConf
from os.path import join as pjoin
import os
import glob
def get_module_config(cfg, filepath="./configs"):
"""
Load yaml config files from subfolders
"""
yamls = glob.glob(pjoin(filepath, '*', '*.yaml'))
yamls = [y.replace(filepath, '') for y in yamls]
for yaml in yamls:
nodes = yaml.replace('.yaml', '').replace('/', '.')
nodes = nodes[1:] if nodes[0] == '.' else nodes
OmegaConf.update(cfg, nodes, OmegaConf.load('./configs' + yaml))
return cfg
def get_obj_from_str(string, reload=False):
"""
Get object from string
"""
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
"""
Instantiate object from config
"""
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def resume_config(cfg: OmegaConf):
"""
Resume model and wandb
"""
if cfg.TRAIN.RESUME:
resume = cfg.TRAIN.RESUME
if os.path.exists(resume):
# Checkpoints
cfg.TRAIN.PRETRAINED = pjoin(resume, "checkpoints", "last.ckpt")
# Wandb
wandb_files = os.listdir(pjoin(resume, "wandb", "latest-run"))
wandb_run = [item for item in wandb_files if "run-" in item][0]
cfg.LOGGER.WANDB.params.id = wandb_run.replace("run-","").replace(".wandb", "")
else:
raise ValueError("Resume path is not right.")
return cfg
def parse_args(phase="train"):
"""
Parse arguments and load config files
"""
parser = ArgumentParser()
group = parser.add_argument_group("Training options")
# Assets
group.add_argument(
"--cfg_assets",
type=str,
required=False,
default="./configs/assets.yaml",
help="config file for asset paths",
)
# Default config
if phase in ["train", "test"]:
cfg_defualt = "./configs/default.yaml"
elif phase == "render":
cfg_defualt = "./configs/render.yaml"
elif phase == "webui":
cfg_defualt = "./configs/webui.yaml"
group.add_argument(
"--cfg",
type=str,
required=False,
default=cfg_defualt,
help="config file",
)
# Parse for each phase
if phase in ["train", "test"]:
group.add_argument("--batch_size",
type=int,
required=False,
help="training batch size")
group.add_argument("--num_nodes",
type=int,
required=False,
help="number of nodes")
group.add_argument("--device",
type=int,
nargs="+",
required=False,
help="training device")
group.add_argument("--task",
type=str,
required=False,
help="evaluation task type")
group.add_argument("--nodebug",
action="store_true",
required=False,
help="debug or not")
if phase == "demo":
group.add_argument(
"--example",
type=str,
required=False,
help="input text and lengths with txt format",
)
group.add_argument(
"--out_dir",
type=str,
required=False,
help="output dir",
)
group.add_argument("--task",
type=str,
required=False,
help="evaluation task type")
if phase == "render":
group.add_argument("--npy",
type=str,
required=False,
default=None,
help="npy motion files")
group.add_argument("--dir",
type=str,
required=False,
default=None,
help="npy motion folder")
group.add_argument("--fps",
type=int,
required=False,
default=30,
help="render fps")
group.add_argument(
"--mode",
type=str,
required=False,
default="sequence",
help="render target: video, sequence, frame",
)
params = parser.parse_args()
# Load yaml config files
OmegaConf.register_new_resolver("eval", eval)
cfg_assets = OmegaConf.load(params.cfg_assets)
cfg_base = OmegaConf.load(pjoin(cfg_assets.CONFIG_FOLDER, 'default.yaml'))
cfg_exp = OmegaConf.merge(cfg_base, OmegaConf.load(params.cfg))
if not cfg_exp.FULL_CONFIG:
cfg_exp = get_module_config(cfg_exp, cfg_assets.CONFIG_FOLDER)
cfg = OmegaConf.merge(cfg_exp, cfg_assets)
# Update config with arguments
if phase in ["train", "test"]:
cfg.TRAIN.BATCH_SIZE = params.batch_size if params.batch_size else cfg.TRAIN.BATCH_SIZE
cfg.DEVICE = params.device if params.device else cfg.DEVICE
cfg.NUM_NODES = params.num_nodes if params.num_nodes else cfg.NUM_NODES
cfg.model.params.task = params.task if params.task else cfg.model.params.task
cfg.DEBUG = not params.nodebug if params.nodebug is not None else cfg.DEBUG
# Force no debug in test
if phase == "test":
cfg.DEBUG = False
cfg.DEVICE = [0]
print("Force no debugging and one gpu when testing")
if phase == "demo":
cfg.DEMO.RENDER = params.render
cfg.DEMO.FRAME_RATE = params.frame_rate
cfg.DEMO.EXAMPLE = params.example
cfg.DEMO.TASK = params.task
cfg.TEST.FOLDER = params.out_dir if params.out_dir else cfg.TEST.FOLDER
os.makedirs(cfg.TEST.FOLDER, exist_ok=True)
if phase == "render":
if params.npy:
cfg.RENDER.NPY = params.npy
cfg.RENDER.INPUT_MODE = "npy"
if params.dir:
cfg.RENDER.DIR = params.dir
cfg.RENDER.INPUT_MODE = "dir"
if params.fps:
cfg.RENDER.FPS = float(params.fps)
cfg.RENDER.MODE = params.mode
# Debug mode
if cfg.DEBUG:
cfg.NAME = "debug--" + cfg.NAME
cfg.LOGGER.WANDB.params.offline = True
cfg.LOGGER.VAL_EVERY_STEPS = 1
# Resume config
cfg = resume_config(cfg)
return cfg