Spaces:
Build error
Build error
File size: 4,868 Bytes
6723494 a50e161 6723494 8f3faa9 6723494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import pickle as pkl
import numpy as np
import numpy.typing as npt
from PIL import Image
from PIL.Image import Image as ImageType
from pathlib import Path
def build_data(data_path: Path) -> dict:
data = {}
image_paths = (
list(data_path.glob("*.png"))
+ list(data_path.glob("*.jpg"))
+ list(data_path.glob("*.jpeg"))
)
for image_path in image_paths:
image_name = image_path.stem
data[image_name] = {
"image": image_path,
"labels": [],
"emb": None,
"meta_data": None,
}
return data
class Data:
def __init__(self, data_path: Path):
self.data_path = data_path
if Path(data_path).exists():
with open(data_path, "rb") as f:
self.data = pkl.load(f)
else:
data_path.parent.mkdir(parents=True, exist_ok=True)
with open(data_path, "wb") as f:
pkl.dump({}, f)
self.data = {}
def _save_data(self) -> None:
with open(self.data_path, "wb") as f:
pkl.dump(self.data, f)
def __contains__(self, image: str) -> bool:
return image in self.data
def emb_exists(self, image: str) -> bool:
return "emb" in self.data[image] and self.data[image]["emb"] is not None
def save_labels(
self, image: str, masks: list[ImageType], bboxes: list[tuple[int, ...]], labels: list[str]
) -> None:
self.clear_labels(image)
label_paths = []
for i, (mask, label) in enumerate(zip(masks, labels)):
label_path = self.data_path.parent / f"{image}.{label}.{i}.png"
mask.save(label_path)
label_paths.append(str(label_path))
self.data[image]["masks"] = label_paths
self.data[image]["labels"] = labels
self.data[image]["bboxes"] = bboxes
self._save_data()
def save_meta_data(self, image: str, meta_data: dict) -> None:
self.data[image]["meta_data"] = meta_data
self._save_data()
def save_emb(self, image: str, emb: npt.NDArray) -> None:
emb_path = self.data_path.parent / f"{image}.emb.npy"
np.save(emb_path, emb)
self.data[image]["emb"] = emb_path
self._save_data()
def save_hq_emb(self, image: str, embs: list[npt.NDArray]) -> None:
for i, emb in enumerate(embs):
emb_path = self.data_path.parent / f"{image}.emb.{i}.npy"
np.save(emb_path, emb)
self.data[image][f"emb.{i}"] = emb_path
self._save_data()
def save_image(self, image: str, image_pil: ImageType) -> None:
image_path = self.data_path.parent / f"{image}.png"
image_pil.save(image_path)
self.data[image] = {}
self.data[image]["image"] = image_path
self._save_data()
def clear_labels(self, image: str) -> None:
if "masks" in self.data[image]:
for label_path in self.data[image]["masks"]:
Path(label_path).unlink(missing_ok=True)
if "labels" in self.data[image]:
self.data[image]["labels"] = []
self._save_data()
def delete_image(self, image: str) -> None:
if image in self.data:
if "image" in self.data[image]:
Path(self.data[image]["image"]).unlink(missing_ok=True)
if "emb" in self.data[image]:
Path(self.data[image]["emb"]).unlink(missing_ok=True)
if "masks" in self.data[image]:
for label_path in self.data[image]["masks"]:
Path(label_path).unlink(missing_ok=True)
del self.data[image]
self._save_data()
def get_all_images(self) -> list:
return list(self.data.keys())
def get_image(self, image: str) -> ImageType:
return Image.open(self.data[image]["image"])
def get_emb(self, image: str) -> npt.NDArray:
return np.load(self.data[image]["emb"])
def get_hq_emb(self, image: str) -> list[npt.NDArray]:
embs = []
i = 0
while True:
if f"emb.{i}" in self.data[image]:
embs.append(np.load(self.data[image][f"emb.{i}"]))
i += 1
else:
break
return embs
def get_labels(
self, image: str
) -> tuple[list[ImageType], list[tuple[int, ...]], list[str]]:
if (
"masks" not in self.data[image]
or "labels" not in self.data[image]
or "bboxes" not in self.data[image]
):
return [], [], []
return (
[Image.open(mask) for mask in self.data[image]["masks"]],
[tuple(e) for e in self.data[image]["bboxes"]],
self.data[image]["labels"],
)
def get_meta_data(self, image: str) -> dict:
return self.data[image]["meta_data"]
|