Spaces:
Runtime error
Runtime error
import torch | |
import tops | |
import numpy as np | |
import io | |
import webdataset as wds | |
import os | |
import json | |
from pathlib import Path | |
from ..utils import png_decoder, mask_decoder, get_num_workers, collate_fn | |
def kp_decoder(x): | |
# Keypoints are between [0, 1] for webdataset | |
keypoints = torch.from_numpy(np.load(io.BytesIO(x))).float() | |
def check_outside(x): return (x < 0).logical_or(x > 1) | |
is_outside = check_outside(keypoints[:, 0]).logical_or( | |
check_outside(keypoints[:, 1]) | |
) | |
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) | |
return keypoints | |
def vertices_decoder(x): | |
vertices = torch.from_numpy(np.load(io.BytesIO(x)).astype(np.int32)) | |
return vertices.squeeze()[None] | |
class InsertNewKeypoints: | |
def __init__(self, keypoints_path: Path) -> None: | |
with open(keypoints_path, "r") as fp: | |
self.keypoints = json.load(fp) | |
def __call__(self, sample): | |
key = sample["__key__"] | |
keypoints = torch.tensor(self.keypoints[key], dtype=torch.float32) | |
def check_outside(x): return (x < 0).logical_or(x > 1) | |
is_outside = check_outside(keypoints[:, 0]).logical_or( | |
check_outside(keypoints[:, 1]) | |
) | |
keypoints[:, 2] = (keypoints[:, 2] > 0).logical_and(is_outside.logical_not()) | |
sample["keypoints.npy"] = keypoints | |
return sample | |
def get_dataloader_fdh_wds( | |
path, | |
batch_size: int, | |
num_workers: int, | |
transform: torch.nn.Module, | |
gpu_transform: torch.nn.Module, | |
infinite: bool, | |
shuffle: bool, | |
partial_batches: bool, | |
load_embedding: bool, | |
sample_shuffle=10_000, | |
tar_shuffle=100, | |
read_condition=False, | |
channels_last=False, | |
load_new_keypoints=False, | |
keypoints_split=None, | |
): | |
# Need to set this for split_by_node to work. | |
os.environ["RANK"] = str(tops.rank()) | |
os.environ["WORLD_SIZE"] = str(tops.world_size()) | |
if infinite: | |
pipeline = [wds.ResampledShards(str(path))] | |
else: | |
pipeline = [wds.SimpleShardList(str(path))] | |
if shuffle: | |
pipeline.append(wds.shuffle(tar_shuffle)) | |
pipeline.extend([ | |
wds.split_by_node, | |
wds.split_by_worker, | |
]) | |
if shuffle: | |
pipeline.append(wds.shuffle(sample_shuffle)) | |
decoder = [ | |
wds.handle_extension("image.png", png_decoder), | |
wds.handle_extension("mask.png", mask_decoder), | |
wds.handle_extension("maskrcnn_mask.png", mask_decoder), | |
wds.handle_extension("keypoints.npy", kp_decoder), | |
] | |
rename_keys = [ | |
["img", "image.png"], ["mask", "mask.png"], | |
["keypoints", "keypoints.npy"], ["maskrcnn_mask", "maskrcnn_mask.png"], | |
["__key__", "__key__"] | |
] | |
if load_embedding: | |
decoder.extend([ | |
wds.handle_extension("vertices.npy", vertices_decoder), | |
wds.handle_extension("E_mask.png", mask_decoder) | |
]) | |
rename_keys.extend([ | |
["vertices", "vertices.npy"], | |
["E_mask", "e_mask.png"] | |
]) | |
if read_condition: | |
decoder.append( | |
wds.handle_extension("condition.png", png_decoder) | |
) | |
rename_keys.append(["condition", "condition.png"]) | |
pipeline.extend([ | |
wds.tarfile_to_samples(), | |
wds.decode(*decoder), | |
]) | |
if load_new_keypoints: | |
assert keypoints_split in ["train", "val"] | |
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/1eb88522-8b91-49c7-b56a-ed98a9c7888cef9c0429-a385-4248-abe3-8682de26d041f268aed1-7c88-4677-baad-7623c2ee330f" | |
file_name = "fdh_keypoints_val-050133b34d.json" | |
if keypoints_split == "train": | |
keypoint_url = "https://api.loke.aws.unit.no/dlr-gui-backend-resources-content/v2/contents/links/3e828b1c-d6c0-4622-90bc-1b2cce48ccfff14ab45d-0a5c-431d-be13-7e60580765bd7938601c-e72e-41d9-8836-fffc49e76f58" | |
file_name = "fdh_keypoints_train-2cff11f69a.json" | |
# Set check_hash=True if you suspect download is incorrect. | |
filepath = tops.download_file(keypoint_url, file_name=file_name, check_hash=False) | |
pipeline.append( | |
wds.map(InsertNewKeypoints(filepath)) | |
) | |
pipeline.extend([ | |
wds.batched(batch_size, collation_fn=collate_fn, partial=partial_batches), | |
wds.rename_keys(*rename_keys), | |
]) | |
if transform is not None: | |
pipeline.append(wds.map(transform)) | |
pipeline = wds.DataPipeline(*pipeline) | |
if infinite: | |
pipeline = pipeline.repeat(nepochs=1000000) | |
loader = wds.WebLoader( | |
pipeline, batch_size=None, shuffle=False, | |
num_workers=get_num_workers(num_workers), | |
persistent_workers=True, | |
) | |
loader = tops.DataPrefetcher(loader, gpu_transform, channels_last=channels_last, to_float=False) | |
return loader | |