monai
medical
katielink commited on
Commit
c834ad4
·
1 Parent(s): 13ef090

update search function to match monai 1.2

Browse files
configs/metadata.json CHANGED
@@ -1,7 +1,8 @@
1
  {
2
  "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
- "version": "0.4.1",
4
  "changelog": {
 
5
  "0.4.1": "fix the wrong GPU index issue of multi-node",
6
  "0.4.0": "remove error dollar symbol in readme",
7
  "0.3.9": "add cpu ram requirement in readme",
 
1
  {
2
  "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.4.2",
4
  "changelog": {
5
+ "0.4.2": "update search function to match monai 1.2",
6
  "0.4.1": "fix the wrong GPU index issue of multi-node",
7
  "0.4.0": "remove error dollar symbol in readme",
8
  "0.3.9": "add cpu ram requirement in readme",
scripts/prepare_datalist.py CHANGED
@@ -11,11 +11,10 @@ def produce_sample_dict(line: str):
11
  return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
12
 
13
 
14
- def produce_datalist(dataset_dir: str):
15
  """
16
  This function is used to split the dataset.
17
- It will produce 200 samples for training, and the other samples are divided equally
18
- into val and test sets.
19
  """
20
 
21
  samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
@@ -23,7 +22,7 @@ def produce_datalist(dataset_dir: str):
23
  datalist = []
24
  for line in samples:
25
  datalist.append(produce_sample_dict(line))
26
- train_list, other_list = train_test_split(datalist, train_size=196)
27
  val_list, test_list = train_test_split(other_list, train_size=0.66)
28
 
29
  return {"training": train_list, "validation": val_list, "testing": test_list}
@@ -37,7 +36,7 @@ def main(args):
37
  output_json = args.output
38
  # produce deterministic data splits
39
  monai.utils.set_determinism(seed=123)
40
- datalist = produce_datalist(dataset_dir=data_file_base_dir)
41
  with open(output_json, "w") as f:
42
  json.dump(datalist, f, ensure_ascii=True, indent=4)
43
 
@@ -53,6 +52,7 @@ if __name__ == "__main__":
53
  parser.add_argument(
54
  "--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
55
  )
 
56
  args = parser.parse_args()
57
 
58
  main(args)
 
11
  return {"label": line, "image": line.replace("labelsTr", "imagesTr")}
12
 
13
 
14
+ def produce_datalist(dataset_dir: str, train_size: int = 196):
15
  """
16
  This function is used to split the dataset.
17
+ It will produce "train_size" number of samples for training.
 
18
  """
19
 
20
  samples = sorted(glob.glob(os.path.join(dataset_dir, "labelsTr", "*"), recursive=True))
 
22
  datalist = []
23
  for line in samples:
24
  datalist.append(produce_sample_dict(line))
25
+ train_list, other_list = train_test_split(datalist, train_size=train_size)
26
  val_list, test_list = train_test_split(other_list, train_size=0.66)
27
 
28
  return {"training": train_list, "validation": val_list, "testing": test_list}
 
36
  output_json = args.output
37
  # produce deterministic data splits
38
  monai.utils.set_determinism(seed=123)
39
+ datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size)
40
  with open(output_json, "w") as f:
41
  json.dump(datalist, f, ensure_ascii=True, indent=4)
42
 
 
52
  parser.add_argument(
53
  "--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
54
  )
55
+ parser.add_argument("--train_size", type=int, default=196, help="number of training samples.")
56
  args = parser.parse_args()
57
 
58
  main(args)
scripts/search.py CHANGED
@@ -28,7 +28,7 @@ from monai import transforms
28
  from monai.bundle import ConfigParser
29
  from monai.data import ThreadDataLoader, partition_dataset
30
  from monai.inferers import sliding_window_inference
31
- from monai.metrics import compute_meandice
32
  from monai.utils import set_determinism
33
  from torch.nn.parallel import DistributedDataParallel
34
  from torch.utils.tensorboard import SummaryWriter
@@ -100,14 +100,12 @@ def run(config_file: Union[str, Sequence[str]]):
100
  train_files_w = partition_dataset(
101
  data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
102
  )[dist.get_rank()]
103
- print("train_files_w:", len(train_files_w))
104
 
105
  train_files_a = train_files[len(train_files) // 2 :]
106
  if torch.cuda.device_count() > 1:
107
  train_files_a = partition_dataset(
108
  data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
109
  )[dist.get_rank()]
110
- print("train_files_a:", len(train_files_a))
111
 
112
  # validation data
113
  files = []
@@ -125,7 +123,6 @@ def run(config_file: Union[str, Sequence[str]]):
125
  val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
126
  dist.get_rank()
127
  ]
128
- print("val_files:", len(val_files))
129
 
130
  # network architecture
131
  if torch.cuda.device_count() > 1:
@@ -421,7 +418,7 @@ def run(config_file: Union[str, Sequence[str]]):
421
  val_labels = post_label(val_labels[0, ...])
422
  val_labels = val_labels[None, ...]
423
 
424
- value = compute_meandice(y_pred=val_outputs, y=val_labels, include_background=False)
425
 
426
  print(_index + 1, "/", len(val_loader), value)
427
 
 
28
  from monai.bundle import ConfigParser
29
  from monai.data import ThreadDataLoader, partition_dataset
30
  from monai.inferers import sliding_window_inference
31
+ from monai.metrics import compute_dice
32
  from monai.utils import set_determinism
33
  from torch.nn.parallel import DistributedDataParallel
34
  from torch.utils.tensorboard import SummaryWriter
 
100
  train_files_w = partition_dataset(
101
  data=train_files_w, shuffle=True, num_partitions=world_size, even_divisible=True
102
  )[dist.get_rank()]
 
103
 
104
  train_files_a = train_files[len(train_files) // 2 :]
105
  if torch.cuda.device_count() > 1:
106
  train_files_a = partition_dataset(
107
  data=train_files_a, shuffle=True, num_partitions=world_size, even_divisible=True
108
  )[dist.get_rank()]
 
109
 
110
  # validation data
111
  files = []
 
123
  val_files = partition_dataset(data=val_files, shuffle=False, num_partitions=world_size, even_divisible=False)[
124
  dist.get_rank()
125
  ]
 
126
 
127
  # network architecture
128
  if torch.cuda.device_count() > 1:
 
418
  val_labels = post_label(val_labels[0, ...])
419
  val_labels = val_labels[None, ...]
420
 
421
+ value = compute_dice(y_pred=val_outputs, y=val_labels, include_background=False)
422
 
423
  print(_index + 1, "/", len(val_loader), value)
424