|
import argparse |
|
|
|
import torch |
|
import yaml |
|
|
|
from pytorch3dunet.unet3d import utils |
|
|
|
logger = utils.get_logger('ConfigLoader') |
|
|
|
|
|
def load_config(): |
|
parser = argparse.ArgumentParser(description='UNet3D') |
|
parser.add_argument('--config', type=str, help='Path to the YAML config file', required=True) |
|
args = parser.parse_args() |
|
config = yaml.safe_load(open(args.config, 'r')) |
|
|
|
device = config.get('device', None) |
|
if device == 'cpu': |
|
logger.warning('CPU mode forced in config, this will likely result in slow training/prediction') |
|
config['device'] = 'cpu' |
|
return config |
|
|
|
if torch.cuda.is_available(): |
|
config['device'] = 'cuda' |
|
else: |
|
logger.warning('CUDA not available, using CPU') |
|
config['device'] = 'cpu' |
|
return config |
|
|
|
|
|
def _load_config_yaml(config_file): |
|
return yaml.safe_load(open(config_file, 'r')) |
|
|