AdritRao's picture
Upload 62 files
a3290d1
raw
history blame
3.85 kB
"""
@author: louisblankemeier
"""
import os
import shutil
import sys
import time
import traceback
from datetime import datetime
from pathlib import Path
from comp2comp.io import io_utils
def process_2d(args, pipeline_builder):
output_dir = Path(
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../outputs",
datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
)
)
if not os.path.exists(output_dir):
output_dir.mkdir(parents=True)
model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models")
if not os.path.exists(model_dir):
os.mkdir(model_dir)
pipeline = pipeline_builder(args)
pipeline(output_dir=output_dir, model_dir=model_dir)
def process_3d(args, pipeline_builder):
model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models")
if not os.path.exists(model_dir):
os.mkdir(model_dir)
if args.output_path is not None:
output_path = Path(args.output_path)
else:
output_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "../../outputs"
)
if not args.overwrite_outputs:
date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_path = os.path.join(output_path, date_time)
for path, num in io_utils.get_dicom_or_nifti_paths_and_num(args.input_path):
try:
st = time.time()
if path.endswith(".nii") or path.endswith(".nii.gz"):
print("Processing: ", path)
else:
print("Processing: ", path, " with ", num, " slices")
min_slices = 30
if num < min_slices:
print(f"Number of slices is less than {min_slices}, skipping\n")
continue
print("")
try:
sys.stdout.flush()
except Exception:
pass
if path.endswith(".nii") or path.endswith(".nii.gz"):
folder_name = Path(os.path.basename(os.path.normpath(path)))
# remove .nii or .nii.gz
folder_name = os.path.normpath(
Path(str(folder_name).replace(".gz", "").replace(".nii", ""))
)
output_dir = Path(
os.path.join(
output_path,
folder_name,
)
)
else:
output_dir = Path(
os.path.join(
output_path,
Path(os.path.basename(os.path.normpath(args.input_path))),
os.path.relpath(
os.path.normpath(path), os.path.normpath(args.input_path)
),
)
)
if not os.path.exists(output_dir):
output_dir.mkdir(parents=True)
pipeline = pipeline_builder(path, args)
pipeline(output_dir=output_dir, model_dir=model_dir)
if not args.save_segmentations:
# remove the segmentations folder
segmentations_dir = os.path.join(output_dir, "segmentations")
if os.path.exists(segmentations_dir):
shutil.rmtree(segmentations_dir)
print(f"Finished processing {path} in {time.time() - st:.1f} seconds\n")
except Exception:
print(f"ERROR PROCESSING {path}\n")
traceback.print_exc()
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# remove parent folder if empty
if len(os.listdir(os.path.dirname(output_dir))) == 0:
shutil.rmtree(os.path.dirname(output_dir))
continue