Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import threading | |
from torch._utils import ExceptionWrapper | |
import logging | |
import torch.nn.functional as F | |
def get_a_var(obj): | |
if isinstance(obj, torch.Tensor): | |
return obj | |
if isinstance(obj, list) or isinstance(obj, tuple): | |
for result in map(get_a_var, obj): | |
if isinstance(result, torch.Tensor): | |
return result | |
if isinstance(obj, dict): | |
for result in map(get_a_var, obj.items()): | |
if isinstance(result, torch.Tensor): | |
return result | |
return None | |
def parallel_apply(fct, model, inputs, device_ids): | |
modules = nn.parallel.replicate(model, device_ids) | |
assert len(modules) == len(inputs) | |
lock = threading.Lock() | |
results = {} | |
grad_enabled = torch.is_grad_enabled() | |
def _worker(i, module, input): | |
torch.set_grad_enabled(grad_enabled) | |
device = get_a_var(input).get_device() | |
try: | |
with torch.cuda.device(device): | |
# this also avoids accidental slicing of `input` if it is a Tensor | |
if not isinstance(input, (list, tuple)): | |
input = (input,) | |
output = fct(module, *input) | |
with lock: | |
results[i] = output | |
except Exception: | |
with lock: | |
results[i] = ExceptionWrapper(where="in replica {} on device {}".format(i, device)) | |
if len(modules) > 1: | |
threads = [threading.Thread(target=_worker, args=(i, module, input)) | |
for i, (module, input) in enumerate(zip(modules, inputs))] | |
for thread in threads: | |
thread.start() | |
for thread in threads: | |
thread.join() | |
else: | |
_worker(0, modules[0], inputs[0]) | |
outputs = [] | |
for i in range(len(inputs)): | |
output = results[i] | |
if isinstance(output, ExceptionWrapper): | |
output.reraise() | |
outputs.append(output) | |
return outputs | |
def get_logger(filename=None): | |
logger = logging.getLogger('logger') | |
logger.setLevel(logging.DEBUG) | |
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', | |
datefmt='%m/%d/%Y %H:%M:%S', | |
level=logging.INFO) | |
if filename is not None: | |
handler = logging.FileHandler(filename) | |
handler.setLevel(logging.DEBUG) | |
handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) | |
logging.getLogger().addHandler(handler) | |
return logger |