""" @author: louisblankemeier """ import inspect import os from typing import Dict, List from comp2comp.inference_class_base import InferenceClass from comp2comp.io.io import DicomLoader, NiftiSaver class InferencePipeline(InferenceClass): """Inference pipeline.""" def __init__(self, inference_classes: List = None, config: Dict = None): self.config = config # assign values from config to attributes if self.config is not None: for key, value in self.config.items(): setattr(self, key, value) self.inference_classes = inference_classes def __call__(self, inference_pipeline=None, **kwargs): # print out the class names for each inference class print("") print("Inference pipeline:") for i, inference_class in enumerate(self.inference_classes): print(f"({i + 1}) {inference_class.__repr__()}") print("") print("Starting inference pipeline.\n") if inference_pipeline: for key, value in kwargs.items(): setattr(inference_pipeline, key, value) else: for key, value in kwargs.items(): setattr(self, key, value) output = {} for inference_class in self.inference_classes: function_keys = set(inspect.signature(inference_class).parameters.keys()) function_keys.remove("inference_pipeline") if "kwargs" in function_keys: function_keys.remove("kwargs") assert function_keys == set( output.keys() ), "Input to inference class, {}, does not have the correct parameters".format( inference_class.__repr__() ) print( "Running {} with input keys {}".format( inference_class.__repr__(), inspect.signature(inference_class).parameters.keys(), ) ) if inference_pipeline: output = inference_class( inference_pipeline=inference_pipeline, **output ) else: output = inference_class(inference_pipeline=self, **output) # if not the last inference class, check that the output keys are correct if inference_class != self.inference_classes[-1]: print( "Finished {} with output keys {}\n".format( inference_class.__repr__(), output.keys() ) ) print("Inference pipeline finished.\n") return output if __name__ == "__main__": """Example usage of InferencePipeline.""" import argparse parser = argparse.ArgumentParser() parser.add_argument("--dicom_dir", type=str, required=True) args = parser.parse_args() output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../outputs") if not os.path.exists(output_dir): os.mkdir(output_dir) output_file_path = os.path.join(output_dir, "test.nii.gz") pipeline = InferencePipeline( [DicomLoader(args.dicom_dir), NiftiSaver()], config={"output_dir": output_file_path}, ) pipeline() print("Done.")