Spaces:
Sleeping
Sleeping
""" | |
@author: louisblankemeier | |
""" | |
import inspect | |
import os | |
from typing import Dict, List | |
from comp2comp.inference_class_base import InferenceClass | |
from comp2comp.io.io import DicomLoader, NiftiSaver | |
class InferencePipeline(InferenceClass): | |
"""Inference pipeline.""" | |
def __init__(self, inference_classes: List = None, config: Dict = None): | |
self.config = config | |
# assign values from config to attributes | |
if self.config is not None: | |
for key, value in self.config.items(): | |
setattr(self, key, value) | |
self.inference_classes = inference_classes | |
def __call__(self, inference_pipeline=None, **kwargs): | |
# print out the class names for each inference class | |
print("") | |
print("Inference pipeline:") | |
for i, inference_class in enumerate(self.inference_classes): | |
print(f"({i + 1}) {inference_class.__repr__()}") | |
print("") | |
print("Starting inference pipeline.\n") | |
if inference_pipeline: | |
for key, value in kwargs.items(): | |
setattr(inference_pipeline, key, value) | |
else: | |
for key, value in kwargs.items(): | |
setattr(self, key, value) | |
output = {} | |
for inference_class in self.inference_classes: | |
function_keys = set(inspect.signature(inference_class).parameters.keys()) | |
function_keys.remove("inference_pipeline") | |
if "kwargs" in function_keys: | |
function_keys.remove("kwargs") | |
assert function_keys == set( | |
output.keys() | |
), "Input to inference class, {}, does not have the correct parameters".format( | |
inference_class.__repr__() | |
) | |
print( | |
"Running {} with input keys {}".format( | |
inference_class.__repr__(), | |
inspect.signature(inference_class).parameters.keys(), | |
) | |
) | |
if inference_pipeline: | |
output = inference_class( | |
inference_pipeline=inference_pipeline, **output | |
) | |
else: | |
output = inference_class(inference_pipeline=self, **output) | |
# if not the last inference class, check that the output keys are correct | |
if inference_class != self.inference_classes[-1]: | |
print( | |
"Finished {} with output keys {}\n".format( | |
inference_class.__repr__(), output.keys() | |
) | |
) | |
print("Inference pipeline finished.\n") | |
return output | |
if __name__ == "__main__": | |
"""Example usage of InferencePipeline.""" | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--dicom_dir", type=str, required=True) | |
args = parser.parse_args() | |
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../outputs") | |
if not os.path.exists(output_dir): | |
os.mkdir(output_dir) | |
output_file_path = os.path.join(output_dir, "test.nii.gz") | |
pipeline = InferencePipeline( | |
[DicomLoader(args.dicom_dir), NiftiSaver()], | |
config={"output_dir": output_file_path}, | |
) | |
pipeline() | |
print("Done.") | |