update search function to match monai 1.2
Browse files- configs/metadata.json +2 -1
- scripts/prepare_datalist.py +5 -5
- scripts/search.py +2 -5
configs/metadata.json
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
{
|
2 |
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
3 |
-
"version": "0.4.
|
4 |
"changelog": {
|
|
|
5 |
"0.4.1": "fix the wrong GPU index issue of multi-node",
|
6 |
"0.4.0": "remove error dollar symbol in readme",
|
7 |
"0.3.9": "add cpu ram requirement in readme",
|
|
|
1 |
{
|
2 |
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
3 |
+
"version": "0.4.2",
|
4 |
"changelog": {
|
5 |
+
"0.4.2": "update search function to match monai 1.2",
|
6 |
"0.4.1": "fix the wrong GPU index issue of multi-node",
|
7 |
"0.4.0": "remove error dollar symbol in readme",
|
8 |
"0.3.9": "add cpu ram requirement in readme",
|
scripts/prepare_datalist.py
CHANGED
@@ -11,11 +11,10 @@ def produce_sample_dict(line: str):
|
|
11 |
return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
|
12 |
|
13 |
|
14 |
-
def produce_datalist(dataset_dir: str):
|
15 |
"""
|
16 |
This function is used to split the dataset.
|
17 |
-
It will produce
|
18 |
-
into val and test sets.
|
19 |
"""
|
20 |
|
21 |
samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
|
@@ -23,7 +22,7 @@ def produce_datalist(dataset_dir: str):
|
|
23 |
datalist = []
|
24 |
for line in samples:
|
25 |
datalist.append(produce_sample_dict(line))
|
26 |
-
train_list, other_list = train_test_split(datalist, train_size=
|
27 |
val_list, test_list = train_test_split(other_list, train_size=0.66)
|
28 |
|
29 |
return {"training": train_list, "validation": val_list, "testing": test_list}
|
@@ -37,7 +36,7 @@ def main(args):
|
|
37 |
output_json = args.output
|
38 |
# produce deterministic data splits
|
39 |
monai.utils.set_determinism(seed=123)
|
40 |
-
datalist = produce_datalist(dataset_dir=data_file_base_dir)
|
41 |
with open(output_json, "w") as f:
|
42 |
json.dump(datalist, f, ensure_ascii=True, indent=4)
|
43 |
|
@@ -53,6 +52,7 @@ if __name__ == "__main__":
|
|
53 |
parser.add_argument(
|
54 |
"--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
|
55 |
)
|
|
|
56 |
args = parser.parse_args()
|
57 |
|
58 |
main(args)
|
|
|
11 |
return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
|
12 |
|
13 |
|
14 |
+
def produce_datalist(dataset_dir: str, train_size: int = 196):
|
15 |
"""
|
16 |
This function is used to split the dataset.
|
17 |
+
It will produce "train_size" number of samples for training.
|
|
|
18 |
"""
|
19 |
|
20 |
samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
|
|
|
22 |
datalist = []
|
23 |
for line in samples:
|
24 |
datalist.append(produce_sample_dict(line))
|
25 |
+
train_list, other_list = train_test_split(datalist, train_size=train_size)
|
26 |
val_list, test_list = train_test_split(other_list, train_size=0.66)
|
27 |
|
28 |
return {"training": train_list, "validation": val_list, "testing": test_list}
|
|
|
36 |
output_json = args.output
|
37 |
# produce deterministic data splits
|
38 |
monai.utils.set_determinism(seed=123)
|
39 |
+
datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size)
|
40 |
with open(output_json, "w") as f:
|
41 |
json.dump(datalist, f, ensure_ascii=True, indent=4)
|
42 |
|
|
|
52 |
parser.add_argument(
|
53 |
"--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
|
54 |
)
|
55 |
+
parser.add_argument("--train_size", type=int, default=196, help="number of training samples.")
|
56 |
args = parser.parse_args()
|
57 |
|
58 |
main(args)
|
scripts/search.py
CHANGED
@@ -28,7 +28,7 @@ from monai import transforms
|
|
28 |
from monai.bundle import ConfigParser
|
29 |
from monai.data import ThreadDataLoader, partition_dataset
|
30 |
from monai.inferers import sliding_window_inference
|
31 |
-
from monai.metrics import
|
32 |
from monai.utils import set_determinism
|
33 |
from torch.nn.parallel import DistributedDataParallel
|
34 |
from torch.utils.tensorboard import SummaryWriter
|
@@ -100,14 +100,12 @@ def run(config_file: Union[str, Sequence[str]]):
|
|
100 |
train_files_w = partition_dataset(
|
101 |
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
|
102 |
)[dist.get_rank()]
|
103 |
-
print("train_files_w:", len(train_files_w))
|
104 |
|
105 |
train_files_a = train_files[len(train_files) // 2 :]
|
106 |
if torch.cuda.device_count() > 1:
|
107 |
train_files_a = partition_dataset(
|
108 |
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
|
109 |
)[dist.get_rank()]
|
110 |
-
print("train_files_a:", len(train_files_a))
|
111 |
|
112 |
# validation data
|
113 |
files = []
|
@@ -125,7 +123,6 @@ def run(config_file: Union[str, Sequence[str]]):
|
|
125 |
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
|
126 |
dist.get_rank()
|
127 |
]
|
128 |
-
print("val_files:", len(val_files))
|
129 |
|
130 |
# network architecture
|
131 |
if torch.cuda.device_count() > 1:
|
@@ -421,7 +418,7 @@ def run(config_file: Union[str, Sequence[str]]):
|
|
421 |
val_labels = post_label(val_labels[0, ...])
|
422 |
val_labels = val_labels[None, ...]
|
423 |
|
424 |
-
value =
|
425 |
|
426 |
print(_index + 1, "/", len(val_loader), value)
|
427 |
|
|
|
28 |
from monai.bundle import ConfigParser
|
29 |
from monai.data import ThreadDataLoader, partition_dataset
|
30 |
from monai.inferers import sliding_window_inference
|
31 |
+
from monai.metrics import compute_dice
|
32 |
from monai.utils import set_determinism
|
33 |
from torch.nn.parallel import DistributedDataParallel
|
34 |
from torch.utils.tensorboard import SummaryWriter
|
|
|
100 |
train_files_w = partition_dataset(
|
101 |
data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
|
102 |
)[dist.get_rank()]
|
|
|
103 |
|
104 |
train_files_a = train_files[len(train_files) // 2 :]
|
105 |
if torch.cuda.device_count() > 1:
|
106 |
train_files_a = partition_dataset(
|
107 |
data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
|
108 |
)[dist.get_rank()]
|
|
|
109 |
|
110 |
# validation data
|
111 |
files = []
|
|
|
123 |
val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
|
124 |
dist.get_rank()
|
125 |
]
|
|
|
126 |
|
127 |
# network architecture
|
128 |
if torch.cuda.device_count() > 1:
|
|
|
418 |
val_labels = post_label(val_labels[0, ...])
|
419 |
val_labels = val_labels[None, ...]
|
420 |
|
421 |
+
value = compute_dice(y_pred=val_outputs, y=val_labels, include_background=False)
|
422 |
|
423 |
print(_index + 1, "/", len(val_loader), value)
|
424 |
|