VDebugger-generalist-for-VQA / vision_processes.py
Xueqing Wu
init
e20ef71
import inspect
import traceback
import torch
import vision_models
consumers = dict()
def load_models():
global consumers
list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass)
if issubclass(m[1], vision_models.BaseModel) and m[1] != vision_models.BaseModel]
list_models.sort(key=lambda x: x.load_order)
print("-" * 10, "List models", list_models)
counter_ = 0
for model_class_ in list_models:
print("-" * 10, "Now loading {}:".format(model_class_))
for process_name_ in model_class_.list_processes():
consumers[process_name_] = make_fn(model_class_, process_name_, counter_)
counter_ += 1
print("-" * 10, "Loading {} finished. Current gpu:".format(model_class_))
print(torch.cuda.memory_summary())
print("-" * 10, "Model loading finished. Final gpu:")
print(torch.cuda.memory_summary())
def make_fn(model_class, process_name, counter):
"""
model_class.name and process_name will be the same unless the same model is used in multiple processes, for
different tasks
"""
# We initialize each one on a separate GPU, to make sure there are no out of memory errors
num_gpus = torch.cuda.device_count()
gpu_number = counter % num_gpus
model_instance = model_class(gpu_number=gpu_number)
def _function(*args, **kwargs):
if process_name != model_class.name:
kwargs['process_name'] = process_name
if model_class.to_batch:
# Batchify the input. Model expects a batch. And later un-batchify the output.
args = [[arg] for arg in args]
kwargs = {k: [v] for k, v in kwargs.items()}
# The defaults that are not in args or kwargs, also need to listify
full_arg_spec = inspect.getfullargspec(model_instance.forward)
if full_arg_spec.defaults is None:
default_dict = {}
else:
default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults))
non_given_args = full_arg_spec.args[1:][len(args):]
non_given_args = set(non_given_args) - set(kwargs.keys())
for arg_name in non_given_args:
kwargs[arg_name] = [default_dict[arg_name]]
try:
out = model_instance.forward(*args, **kwargs)
if model_class.to_batch:
out = out[0]
except Exception as e:
print(f'Error in {process_name} model:', e)
traceback.print_exc()
out = None
return out
return _function
def forward(model_name, *args, **kwargs):
return consumers[model_name](*args, **kwargs)