Spaces:
Runtime error
Runtime error
import os | |
import os.path as osp | |
import argparse | |
import torch | |
import time | |
from torch.utils.data import DataLoader | |
from mmengine.utils import mkdir_or_exist | |
from mmengine.config import Config, DictAction | |
from mmengine.logging import MMLogger | |
from estimator.utils import RunnerInfo, setup_env, log_env, fix_random_seed | |
from estimator.models.builder import build_model | |
from estimator.datasets.builder import build_dataset | |
from estimator.tester import Tester | |
from estimator.models.patchfusion import PatchFusion | |
from mmengine import print_log | |
from transformers import PretrainedConfig | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train a segmentor') | |
parser.add_argument('config', help='train config file path') | |
parser.add_argument( | |
'--ckp-path', | |
type=str, | |
help='ckp_path') | |
parser.add_argument( | |
'--save-path', | |
type=str, | |
help='ckp_path') | |
parser.add_argument( | |
'--cfg-options', | |
nargs='+', | |
action=DictAction, | |
help='override some settings in the used config, the key-value pair ' | |
'in xxx=yyy format will be merged into config file. If the value to ' | |
'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
'Note that the quotation marks are necessary and that no white space ' | |
'is allowed.') | |
parser.add_argument( | |
'--launcher', | |
choices=['none', 'pytorch', 'slurm', 'mpi'], | |
default='none', | |
help='job launcher') | |
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch` | |
# will pass the `--local-rank` parameter to `tools/train.py` instead | |
# of `--local_rank`. | |
parser.add_argument('--local_rank', '--local-rank', type=int, default=0) | |
args = parser.parse_args() | |
if 'LOCAL_RANK' not in os.environ: | |
os.environ['LOCAL_RANK'] = str(args.local_rank) | |
return args | |
def main(): | |
args = parse_args() | |
# load config | |
cfg = Config.fromfile(args.config) | |
cfg.ckp_path = args.ckp_path | |
# folder_name = os.path.dirname(args.save_path) | |
# print(folder_name) | |
# exit(100) | |
# build model | |
model = build_model(cfg.model) | |
print_log('Checkpoint Path: {}'.format(cfg.ckp_path), logger='current') | |
if hasattr(model, 'load_dict'): | |
print_log(model.load_dict(torch.load(cfg.ckp_path)['model_state_dict']), logger='current') | |
else: | |
print_log(model.load_state_dict(torch.load(cfg.ckp_path)['model_state_dict'], strict=True), logger='current') | |
model.eval() | |
model.save_pretrained(args.save_path) | |
model.config.to_json_file(os.path.join(args.save_path, "config.json")) | |
# model = PatchFusion.from_pretrained('Zhyever/patchfusion_depth_anything_vits14') | |
if __name__ == '__main__': | |
main() |