monai
medical
katielink's picture
complete the model package
618f7d3
raw
history blame
21.5 kB
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import random
import sys
import time
from datetime import datetime
from typing import Sequence, Union
import monai
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
from monai import transforms
from monai.bundle import ConfigParser
from monai.data import ThreadDataLoader, partition_dataset
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.utils import set_determinism
from torch.nn.parallel import DistributedDataParallel
from torch.utils.tensorboard import SummaryWriter
def run(config_file: Union[str, Sequence[str]]):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
parser = ConfigParser()
parser.read_config(config_file)
arch_ckpt_path = parser["arch_ckpt_path"]
amp = parser["amp"]
data_file_base_dir = parser["data_file_base_dir"]
data_list_file_path = parser["data_list_file_path"]
determ = parser["determ"]
learning_rate = parser["learning_rate"]
learning_rate_arch = parser["learning_rate_arch"]
learning_rate_milestones = np.array(parser["learning_rate_milestones"])
num_images_per_batch = parser["num_images_per_batch"]
num_epochs = parser["num_epochs"] # around 20k iterations
num_epochs_per_validation = parser["num_epochs_per_validation"]
num_epochs_warmup = parser["num_epochs_warmup"]
num_sw_batch_size = parser["num_sw_batch_size"]
output_classes = parser["output_classes"]
overlap_ratio = parser["overlap_ratio"]
patch_size_valid = parser["patch_size_valid"]
ram_cost_factor = parser["ram_cost_factor"]
print("[info] GPU RAM cost factor:", ram_cost_factor)
train_transforms = parser.get_parsed_content("transform_train")
val_transforms = parser.get_parsed_content("transform_validation")
# deterministic training
if determ:
set_determinism(seed=0)
print("[info] number of GPUs:", torch.cuda.device_count())
if torch.cuda.device_count() > 1:
# initialize the distributed training process, every GPU runs in a process
dist.init_process_group(backend="nccl", init_method="env://")
world_size = dist.get_world_size()
else:
world_size = 1
print("[info] world_size:", world_size)
with open(data_list_file_path, "r") as f:
json_data = json.load(f)
list_train = json_data["training"]
list_valid = json_data["validation"]
# training data
files = []
for _i in range(len(list_train)):
str_img = os.path.join(data_file_base_dir, list_train[_i]["image"])
str_seg = os.path.join(data_file_base_dir, list_train[_i]["label"])
if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
continue
files.append({"image": str_img, "label": str_seg})
train_files = files
random.shuffle(train_files)
train_files_w = train_files[: len(train_files) // 2]
if torch.cuda.device_count() > 1:
train_files_w = partition_dataset(
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
)[dist.get_rank()]
print("train_files_w:", len(train_files_w))
train_files_a = train_files[len(train_files) // 2 :]
if torch.cuda.device_count() > 1:
train_files_a = partition_dataset(
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
)[dist.get_rank()]
print("train_files_a:", len(train_files_a))
# validation data
files = []
for _i in range(len(list_valid)):
str_img = os.path.join(data_file_base_dir, list_valid[_i]["image"])
str_seg = os.path.join(data_file_base_dir, list_valid[_i]["label"])
if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)):
continue
files.append({"image": str_img, "label": str_seg})
val_files = files
if torch.cuda.device_count() > 1:
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
dist.get_rank()
]
print("val_files:", len(val_files))
# network architecture
if torch.cuda.device_count() > 1:
device = torch.device(f"cuda:{dist.get_rank()}")
else:
device = torch.device("cuda:0")
torch.cuda.set_device(device)
if torch.cuda.device_count() > 1:
train_ds_a = monai.data.CacheDataset(
data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8
)
train_ds_w = monai.data.CacheDataset(
data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8
)
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2)
else:
train_ds_a = monai.data.CacheDataset(
data=train_files_a, transform=train_transforms, cache_rate=0.125, num_workers=8
)
train_ds_w = monai.data.CacheDataset(
data=train_files_w, transform=train_transforms, cache_rate=0.125, num_workers=8
)
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.125, num_workers=2)
train_loader_a = ThreadDataLoader(train_ds_a, num_workers=6, batch_size=num_images_per_batch, shuffle=True)
train_loader_w = ThreadDataLoader(train_ds_w, num_workers=6, batch_size=num_images_per_batch, shuffle=True)
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False)
model = parser.get_parsed_content("network")
dints_space = parser.get_parsed_content("dints_space")
model = model.to(device)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
post_pred = transforms.Compose(
[transforms.EnsureType(), transforms.AsDiscrete(argmax=True, to_onehot=output_classes)]
)
post_label = transforms.Compose([transforms.EnsureType(), transforms.AsDiscrete(to_onehot=output_classes)])
# loss function
loss_func = parser.get_parsed_content("loss")
# optimizer
optimizer = torch.optim.SGD(
model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004
)
arch_optimizer_a = torch.optim.Adam(
[dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0
)
arch_optimizer_c = torch.optim.Adam(
[dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0
)
if torch.cuda.device_count() > 1:
model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True)
# amp
if amp:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
print("[info] amp enabled")
# start a typical PyTorch training
val_interval = num_epochs_per_validation
best_metric = -1
best_metric_epoch = -1
idx_iter = 0
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
writer = SummaryWriter(log_dir=os.path.join(arch_ckpt_path, "Events"))
with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f:
f.write("epoch\tmetric\tloss\tlr\ttime\titer\n")
dataloader_a_iterator = iter(train_loader_a)
start_time = time.time()
for epoch in range(num_epochs):
decay = 0.5 ** np.sum(
[(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones]
)
lr = learning_rate * decay * world_size
for param_group in optimizer.param_groups:
param_group["lr"] = lr
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
print("-" * 10)
print(f"epoch {epoch + 1}/{num_epochs}")
print("learning rate is set to {}".format(lr))
model.train()
epoch_loss = 0
loss_torch = torch.zeros(2, dtype=torch.float, device=device)
epoch_loss_arch = 0
loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device)
step = 0
for batch_data in train_loader_w:
step += 1
inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
if world_size == 1:
for _ in model.weight_parameters():
_.requires_grad = True
else:
for _ in model.module.weight_parameters():
_.requires_grad = True
dints_space.log_alpha_a.requires_grad = False
dints_space.log_alpha_c.requires_grad = False
optimizer.zero_grad()
if amp:
with autocast():
outputs = model(inputs)
if output_classes == 2:
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
else:
loss = loss_func(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
outputs = model(inputs)
if output_classes == 2:
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels)
else:
loss = loss_func(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
loss_torch[0] += loss.item()
loss_torch[1] += 1.0
epoch_len = len(train_loader_w)
idx_iter += 1
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
if epoch < num_epochs_warmup:
continue
try:
sample_a = next(dataloader_a_iterator)
except StopIteration:
dataloader_a_iterator = iter(train_loader_a)
sample_a = next(dataloader_a_iterator)
inputs_search, labels_search = (sample_a["image"].to(device), sample_a["label"].to(device))
if world_size == 1:
for _ in model.weight_parameters():
_.requires_grad = False
else:
for _ in model.module.weight_parameters():
_.requires_grad = False
dints_space.log_alpha_a.requires_grad = True
dints_space.log_alpha_c.requires_grad = True
# linear increase topology and RAM loss
entropy_alpha_c = torch.tensor(0.0).to(device)
entropy_alpha_a = torch.tensor(0.0).to(device)
ram_cost_full = torch.tensor(0.0).to(device)
ram_cost_usage = torch.tensor(0.0).to(device)
ram_cost_loss = torch.tensor(0.0).to(device)
topology_loss = torch.tensor(0.0).to(device)
probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True)
entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean()
entropy_alpha_c = -(
F.softmax(dints_space.log_alpha_c, dim=-1) * F.log_softmax(dints_space.log_alpha_c, dim=-1)
).mean()
topology_loss = dints_space.get_topology_entropy(probs_a)
ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True)
ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape)
ram_cost_loss = torch.abs(ram_cost_factor - ram_cost_usage / ram_cost_full)
arch_optimizer_a.zero_grad()
arch_optimizer_c.zero_grad()
combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup)
if amp:
with autocast():
outputs_search = model(inputs_search)
if output_classes == 2:
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)
else:
loss = loss_func(outputs_search, labels_search)
loss += combination_weights * (
(entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss
)
scaler.scale(loss).backward()
scaler.step(arch_optimizer_a)
scaler.step(arch_optimizer_c)
scaler.update()
else:
outputs_search = model(inputs_search)
if output_classes == 2:
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search)
else:
loss = loss_func(outputs_search, labels_search)
loss += 1.0 * (
combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss
)
loss.backward()
arch_optimizer_a.step()
arch_optimizer_c.step()
epoch_loss_arch += loss.item()
loss_torch_arch[0] += loss.item()
loss_torch_arch[1] += 1.0
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
print(
"[{0}] ".format(str(datetime.now())[:19])
+ f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}"
)
writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step)
# synchronizes all processes and reduce results
if torch.cuda.device_count() > 1:
dist.barrier()
dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)
loss_torch = loss_torch.tolist()
loss_torch_arch = loss_torch_arch.tolist()
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
loss_torch_epoch = loss_torch[0] / loss_torch[1]
print(
f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, "
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
)
if epoch >= num_epochs_warmup:
loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1]
print(
f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, "
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}"
)
if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs:
torch.cuda.empty_cache()
model.eval()
with torch.no_grad():
metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device)
metric_sum = 0.0
metric_count = 0
metric_mat = []
val_images = None
val_labels = None
val_outputs = None
_index = 0
for val_data in val_loader:
val_images = val_data["image"].to(device)
val_labels = val_data["label"].to(device)
roi_size = patch_size_valid
sw_batch_size = num_sw_batch_size
if amp:
with torch.cuda.amp.autocast():
pred = sliding_window_inference(
val_images,
roi_size,
sw_batch_size,
lambda x: model(x),
mode="gaussian",
overlap=overlap_ratio,
)
else:
pred = sliding_window_inference(
val_images,
roi_size,
sw_batch_size,
lambda x: model(x),
mode="gaussian",
overlap=overlap_ratio,
)
val_outputs = pred
val_outputs = post_pred(val_outputs[0, ...])
val_outputs = val_outputs[None, ...]
val_labels = post_label(val_labels[0, ...])
val_labels = val_labels[None, ...]
value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False)
print(_index + 1, "/", len(val_loader), value)
metric_count += len(value)
metric_sum += value.sum().item()
metric_vals = value.cpu().numpy()
if len(metric_mat) == 0:
metric_mat = metric_vals
else:
metric_mat = np.concatenate((metric_mat, metric_vals), axis=0)
for _c in range(output_classes - 1):
val0 = torch.nan_to_num(value[0, _c], nan=0.0)
val1 = 1.0 - torch.isnan(value[0, 0]).float()
metric[2 * _c] += val0 * val1
metric[2 * _c + 1] += val1
_index += 1
# synchronizes all processes and reduce results
if torch.cuda.device_count() > 1:
dist.barrier()
dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM)
metric = metric.tolist()
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
for _c in range(output_classes - 1):
print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1])
avg_metric = 0
for _c in range(output_classes - 1):
avg_metric += metric[2 * _c] / metric[2 * _c + 1]
avg_metric = avg_metric / float(output_classes - 1)
print("avg_metric", avg_metric)
if avg_metric > best_metric:
best_metric = avg_metric
best_metric_epoch = epoch + 1
best_metric_iterations = idx_iter
(node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d) = dints_space.decode()
torch.save(
{
"node_a": node_a_d,
"arch_code_a": arch_code_a_d,
"arch_code_a_max": arch_code_a_max_d,
"arch_code_c": arch_code_c_d,
"iter_num": idx_iter,
"epochs": epoch + 1,
"best_dsc": best_metric,
"best_path": best_metric_iterations,
},
os.path.join(arch_ckpt_path, "search_code_" + str(idx_iter) + ".pt"),
)
print("saved new best metric model")
dict_file = {}
dict_file["best_avg_dice_score"] = float(best_metric)
dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch)
dict_file["best_avg_dice_score_iteration"] = int(idx_iter)
with open(os.path.join(arch_ckpt_path, "progress.yaml"), "w") as out_file:
_ = yaml.dump(dict_file, stream=out_file)
print(
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
epoch + 1, avg_metric, best_metric, best_metric_epoch
)
)
current_time = time.time()
elapsed_time = (current_time - start_time) / 60.0
with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f:
f.write(
"{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format(
epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter
)
)
if torch.cuda.device_count() > 1:
dist.barrier()
torch.cuda.empty_cache()
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
if torch.cuda.device_count() == 1 or dist.get_rank() == 0:
writer.close()
if torch.cuda.device_count() > 1:
dist.destroy_process_group()
if __name__ == "__main__":
from monai.utils import optional_import
fire, _ = optional_import("fire")
fire.Fire()