File size: 2,737 Bytes
e20ef71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import inspect
import traceback

import torch

import vision_models

consumers = dict()


def load_models():
    global consumers
    list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass)
                   if issubclass(m[1], vision_models.BaseModel) and m[1] != vision_models.BaseModel]
    list_models.sort(key=lambda x: x.load_order)
    print("-" * 10, "List models", list_models)

    counter_ = 0
    for model_class_ in list_models:
        print("-" * 10, "Now loading {}:".format(model_class_))
        for process_name_ in model_class_.list_processes():
            consumers[process_name_] = make_fn(model_class_, process_name_, counter_)
            counter_ += 1
        print("-" * 10, "Loading {} finished. Current gpu:".format(model_class_))
        print(torch.cuda.memory_summary())

    print("-" * 10, "Model loading finished. Final gpu:")
    print(torch.cuda.memory_summary())


def make_fn(model_class, process_name, counter):
    """
    model_class.name and process_name will be the same unless the same model is used in multiple processes, for
    different tasks
    """
    # We initialize each one on a separate GPU, to make sure there are no out of memory errors
    num_gpus = torch.cuda.device_count()
    gpu_number = counter % num_gpus

    model_instance = model_class(gpu_number=gpu_number)

    def _function(*args, **kwargs):
        if process_name != model_class.name:
            kwargs['process_name'] = process_name

        if model_class.to_batch:
            # Batchify the input. Model expects a batch. And later un-batchify the output.
            args = [[arg] for arg in args]
            kwargs = {k: [v] for k, v in kwargs.items()}

            # The defaults that are not in args or kwargs, also need to listify
            full_arg_spec = inspect.getfullargspec(model_instance.forward)
            if full_arg_spec.defaults is None:
                default_dict = {}
            else:
                default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults))
            non_given_args = full_arg_spec.args[1:][len(args):]
            non_given_args = set(non_given_args) - set(kwargs.keys())
            for arg_name in non_given_args:
                kwargs[arg_name] = [default_dict[arg_name]]

        try:
            out = model_instance.forward(*args, **kwargs)
            if model_class.to_batch:
                out = out[0]
        except Exception as e:
            print(f'Error in {process_name} model:', e)
            traceback.print_exc()
            out = None
        return out

    return _function


def forward(model_name, *args, **kwargs):
    return consumers[model_name](*args, **kwargs)