AdritRao's picture
Upload 62 files
a3290d1
raw
history blame
8.52 kB
#!/usr/bin/env python
import argparse
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
from comp2comp.aaa import aaa
from comp2comp.aortic_calcium import (
aortic_calcium,
aortic_calcium_visualization,
)
from comp2comp.contrast_phase.contrast_phase import ContrastPhaseDetection
from comp2comp.hip import hip
from comp2comp.inference_pipeline import InferencePipeline
from comp2comp.io import io
from comp2comp.liver_spleen_pancreas import (
liver_spleen_pancreas,
liver_spleen_pancreas_visualization,
)
from comp2comp.muscle_adipose_tissue import (
muscle_adipose_tissue,
muscle_adipose_tissue_visualization,
)
from comp2comp.spine import spine
from comp2comp.utils import orientation
from comp2comp.utils.process import process_3d
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
### AAA Pipeline
def AAAPipelineBuilder(path, args):
pipeline = InferencePipeline(
[
AxialCropperPipelineBuilder(path, args),
aaa.AortaSegmentation(),
aaa.AortaDiameter(),
aaa.AortaMetricsSaver()
]
)
return pipeline
def MuscleAdiposeTissuePipelineBuilder(args):
pipeline = InferencePipeline(
[
muscle_adipose_tissue.MuscleAdiposeTissueSegmentation(
16, args.muscle_fat_model
),
muscle_adipose_tissue.MuscleAdiposeTissuePostProcessing(),
muscle_adipose_tissue.MuscleAdiposeTissueComputeMetrics(),
muscle_adipose_tissue_visualization.MuscleAdiposeTissueVisualizer(),
muscle_adipose_tissue.MuscleAdiposeTissueH5Saver(),
muscle_adipose_tissue.MuscleAdiposeTissueMetricsSaver(),
]
)
return pipeline
def MuscleAdiposeTissueFullPipelineBuilder(args):
pipeline = InferencePipeline(
[io.DicomFinder(args.input_path), MuscleAdiposeTissuePipelineBuilder(args)]
)
return pipeline
def SpinePipelineBuilder(path, args):
pipeline = InferencePipeline(
[
io.DicomToNifti(path),
spine.SpineSegmentation(args.spine_model, save=True),
orientation.ToCanonical(),
spine.SpineComputeROIs(args.spine_model),
spine.SpineMetricsSaver(),
spine.SpineCoronalSagittalVisualizer(format="png"),
spine.SpineReport(format="png"),
]
)
return pipeline
def AxialCropperPipelineBuilder(path, args):
pipeline = InferencePipeline(
[
io.DicomToNifti(path),
spine.SpineSegmentation(args.spine_model),
orientation.ToCanonical(),
spine.AxialCropper(lower_level="L5", upper_level="L1", save=True),
]
)
return pipeline
def SpineMuscleAdiposeTissuePipelineBuilder(path, args):
pipeline = InferencePipeline(
[
SpinePipelineBuilder(path, args),
spine.SpineFindDicoms(),
MuscleAdiposeTissuePipelineBuilder(args),
spine.SpineMuscleAdiposeTissueReport(),
]
)
return pipeline
def LiverSpleenPancreasPipelineBuilder(path, args):
pipeline = InferencePipeline(
[
io.DicomToNifti(path),
liver_spleen_pancreas.LiverSpleenPancreasSegmentation(),
orientation.ToCanonical(),
liver_spleen_pancreas_visualization.LiverSpleenPancreasVisualizer(),
liver_spleen_pancreas_visualization.LiverSpleenPancreasMetricsPrinter(),
]
)
return pipeline
def AorticCalciumPipelineBuilder(path, args):
pipeline = InferencePipeline(
[
io.DicomToNifti(path),
spine.SpineSegmentation(model_name="ts_spine"),
orientation.ToCanonical(),
aortic_calcium.AortaSegmentation(),
orientation.ToCanonical(),
aortic_calcium.AorticCalciumSegmentation(),
aortic_calcium.AorticCalciumMetrics(),
aortic_calcium_visualization.AorticCalciumVisualizer(),
aortic_calcium_visualization.AorticCalciumPrinter(),
]
)
return pipeline
def ContrastPhasePipelineBuilder(path, args):
pipeline = InferencePipeline([io.DicomToNifti(path), ContrastPhaseDetection(path)])
return pipeline
def HipPipelineBuilder(path, args):
pipeline = InferencePipeline(
[
io.DicomToNifti(path),
hip.HipSegmentation(args.hip_model),
orientation.ToCanonical(),
hip.HipComputeROIs(args.hip_model),
hip.HipMetricsSaver(),
hip.HipVisualizer(),
]
)
return pipeline
def AllPipelineBuilder(path, args):
pipeline = InferencePipeline(
[
io.DicomToNifti(path),
SpineMuscleAdiposeTissuePipelineBuilder(path, args),
LiverSpleenPancreasPipelineBuilder(path, args),
HipPipelineBuilder(path, args),
]
)
return pipeline
def argument_parser():
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument("--input_path", "-i", type=str, required=True)
base_parser.add_argument("--output_path", "-o", type=str)
base_parser.add_argument("--save_segmentations", action="store_true")
base_parser.add_argument("--overwrite_outputs", action="store_true")
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="pipeline", help="Pipeline to run")
# Add the help option to each subparser
muscle_adipose_tissue_parser = subparsers.add_parser(
"muscle_adipose_tissue", parents=[base_parser]
)
muscle_adipose_tissue_parser.add_argument(
"--muscle_fat_model", default="abCT_v0.0.1", type=str
)
# Spine
spine_parser = subparsers.add_parser("spine", parents=[base_parser])
spine_parser.add_argument("--spine_model", default="ts_spine", type=str)
# Spine + muscle + fat
spine_muscle_adipose_tissue_parser = subparsers.add_parser(
"spine_muscle_adipose_tissue", parents=[base_parser]
)
spine_muscle_adipose_tissue_parser.add_argument(
"--muscle_fat_model", default="stanford_v0.0.2", type=str
)
spine_muscle_adipose_tissue_parser.add_argument(
"--spine_model", default="ts_spine", type=str
)
liver_spleen_pancreas = subparsers.add_parser(
"liver_spleen_pancreas", parents=[base_parser]
)
aortic_calcium = subparsers.add_parser("aortic_calcium", parents=[base_parser])
contrast_phase_parser = subparsers.add_parser(
"contrast_phase", parents=[base_parser]
)
hip_parser = subparsers.add_parser("hip", parents=[base_parser])
hip_parser.add_argument(
"--hip_model",
default="ts_hip",
type=str,
)
# AAA
aorta_diameter_parser = subparsers.add_parser("aaa", help="aorta diameter", parents=[base_parser])
aorta_diameter_parser.add_argument(
"--aorta_model",
default="ts_spine",
type=str,
help="aorta model to use for inference",
)
aorta_diameter_parser.add_argument(
"--spine_model",
default="ts_spine",
type=str,
help="spine model to use for inference",
)
all_parser = subparsers.add_parser("all", parents=[base_parser])
all_parser.add_argument(
"--muscle_fat_model",
default="abCT_v0.0.1",
type=str,
)
all_parser.add_argument(
"--spine_model",
default="ts_spine",
type=str,
)
all_parser.add_argument(
"--hip_model",
default="ts_hip",
type=str,
)
return parser
def main():
args = argument_parser().parse_args()
if args.pipeline == "spine_muscle_adipose_tissue":
process_3d(args, SpineMuscleAdiposeTissuePipelineBuilder)
elif args.pipeline == "spine":
process_3d(args, SpinePipelineBuilder)
elif args.pipeline == "contrast_phase":
process_3d(args, ContrastPhasePipelineBuilder)
elif args.pipeline == "liver_spleen_pancreas":
process_3d(args, LiverSpleenPancreasPipelineBuilder)
elif args.pipeline == "aortic_calcium":
process_3d(args, AorticCalciumPipelineBuilder)
elif args.pipeline == "hip":
process_3d(args, HipPipelineBuilder)
elif args.pipeline == "aaa":
process_3d(args, AAAPipelineBuilder)
elif args.pipeline == "all":
process_3d(args, AllPipelineBuilder)
else:
raise AssertionError("{} command not supported".format(args.action))
if __name__ == "__main__":
main()