File size: 2,147 Bytes
864c14f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import os
from monai.utils import set_determinism
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import os 
from pytorch_lightning.loggers import TensorBoardLogger
from trainer import BRATS
from dataset.utils import get_loader
import pytorch_lightning as pl
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

set_determinism(seed=0)

os.system('cls||clear')
print("Training ...")

data_dir = "/app/brats_2021_task1"
json_list = "/app/info.json"
roi = (128, 128, 128)
batch_size = 1
fold = 1
max_epochs = 500
val_every = 10
train_loader, val_loader,test_loader = get_loader(batch_size, data_dir, json_list, fold, roi, volume=1, test_size=0.2)
print("Done initialize dataloader !! ")

model = BRATS(use_VAE = True, train_loader = train_loader,val_loader = val_loader, test_loader=test_loader )
checkpoint_callback = ModelCheckpoint(
    monitor='val/MeanDiceScore',
    dirpath='./checkpoints/{}'.format("SegTransVAE"),
    filename='Epoch{epoch:3d}-MeanDiceScore{val/MeanDiceScore:.4f}',
    save_top_k=3,
    mode='max',
    save_last= True,
    auto_insert_metric_name=False
)
early_stop_callback = EarlyStopping(
   monitor='val/MeanDiceScore',
   min_delta=0.0001,
   patience=15,
   verbose=False,
   mode='max'
)
tensorboardlogger = TensorBoardLogger(
    'logs', 
    name = "SegTransVAE", 
    default_hp_metric = None 
)
trainer = pl.Trainer(#fast_dev_run = 10, 
#                     accelerator='ddp',
                    #overfit_batches=5,
                     devices = [0], 
                        precision=16,
                     max_epochs = max_epochs, 
                     enable_progress_bar=True,  
                     callbacks=[checkpoint_callback, early_stop_callback], 
#                     auto_lr_find=True,
                    num_sanity_val_steps=1,
                    logger = tensorboardlogger,
                    check_val_every_n_epoch = 10,
#                     limit_train_batches=0.01, 
#                     limit_val_batches=0.01
                     )
# trainer.tune(model)
trainer.fit(model)