Spaces:
Runtime error
Runtime error
import torch | |
import tops | |
import numpy as np | |
import io | |
import webdataset as wds | |
import os | |
from ..utils import png_decoder, get_num_workers, collate_fn | |
def kp_decoder(x): | |
# Keypoints are between [0, 1] for webdataset | |
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float().view(7, 2).clamp(0, 1) | |
keypoints = torch.cat((keypoints, torch.ones((7, 1))), dim=-1) | |
return keypoints | |
def bbox_decoder(x): | |
return torch.from_numpy(np.load(io.BytesIO(x))).float().view(4) | |
class BBoxToMask: | |
def __call__(self, sample): | |
imsize = sample["image.png"].shape[-1] | |
bbox = sample["bounding_box.npy"] * imsize | |
x0, y0, x1, y1 = np.round(bbox).astype(np.int64) | |
mask = torch.ones((1, imsize, imsize), dtype=torch.bool) | |
mask[:, y0:y1, x0:x1] = 0 | |
sample["mask"] = mask | |
return sample | |
def get_dataloader_fdf_wds( | |
path, | |
batch_size: int, | |
num_workers: int, | |
transform: torch.nn.Module, | |
gpu_transform: torch.nn.Module, | |
infinite: bool, | |
shuffle: bool, | |
partial_batches: bool, | |
sample_shuffle=10_000, | |
tar_shuffle=100, | |
channels_last=False, | |
): | |
# Need to set this for split_by_node to work. | |
os.environ["RANK"] = str(tops.rank()) | |
os.environ["WORLD_SIZE"] = str(tops.world_size()) | |
if infinite: | |
pipeline = [wds.ResampledShards(str(path))] | |
else: | |
pipeline = [wds.SimpleShardList(str(path))] | |
if shuffle: | |
pipeline.append(wds.shuffle(tar_shuffle)) | |
pipeline.extend([ | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
if shuffle: | |
pipeline.append(wds.shuffle(sample_shuffle)) | |
decoder = [ | |
wds.handle_extension("image.png", png_decoder), | |
wds.handle_extension("keypoints.npy", kp_decoder), | |
] | |
rename_keys = [ | |
["img", "image.png"], | |
["keypoints", "keypoints.npy"], | |
["__key__", "__key__"], | |
["mask", "mask"] | |
] | |
pipeline.extend([ | |
wds.tarfile_to_samples(), | |
wds.decode(*decoder), | |
]) | |
pipeline.append(wds.map(BBoxToMask())) | |
pipeline.extend([ | |
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), | |
wds.rename_keys(*rename_keys), | |
]) | |
if transform is not None: | |
pipeline.append(wds.map(transform)) | |
pipeline = wds.DataPipeline(*pipeline) | |
if infinite: | |
pipeline = pipeline.repeat(nepochs=1000000) | |
loader = wds.WebLoader( | |
pipeline, batch_size=None, shuffle=False, | |
num_workers=get_num_workers(num_workers), | |
persistent_workers=True, | |
) | |
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) | |
return loader | |