|
import torch.utils.data |
|
from dataset import Offline_Dataset |
|
import yaml |
|
from sgmnet.match_model import matcher as SGM_Model |
|
from superglue.match_model import matcher as SG_Model |
|
import torch.distributed as dist |
|
import torch |
|
import os |
|
from collections import namedtuple |
|
from train import train |
|
from config import get_config, print_usage |
|
|
|
|
|
def main(config,model_config): |
|
"""The main function.""" |
|
|
|
if config.model_name=='SGM': |
|
model = SGM_Model(model_config) |
|
elif config.model_name=='SG': |
|
model= SG_Model(model_config) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
torch.cuda.set_device(config.local_rank) |
|
device = torch.device(f'cuda:{config.local_rank}') |
|
model.to(device) |
|
dist.init_process_group(backend='nccl',init_method='env://') |
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.local_rank]) |
|
|
|
if config.local_rank==0: |
|
os.system('nvidia-smi') |
|
|
|
|
|
train_dataset = Offline_Dataset(config,'train') |
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,shuffle=True) |
|
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size//torch.distributed.get_world_size(), |
|
num_workers=8//dist.get_world_size(), pin_memory=False,sampler=train_sampler,collate_fn=train_dataset.collate_fn) |
|
|
|
valid_dataset = Offline_Dataset(config,'valid') |
|
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,shuffle=False) |
|
valid_loader=torch.utils.data.DataLoader(valid_dataset, batch_size=config.train_batch_size, |
|
num_workers=8//dist.get_world_size(), pin_memory=False,collate_fn=valid_dataset.collate_fn,sampler=valid_sampler) |
|
|
|
if config.local_rank==0: |
|
print('start training .....') |
|
train(model,train_loader, valid_loader, config,model_config) |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
config, unparsed = get_config() |
|
with open(config.config_path, 'r') as f: |
|
model_config = yaml.load(f) |
|
model_config=namedtuple('model_config',model_config.keys())(*model_config.values()) |
|
|
|
if len(unparsed) > 0: |
|
print_usage() |
|
exit(1) |
|
|
|
main(config,model_config) |
|
|