File size: 3,279 Bytes
a3290d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
@author: louisblankemeier
"""

import inspect
import os
from typing import Dict, List

from comp2comp.inference_class_base import InferenceClass
from comp2comp.io.io import DicomLoader, NiftiSaver


class InferencePipeline(InferenceClass):
    """Inference pipeline."""

    def __init__(self, inference_classes: List = None, config: Dict = None):
        self.config = config
        # assign values from config to attributes
        if self.config is not None:
            for key, value in self.config.items():
                setattr(self, key, value)

        self.inference_classes = inference_classes

    def __call__(self, inference_pipeline=None, **kwargs):
        # print out the class names for each inference class
        print("")
        print("Inference pipeline:")
        for i, inference_class in enumerate(self.inference_classes):
            print(f"({i + 1}) {inference_class.__repr__()}")
        print("")

        print("Starting inference pipeline.\n")

        if inference_pipeline:
            for key, value in kwargs.items():
                setattr(inference_pipeline, key, value)
        else:
            for key, value in kwargs.items():
                setattr(self, key, value)

        output = {}
        for inference_class in self.inference_classes:
            function_keys = set(inspect.signature(inference_class).parameters.keys())
            function_keys.remove("inference_pipeline")

            if "kwargs" in function_keys:
                function_keys.remove("kwargs")

            assert function_keys == set(
                output.keys()
            ), "Input to inference class, {}, does not have the correct parameters".format(
                inference_class.__repr__()
            )

            print(
                "Running {} with input keys {}".format(
                    inference_class.__repr__(),
                    inspect.signature(inference_class).parameters.keys(),
                )
            )

            if inference_pipeline:
                output = inference_class(
                    inference_pipeline=inference_pipeline, **output
                )
            else:
                output = inference_class(inference_pipeline=self, **output)

            # if not the last inference class, check that the output keys are correct
            if inference_class != self.inference_classes[-1]:
                print(
                    "Finished {} with output keys {}\n".format(
                        inference_class.__repr__(), output.keys()
                    )
                )

        print("Inference pipeline finished.\n")

        return output


if __name__ == "__main__":
    """Example usage of InferencePipeline."""
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--dicom_dir", type=str, required=True)
    args = parser.parse_args()

    output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../outputs")
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
    output_file_path = os.path.join(output_dir, "test.nii.gz")

    pipeline = InferencePipeline(
        [DicomLoader(args.dicom_dir), NiftiSaver()],
        config={"output_dir": output_file_path},
    )
    pipeline()

    print("Done.")