vierundvi / VISAM /datasets /data_prefetcher.py
mart9992's picture
m
2cd560a
raw
history blame
4.71 kB
# ------------------------------------------------------------------------
# Copyright (c) 2022 megvii-research. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
import torch
from functools import partial
from models.structures import Instances
def to_cuda(samples, targets, device):
samples = samples.to(device, non_blocking=True)
targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets]
return samples, targets
def tensor_to_cuda(tensor: torch.Tensor, device):
return tensor.to(device)
def is_tensor_or_instances(data):
return isinstance(data, torch.Tensor) or isinstance(data, Instances)
def data_apply(data, check_func, apply_func):
if isinstance(data, dict):
for k in data.keys():
if check_func(data[k]):
data[k] = apply_func(data[k])
elif isinstance(data[k], dict) or isinstance(data[k], list):
data_apply(data[k], check_func, apply_func)
else:
raise ValueError()
elif isinstance(data, list):
for i in range(len(data)):
if check_func(data[i]):
data[i] = apply_func(data[i])
elif isinstance(data[i], dict) or isinstance(data[i], list):
data_apply(data[i], check_func, apply_func)
else:
raise ValueError("invalid type {}".format(type(data[i])))
else:
raise ValueError("invalid type {}".format(type(data)))
return data
def data_dict_to_cuda(data_dict, device):
return data_apply(data_dict, is_tensor_or_instances, partial(tensor_to_cuda, device=device))
class data_prefetcher():
def __init__(self, loader, device, prefetch=True):
self.loader = iter(loader)
self.prefetch = prefetch
self.device = device
if prefetch:
self.stream = torch.cuda.Stream()
self.preload()
def preload(self):
try:
self.next_samples, self.next_targets = next(self.loader)
except StopIteration:
self.next_samples = None
self.next_targets = None
return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream):
self.next_samples, self.next_targets = to_cuda(self.next_samples, self.next_targets, self.device)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half.
# if args.fp16:
# self.next_input = self.next_input.half()
# else:
def next(self):
if self.prefetch:
torch.cuda.current_stream().wait_stream(self.stream)
samples = self.next_samples
targets = self.next_targets
if samples is not None:
samples.record_stream(torch.cuda.current_stream())
if targets is not None:
for t in targets:
for k, v in t.items():
v.record_stream(torch.cuda.current_stream())
self.preload()
else:
try:
samples, targets = next(self.loader)
samples, targets = to_cuda(samples, targets, self.device)
except StopIteration:
print("catch_stop_iter")
samples = None
targets = None
return samples, targets