Spaces:
Build error
Build error
import pickle as pkl | |
import numpy as np | |
import numpy.typing as npt | |
from PIL import Image | |
from PIL.Image import Image as ImageType | |
from pathlib import Path | |
def build_data(data_path: Path) -> dict: | |
data = {} | |
image_paths = ( | |
list(data_path.glob("*.png")) | |
+ list(data_path.glob("*.jpg")) | |
+ list(data_path.glob("*.jpeg")) | |
) | |
for image_path in image_paths: | |
image_name = image_path.stem | |
data[image_name] = { | |
"image": image_path, | |
"labels": [], | |
"emb": None, | |
"meta_data": None, | |
} | |
return data | |
class Data: | |
def __init__(self, data_path: Path): | |
self.data_path = data_path | |
if Path(data_path).exists(): | |
with open(data_path, "rb") as f: | |
self.data = pkl.load(f) | |
else: | |
data_path.parent.mkdir(parents=True, exist_ok=True) | |
with open(data_path, "wb") as f: | |
pkl.dump({}, f) | |
self.data = {} | |
def _save_data(self) -> None: | |
with open(self.data_path, "wb") as f: | |
pkl.dump(self.data, f) | |
def __contains__(self, image: str) -> bool: | |
return image in self.data | |
def emb_exists(self, image: str) -> bool: | |
return "emb" in self.data[image] and self.data[image]["emb"] is not None | |
def save_labels( | |
self, image: str, masks: list[ImageType], bboxes: list[tuple[int, ...]], labels: list[str] | |
) -> None: | |
self.clear_labels(image) | |
label_paths = [] | |
for i, (mask, label) in enumerate(zip(masks, labels)): | |
label_path = self.data_path.parent / f"{image}.{label}.{i}.png" | |
mask.save(label_path) | |
label_paths.append(str(label_path)) | |
self.data[image]["masks"] = label_paths | |
self.data[image]["labels"] = labels | |
self.data[image]["bboxes"] = bboxes | |
self._save_data() | |
def save_meta_data(self, image: str, meta_data: dict) -> None: | |
self.data[image]["meta_data"] = meta_data | |
self._save_data() | |
def save_emb(self, image: str, emb: npt.NDArray) -> None: | |
emb_path = self.data_path.parent / f"{image}.emb.npy" | |
np.save(emb_path, emb) | |
self.data[image]["emb"] = emb_path | |
self._save_data() | |
def save_hq_emb(self, image: str, embs: list[npt.NDArray]) -> None: | |
for i, emb in enumerate(embs): | |
emb_path = self.data_path.parent / f"{image}.emb.{i}.npy" | |
np.save(emb_path, emb) | |
self.data[image][f"emb.{i}"] = emb_path | |
self._save_data() | |
def save_image(self, image: str, image_pil: ImageType) -> None: | |
image_path = self.data_path.parent / f"{image}.png" | |
image_pil.save(image_path) | |
self.data[image] = {} | |
self.data[image]["image"] = image_path | |
self._save_data() | |
def clear_labels(self, image: str) -> None: | |
if "masks" in self.data[image]: | |
for label_path in self.data[image]["masks"]: | |
Path(label_path).unlink(missing_ok=True) | |
if "labels" in self.data[image]: | |
self.data[image]["labels"] = [] | |
self._save_data() | |
def delete_image(self, image: str) -> None: | |
if image in self.data: | |
if "image" in self.data[image]: | |
Path(self.data[image]["image"]).unlink(missing_ok=True) | |
if "emb" in self.data[image]: | |
Path(self.data[image]["emb"]).unlink(missing_ok=True) | |
if "masks" in self.data[image]: | |
for label_path in self.data[image]["masks"]: | |
Path(label_path).unlink(missing_ok=True) | |
del self.data[image] | |
self._save_data() | |
def get_all_images(self) -> list: | |
return list(self.data.keys()) | |
def get_image(self, image: str) -> ImageType: | |
return Image.open(self.data[image]["image"]) | |
def get_emb(self, image: str) -> npt.NDArray: | |
return np.load(self.data[image]["emb"]) | |
def get_hq_emb(self, image: str) -> list[npt.NDArray]: | |
embs = [] | |
i = 0 | |
while True: | |
if f"emb.{i}" in self.data[image]: | |
embs.append(np.load(self.data[image][f"emb.{i}"])) | |
i += 1 | |
else: | |
break | |
return embs | |
def get_labels( | |
self, image: str | |
) -> tuple[list[ImageType], list[tuple[int, ...]], list[str]]: | |
if ( | |
"masks" not in self.data[image] | |
or "labels" not in self.data[image] | |
or "bboxes" not in self.data[image] | |
): | |
return [], [], [] | |
return ( | |
[Image.open(mask) for mask in self.data[image]["masks"]], | |
[tuple(e) for e in self.data[image]["bboxes"]], | |
self.data[image]["labels"], | |
) | |
def get_meta_data(self, image: str) -> dict: | |
return self.data[image]["meta_data"] | |