haakohu's picture
initial
5d756f1
import torch
import tops
import numpy as np
import io
import webdataset as wds
import os
import json
from pathlib import Path
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn
def kp_decoder(x):
# Keypoints are between [0, 1] for webdataset
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float()
def check_outside(x): return (x < 0).logical_or(x > 1)
is_outside = check_outside(keypoints[:, 0]).logical_or(
check_outside(keypoints[:, 1])
)
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
return keypoints
def vertices_decoder(x):
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32))
return vertices.squeeze()[None]
class InsertNewKeypoints:
def __init__(self, keypoints_path: Path) -> None:
with open(keypoints_path, "r") as fp:
self.keypoints = json.load(fp)
def __call__(self, sample):
key = sample["__key__"]
keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32)
def check_outside(x): return (x < 0).logical_or(x > 1)
is_outside = check_outside(keypoints[:, 0]).logical_or(
check_outside(keypoints[:, 1])
)
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not())
sample["keypoints.npy"] = keypoints
return sample
def get_dataloader_fdh_wds(
path,
batch_size: int,
num_workers: int,
transform: torch.nn.Module,
gpu_transform: torch.nn.Module,
infinite: bool,
shuffle: bool,
partial_batches: bool,
load_embedding: bool,
sample_shuffle=10_000,
tar_shuffle=100,
read_condition=False,
channels_last=False,
load_new_keypoints=False,
keypoints_split=None,
):
# Need to set this for split_by_node to work.
os.environ["RANK"] = str(tops.rank())
os.environ["WORLD_SIZE"] = str(tops.world_size())
if infinite:
pipeline = [wds.ResampledShards(str(path))]
else:
pipeline = [wds.SimpleShardList(str(path))]
if shuffle:
pipeline.append(wds.shuffle(tar_shuffle))
pipeline.extend([
wds.split_by_node,
wds.split_by_worker,
])
if shuffle:
pipeline.append(wds.shuffle(sample_shuffle))
decoder = [
wds.handle_extension("image.png", png_decoder),
wds.handle_extension("mask.png", mask_decoder),
wds.handle_extension("maskrcnn_mask.png", mask_decoder),
wds.handle_extension("keypoints.npy", kp_decoder),
]
rename_keys = [
["img", "image.png"], ["mask", "mask.png"],
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"],
["__key__", "__key__"]
]
if load_embedding:
decoder.extend([
wds.handle_extension("vertices.npy", vertices_decoder),
wds.handle_extension("E_mask.png", mask_decoder)
])
rename_keys.extend([
["vertices", "vertices.npy"],
["E_mask", "e_mask.png"]
])
if read_condition:
decoder.append(
wds.handle_extension("condition.png", png_decoder)
)
rename_keys.append(["condition", "condition.png"])
pipeline.extend([
wds.tarfile_to_samples(),
wds.decode(*decoder),
])
if load_new_keypoints:
assert keypoints_split in ["train", "val"]
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f"
file_name = "fdh_keypoints_val-050133b34d.json"
if keypoints_split == "train":
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58"
file_name = "fdh_keypoints_train-2cff11f69a.json"
# Set check_hash=True if you suspect download is incorrect.
filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False)
pipeline.append(
wds.map(InsertNewKeypoints(filepath))
)
pipeline.extend([
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches),
wds.rename_keys(*rename_keys),
])
if transform is not None:
pipeline.append(wds.map(transform))
pipeline = wds.DataPipeline(*pipeline)
if infinite:
pipeline = pipeline.repeat(nepochs=1000000)
loader = wds.WebLoader(
pipeline, batch_size=None, shuffle=False,
num_workers=get_num_workers(num_workers),
persistent_workers=True,
)
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False)
return loader