|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import random |
|
import sys |
|
import time |
|
from datetime import datetime |
|
from typing import Sequence, Union |
|
|
|
import monai |
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn.functional as F |
|
import yaml |
|
from monai import transforms |
|
from monai.bundle import ConfigParser |
|
from monai.data import ThreadDataLoader, partition_dataset |
|
from monai.inferers import sliding_window_inference |
|
from monai.metrics import compute_meandice |
|
from monai.utils import set_determinism |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
def run(config_file: Union[str, Sequence[str]]): |
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
|
|
parser = ConfigParser() |
|
parser.read_config(config_file) |
|
|
|
arch_ckpt_path = parser["arch_ckpt_path"] |
|
amp = parser["amp"] |
|
data_file_base_dir = parser["data_file_base_dir"] |
|
data_list_file_path = parser["data_list_file_path"] |
|
determ = parser["determ"] |
|
learning_rate = parser["learning_rate"] |
|
learning_rate_arch = parser["learning_rate_arch"] |
|
learning_rate_milestones = np.array(parser["learning_rate_milestones"]) |
|
num_images_per_batch = parser["num_images_per_batch"] |
|
num_epochs = parser["num_epochs"] |
|
num_epochs_per_validation = parser["num_epochs_per_validation"] |
|
num_epochs_warmup = parser["num_epochs_warmup"] |
|
num_sw_batch_size = parser["num_sw_batch_size"] |
|
output_classes = parser["output_classes"] |
|
overlap_ratio = parser["overlap_ratio"] |
|
patch_size_valid = parser["patch_size_valid"] |
|
ram_cost_factor = parser["ram_cost_factor"] |
|
print("[info] GPU RAM cost factor:", ram_cost_factor) |
|
|
|
train_transforms = parser.get_parsed_content("transform_train") |
|
val_transforms = parser.get_parsed_content("transform_validation") |
|
|
|
|
|
if determ: |
|
set_determinism(seed=0) |
|
|
|
print("[info] number of GPUs:", torch.cuda.device_count()) |
|
if torch.cuda.device_count() > 1: |
|
|
|
dist.init_process_group(backend="nccl", init_method="env://") |
|
world_size = dist.get_world_size() |
|
else: |
|
world_size = 1 |
|
print("[info] world_size:", world_size) |
|
|
|
with open(data_list_file_path, "r") as f: |
|
json_data = json.load(f) |
|
|
|
list_train = json_data["training"] |
|
list_valid = json_data["validation"] |
|
|
|
|
|
files = [] |
|
for _i in range(len(list_train)): |
|
str_img = os.path.join(data_file_base_dir, list_train[_i]["image"]) |
|
str_seg = os.path.join(data_file_base_dir, list_train[_i]["label"]) |
|
|
|
if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): |
|
continue |
|
|
|
files.append({"image": str_img, "label": str_seg}) |
|
train_files = files |
|
|
|
random.shuffle(train_files) |
|
|
|
train_files_w = train_files[: len(train_files) // 2] |
|
if torch.cuda.device_count() > 1: |
|
train_files_w = partition_dataset( |
|
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True |
|
)[dist.get_rank()] |
|
print("train_files_w:", len(train_files_w)) |
|
|
|
train_files_a = train_files[len(train_files) // 2 :] |
|
if torch.cuda.device_count() > 1: |
|
train_files_a = partition_dataset( |
|
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True |
|
)[dist.get_rank()] |
|
print("train_files_a:", len(train_files_a)) |
|
|
|
|
|
files = [] |
|
for _i in range(len(list_valid)): |
|
str_img = os.path.join(data_file_base_dir, list_valid[_i]["image"]) |
|
str_seg = os.path.join(data_file_base_dir, list_valid[_i]["label"]) |
|
|
|
if (not os.path.exists(str_img)) or (not os.path.exists(str_seg)): |
|
continue |
|
|
|
files.append({"image": str_img, "label": str_seg}) |
|
val_files = files |
|
|
|
if torch.cuda.device_count() > 1: |
|
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[ |
|
dist.get_rank() |
|
] |
|
print("val_files:", len(val_files)) |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
device = torch.device(f"cuda:{dist.get_rank()}") |
|
else: |
|
device = torch.device("cuda:0") |
|
torch.cuda.set_device(device) |
|
|
|
if torch.cuda.device_count() > 1: |
|
train_ds_a = monai.data.CacheDataset( |
|
data=train_files_a, transform=train_transforms, cache_rate=1.0, num_workers=8 |
|
) |
|
train_ds_w = monai.data.CacheDataset( |
|
data=train_files_w, transform=train_transforms, cache_rate=1.0, num_workers=8 |
|
) |
|
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=2) |
|
else: |
|
train_ds_a = monai.data.CacheDataset( |
|
data=train_files_a, transform=train_transforms, cache_rate=0.125, num_workers=8 |
|
) |
|
train_ds_w = monai.data.CacheDataset( |
|
data=train_files_w, transform=train_transforms, cache_rate=0.125, num_workers=8 |
|
) |
|
val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.125, num_workers=2) |
|
|
|
train_loader_a = ThreadDataLoader(train_ds_a, num_workers=6, batch_size=num_images_per_batch, shuffle=True) |
|
train_loader_w = ThreadDataLoader(train_ds_w, num_workers=6, batch_size=num_images_per_batch, shuffle=True) |
|
val_loader = ThreadDataLoader(val_ds, num_workers=0, batch_size=1, shuffle=False) |
|
|
|
model = parser.get_parsed_content("network") |
|
dints_space = parser.get_parsed_content("dints_space") |
|
|
|
model = model.to(device) |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
|
|
post_pred = transforms.Compose( |
|
[transforms.EnsureType(), transforms.AsDiscrete(argmax=True, to_onehot=output_classes)] |
|
) |
|
post_label = transforms.Compose([transforms.EnsureType(), transforms.AsDiscrete(to_onehot=output_classes)]) |
|
|
|
|
|
loss_func = parser.get_parsed_content("loss") |
|
|
|
|
|
optimizer = torch.optim.SGD( |
|
model.weight_parameters(), lr=learning_rate * world_size, momentum=0.9, weight_decay=0.00004 |
|
) |
|
arch_optimizer_a = torch.optim.Adam( |
|
[dints_space.log_alpha_a], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0 |
|
) |
|
arch_optimizer_c = torch.optim.Adam( |
|
[dints_space.log_alpha_c], lr=learning_rate_arch * world_size, betas=(0.5, 0.999), weight_decay=0.0 |
|
) |
|
|
|
if torch.cuda.device_count() > 1: |
|
model = DistributedDataParallel(model, device_ids=[device], find_unused_parameters=True) |
|
|
|
|
|
if amp: |
|
from torch.cuda.amp import GradScaler, autocast |
|
|
|
scaler = GradScaler() |
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
print("[info] amp enabled") |
|
|
|
|
|
val_interval = num_epochs_per_validation |
|
best_metric = -1 |
|
best_metric_epoch = -1 |
|
idx_iter = 0 |
|
|
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
writer = SummaryWriter(log_dir=os.path.join(arch_ckpt_path, "Events")) |
|
|
|
with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f: |
|
f.write("epoch\tmetric\tloss\tlr\ttime\titer\n") |
|
|
|
dataloader_a_iterator = iter(train_loader_a) |
|
|
|
start_time = time.time() |
|
for epoch in range(num_epochs): |
|
decay = 0.5 ** np.sum( |
|
[(epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) > learning_rate_milestones] |
|
) |
|
lr = learning_rate * decay * world_size |
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = lr |
|
|
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
print("-" * 10) |
|
print(f"epoch {epoch + 1}/{num_epochs}") |
|
print("learning rate is set to {}".format(lr)) |
|
|
|
model.train() |
|
epoch_loss = 0 |
|
loss_torch = torch.zeros(2, dtype=torch.float, device=device) |
|
epoch_loss_arch = 0 |
|
loss_torch_arch = torch.zeros(2, dtype=torch.float, device=device) |
|
step = 0 |
|
|
|
for batch_data in train_loader_w: |
|
step += 1 |
|
inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) |
|
if world_size == 1: |
|
for _ in model.weight_parameters(): |
|
_.requires_grad = True |
|
else: |
|
for _ in model.module.weight_parameters(): |
|
_.requires_grad = True |
|
dints_space.log_alpha_a.requires_grad = False |
|
dints_space.log_alpha_c.requires_grad = False |
|
|
|
optimizer.zero_grad() |
|
|
|
if amp: |
|
with autocast(): |
|
outputs = model(inputs) |
|
if output_classes == 2: |
|
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) |
|
else: |
|
loss = loss_func(outputs, labels) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
outputs = model(inputs) |
|
if output_classes == 2: |
|
loss = loss_func(torch.flip(outputs, dims=[1]), 1 - labels) |
|
else: |
|
loss = loss_func(outputs, labels) |
|
loss.backward() |
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
loss_torch[0] += loss.item() |
|
loss_torch[1] += 1.0 |
|
epoch_len = len(train_loader_w) |
|
idx_iter += 1 |
|
|
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
print("[{0}] ".format(str(datetime.now())[:19]) + f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") |
|
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) |
|
|
|
if epoch < num_epochs_warmup: |
|
continue |
|
|
|
try: |
|
sample_a = next(dataloader_a_iterator) |
|
except StopIteration: |
|
dataloader_a_iterator = iter(train_loader_a) |
|
sample_a = next(dataloader_a_iterator) |
|
inputs_search, labels_search = (sample_a["image"].to(device), sample_a["label"].to(device)) |
|
if world_size == 1: |
|
for _ in model.weight_parameters(): |
|
_.requires_grad = False |
|
else: |
|
for _ in model.module.weight_parameters(): |
|
_.requires_grad = False |
|
dints_space.log_alpha_a.requires_grad = True |
|
dints_space.log_alpha_c.requires_grad = True |
|
|
|
|
|
entropy_alpha_c = torch.tensor(0.0).to(device) |
|
entropy_alpha_a = torch.tensor(0.0).to(device) |
|
ram_cost_full = torch.tensor(0.0).to(device) |
|
ram_cost_usage = torch.tensor(0.0).to(device) |
|
ram_cost_loss = torch.tensor(0.0).to(device) |
|
topology_loss = torch.tensor(0.0).to(device) |
|
|
|
probs_a, arch_code_prob_a = dints_space.get_prob_a(child=True) |
|
entropy_alpha_a = -((probs_a) * torch.log(probs_a + 1e-5)).mean() |
|
entropy_alpha_c = -( |
|
F.softmax(dints_space.log_alpha_c, dim=-1) * F.log_softmax(dints_space.log_alpha_c, dim=-1) |
|
).mean() |
|
topology_loss = dints_space.get_topology_entropy(probs_a) |
|
|
|
ram_cost_full = dints_space.get_ram_cost_usage(inputs.shape, full=True) |
|
ram_cost_usage = dints_space.get_ram_cost_usage(inputs.shape) |
|
ram_cost_loss = torch.abs(ram_cost_factor - ram_cost_usage / ram_cost_full) |
|
|
|
arch_optimizer_a.zero_grad() |
|
arch_optimizer_c.zero_grad() |
|
|
|
combination_weights = (epoch - num_epochs_warmup) / (num_epochs - num_epochs_warmup) |
|
|
|
if amp: |
|
with autocast(): |
|
outputs_search = model(inputs_search) |
|
if output_classes == 2: |
|
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) |
|
else: |
|
loss = loss_func(outputs_search, labels_search) |
|
|
|
loss += combination_weights * ( |
|
(entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss |
|
) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(arch_optimizer_a) |
|
scaler.step(arch_optimizer_c) |
|
scaler.update() |
|
else: |
|
outputs_search = model(inputs_search) |
|
if output_classes == 2: |
|
loss = loss_func(torch.flip(outputs_search, dims=[1]), 1 - labels_search) |
|
else: |
|
loss = loss_func(outputs_search, labels_search) |
|
|
|
loss += 1.0 * ( |
|
combination_weights * (entropy_alpha_a + entropy_alpha_c) + ram_cost_loss + 0.001 * topology_loss |
|
) |
|
|
|
loss.backward() |
|
arch_optimizer_a.step() |
|
arch_optimizer_c.step() |
|
|
|
epoch_loss_arch += loss.item() |
|
loss_torch_arch[0] += loss.item() |
|
loss_torch_arch[1] += 1.0 |
|
|
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
print( |
|
"[{0}] ".format(str(datetime.now())[:19]) |
|
+ f"{step}/{epoch_len}, train_loss_arch: {loss.item():.4f}" |
|
) |
|
writer.add_scalar("train_loss_arch", loss.item(), epoch_len * epoch + step) |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
dist.barrier() |
|
dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM) |
|
|
|
loss_torch = loss_torch.tolist() |
|
loss_torch_arch = loss_torch_arch.tolist() |
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
loss_torch_epoch = loss_torch[0] / loss_torch[1] |
|
print( |
|
f"epoch {epoch + 1} average loss: {loss_torch_epoch:.4f}, " |
|
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" |
|
) |
|
|
|
if epoch >= num_epochs_warmup: |
|
loss_torch_arch_epoch = loss_torch_arch[0] / loss_torch_arch[1] |
|
print( |
|
f"epoch {epoch + 1} average arch loss: {loss_torch_arch_epoch:.4f}, " |
|
f"best mean dice: {best_metric:.4f} at epoch {best_metric_epoch}" |
|
) |
|
|
|
if (epoch + 1) % val_interval == 0 or (epoch + 1) == num_epochs: |
|
torch.cuda.empty_cache() |
|
model.eval() |
|
with torch.no_grad(): |
|
metric = torch.zeros((output_classes - 1) * 2, dtype=torch.float, device=device) |
|
metric_sum = 0.0 |
|
metric_count = 0 |
|
metric_mat = [] |
|
val_images = None |
|
val_labels = None |
|
val_outputs = None |
|
|
|
_index = 0 |
|
for val_data in val_loader: |
|
val_images = val_data["image"].to(device) |
|
val_labels = val_data["label"].to(device) |
|
|
|
roi_size = patch_size_valid |
|
sw_batch_size = num_sw_batch_size |
|
|
|
if amp: |
|
with torch.cuda.amp.autocast(): |
|
pred = sliding_window_inference( |
|
val_images, |
|
roi_size, |
|
sw_batch_size, |
|
lambda x: model(x), |
|
mode="gaussian", |
|
overlap=overlap_ratio, |
|
) |
|
else: |
|
pred = sliding_window_inference( |
|
val_images, |
|
roi_size, |
|
sw_batch_size, |
|
lambda x: model(x), |
|
mode="gaussian", |
|
overlap=overlap_ratio, |
|
) |
|
val_outputs = pred |
|
|
|
val_outputs = post_pred(val_outputs[0, ...]) |
|
val_outputs = val_outputs[None, ...] |
|
val_labels = post_label(val_labels[0, ...]) |
|
val_labels = val_labels[None, ...] |
|
|
|
value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False) |
|
|
|
print(_index + 1, "/", len(val_loader), value) |
|
|
|
metric_count += len(value) |
|
metric_sum += value.sum().item() |
|
metric_vals = value.cpu().numpy() |
|
if len(metric_mat) == 0: |
|
metric_mat = metric_vals |
|
else: |
|
metric_mat = np.concatenate((metric_mat, metric_vals), axis=0) |
|
|
|
for _c in range(output_classes - 1): |
|
val0 = torch.nan_to_num(value[0, _c], nan=0.0) |
|
val1 = 1.0 - torch.isnan(value[0, 0]).float() |
|
metric[2 * _c] += val0 * val1 |
|
metric[2 * _c + 1] += val1 |
|
|
|
_index += 1 |
|
|
|
|
|
if torch.cuda.device_count() > 1: |
|
dist.barrier() |
|
dist.all_reduce(metric, op=torch.distributed.ReduceOp.SUM) |
|
|
|
metric = metric.tolist() |
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
for _c in range(output_classes - 1): |
|
print("evaluation metric - class {0:d}:".format(_c + 1), metric[2 * _c] / metric[2 * _c + 1]) |
|
avg_metric = 0 |
|
for _c in range(output_classes - 1): |
|
avg_metric += metric[2 * _c] / metric[2 * _c + 1] |
|
avg_metric = avg_metric / float(output_classes - 1) |
|
print("avg_metric", avg_metric) |
|
|
|
if avg_metric > best_metric: |
|
best_metric = avg_metric |
|
best_metric_epoch = epoch + 1 |
|
best_metric_iterations = idx_iter |
|
|
|
(node_a_d, arch_code_a_d, arch_code_c_d, arch_code_a_max_d) = dints_space.decode() |
|
torch.save( |
|
{ |
|
"node_a": node_a_d, |
|
"arch_code_a": arch_code_a_d, |
|
"arch_code_a_max": arch_code_a_max_d, |
|
"arch_code_c": arch_code_c_d, |
|
"iter_num": idx_iter, |
|
"epochs": epoch + 1, |
|
"best_dsc": best_metric, |
|
"best_path": best_metric_iterations, |
|
}, |
|
os.path.join(arch_ckpt_path, "search_code_" + str(idx_iter) + ".pt"), |
|
) |
|
print("saved new best metric model") |
|
|
|
dict_file = {} |
|
dict_file["best_avg_dice_score"] = float(best_metric) |
|
dict_file["best_avg_dice_score_epoch"] = int(best_metric_epoch) |
|
dict_file["best_avg_dice_score_iteration"] = int(idx_iter) |
|
with open(os.path.join(arch_ckpt_path, "progress.yaml"), "w") as out_file: |
|
_ = yaml.dump(dict_file, stream=out_file) |
|
|
|
print( |
|
"current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format( |
|
epoch + 1, avg_metric, best_metric, best_metric_epoch |
|
) |
|
) |
|
|
|
current_time = time.time() |
|
elapsed_time = (current_time - start_time) / 60.0 |
|
with open(os.path.join(arch_ckpt_path, "accuracy_history.csv"), "a") as f: |
|
f.write( |
|
"{0:d}\t{1:.5f}\t{2:.5f}\t{3:.5f}\t{4:.1f}\t{5:d}\n".format( |
|
epoch + 1, avg_metric, loss_torch_epoch, lr, elapsed_time, idx_iter |
|
) |
|
) |
|
|
|
if torch.cuda.device_count() > 1: |
|
dist.barrier() |
|
|
|
torch.cuda.empty_cache() |
|
|
|
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") |
|
|
|
if torch.cuda.device_count() == 1 or dist.get_rank() == 0: |
|
writer.close() |
|
|
|
if torch.cuda.device_count() > 1: |
|
dist.destroy_process_group() |
|
|
|
|
|
if __name__ == "__main__": |
|
from monai.utils import optional_import |
|
|
|
fire, _ = optional_import("fire") |
|
fire.Fire() |
|
|