|
import inspect |
|
import traceback |
|
|
|
import torch |
|
|
|
import vision_models |
|
|
|
consumers = dict() |
|
|
|
|
|
def load_models(): |
|
global consumers |
|
list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass) |
|
if issubclass(m[1], vision_models.BaseModel) and m[1] != vision_models.BaseModel] |
|
list_models.sort(key=lambda x: x.load_order) |
|
print("-" * 10, "List models", list_models) |
|
|
|
counter_ = 0 |
|
for model_class_ in list_models: |
|
print("-" * 10, "Now loading {}:".format(model_class_)) |
|
for process_name_ in model_class_.list_processes(): |
|
consumers[process_name_] = make_fn(model_class_, process_name_, counter_) |
|
counter_ += 1 |
|
print("-" * 10, "Loading {} finished. Current gpu:".format(model_class_)) |
|
print(torch.cuda.memory_summary()) |
|
|
|
print("-" * 10, "Model loading finished. Final gpu:") |
|
print(torch.cuda.memory_summary()) |
|
|
|
|
|
def make_fn(model_class, process_name, counter): |
|
""" |
|
model_class.name and process_name will be the same unless the same model is used in multiple processes, for |
|
different tasks |
|
""" |
|
|
|
num_gpus = torch.cuda.device_count() |
|
gpu_number = counter % num_gpus |
|
|
|
model_instance = model_class(gpu_number=gpu_number) |
|
|
|
def _function(*args, **kwargs): |
|
if process_name != model_class.name: |
|
kwargs['process_name'] = process_name |
|
|
|
if model_class.to_batch: |
|
|
|
args = [[arg] for arg in args] |
|
kwargs = {k: [v] for k, v in kwargs.items()} |
|
|
|
|
|
full_arg_spec = inspect.getfullargspec(model_instance.forward) |
|
if full_arg_spec.defaults is None: |
|
default_dict = {} |
|
else: |
|
default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults)) |
|
non_given_args = full_arg_spec.args[1:][len(args):] |
|
non_given_args = set(non_given_args) - set(kwargs.keys()) |
|
for arg_name in non_given_args: |
|
kwargs[arg_name] = [default_dict[arg_name]] |
|
|
|
try: |
|
out = model_instance.forward(*args, **kwargs) |
|
if model_class.to_batch: |
|
out = out[0] |
|
except Exception as e: |
|
print(f'Error in {process_name} model:', e) |
|
traceback.print_exc() |
|
out = None |
|
return out |
|
|
|
return _function |
|
|
|
|
|
def forward(model_name, *args, **kwargs): |
|
return consumers[model_name](*args, **kwargs) |
|
|