Spaces:
Runtime error
Runtime error
File size: 4,916 Bytes
5d756f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import tops
import numpy as np
import io
import webdataset as wds
import os
import json
from pathlib import Path
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
def kp_decoder(x):
# Keypoints are between [0, 1] for webdataset
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
def check_outside(x): return (x < 0).logical_or(x > 1)
is_outside = check_outside(keypoints[:, 0]).logical_or(
check_outside(keypoints[:, 1])
)
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
return keypoints
def vertices_decoder(x):
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
return vertices.squeeze()[None]
class InsertNewKeypoints:
def __init__(self, keypoints_path: Path) -> None:
with open(keypoints_path, "r") as fp:
self.keypoints = json.load(fp)
def __call__(self, sample):
key = sample["__key__"]
keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
def check_outside(x): return (x < 0).logical_or(x > 1)
is_outside = check_outside(keypoints[:, 0]).logical_or(
check_outside(keypoints[:, 1])
)
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
sample["keypoints.npy"] = keypoints
return sample
def get_dataloader_fdh_wds(
path,
batch_size: int,
num_workers: int,
transform: torch.nn.Module,
gpu_transform: torch.nn.Module,
infinite: bool,
shuffle: bool,
partial_batches: bool,
load_embedding: bool,
sample_shuffle=10_000,
tar_shuffle=100,
read_condition=False,
channels_last=False,
load_new_keypoints=False,
keypoints_split=None,
):
# Need to set this for split_by_node to work.
os.environ["RANK"] = str(tops.rank())
os.environ["WORLD_SIZE"] = str(tops.world_size())
if infinite:
pipeline = [wds.ResampledShards(str(path))]
else:
pipeline = [wds.SimpleShardList(str(path))]
if shuffle:
pipeline.append(wds.shuffle(tar_shuffle))
pipeline.extend([
wds.split_by_node,
wds.split_by_worker,
])
if shuffle:
pipeline.append(wds.shuffle(sample_shuffle))
decoder = [
wds.handle_extension("image.png", png_decoder),
wds.handle_extension("mask.png", mask_decoder),
wds.handle_extension("maskrcnn_mask.png", mask_decoder),
wds.handle_extension("keypoints.npy", kp_decoder),
]
rename_keys = [
["img", "image.png"], ["mask", "mask.png"],
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
["__key__", "__key__"]
]
if load_embedding:
decoder.extend([
wds.handle_extension("vertices.npy", vertices_decoder),
wds.handle_extension("E_mask.png", mask_decoder)
])
rename_keys.extend([
["vertices", "vertices.npy"],
["E_mask", "e_mask.png"]
])
if read_condition:
decoder.append(
wds.handle_extension("condition.png", png_decoder)
)
rename_keys.append(["condition", "condition.png"])
pipeline.extend([
wds.tarfile_to_samples(),
wds.decode(*decoder),
])
if load_new_keypoints:
assert keypoints_split in ["train", "val"]
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
file_name = "fdh_keypoints_val-050133b34d.json"
if keypoints_split == "train":
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
file_name = "fdh_keypoints_train-2cff11f69a.json"
# Set check_hash=True if you suspect download is incorrect.
filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
pipeline.append(
wds.map(InsertNewKeypoints(filepath))
)
pipeline.extend([
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
wds.rename_keys(*rename_keys),
])
if transform is not None:
pipeline.append(wds.map(transform))
pipeline = wds.DataPipeline(*pipeline)
if infinite:
pipeline = pipeline.repeat(nepochs=1000000)
loader = wds.WebLoader(
pipeline, batch_size=None, shuffle=False,
num_workers=get_num_workers(num_workers),
persistent_workers=True,
)
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
return loader
|