Spaces:
Runtime error
Runtime error
from typing import Any, List | |
import numpy as np | |
import torch | |
from PIL import Image | |
from PIL.Image import Image as Img | |
from fis.feature_extraction.detection.base import BaseDetector | |
from fis.feature_extraction.embedding.base import BaseEncoder | |
class EncodingPipeline: | |
"""Apply the detection and embedding models to an image.""" | |
def __init__(self, name: str, detection_model: BaseDetector, embedding_model: BaseEncoder) -> None: | |
"""Initialize the encoding pipeline. | |
Args: | |
name: Name of the pipeline. | |
detection_model: Model used to detect the fashion items in the images. | |
embedding_model: Model used to generate embeddings for each detected item. | |
""" | |
self._name = name | |
self._detection_model = detection_model | |
self._embedding_model = embedding_model | |
def encode(self, image: str) -> List[torch.Tensor]: | |
"""Encode each item from an image into a embedding. | |
Args: | |
image: path to the image. | |
Returns: | |
Embeddings for each detected item in the image. | |
""" | |
image = self._load_images(image) | |
bboxes = self._detection_model(image) | |
items = self._crop_images(image, bboxes) | |
embeddings = [] | |
for item in items: | |
embedding = self._embedding_model(item) | |
embeddings.append(embedding) | |
return embeddings | |
def _load_images(self, image: Any) -> Img: | |
"""Read an image from disk. | |
Args: | |
image: Path to the image on disk. | |
Raises: | |
TypeError: if the type of image is incorrect. | |
Returns: | |
PIL Image. | |
""" | |
if isinstance(image, Img): | |
pass | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif isinstance(image, str): | |
image = Image.open(image) | |
else: | |
raise TypeError(f"Unknown type for image: {type(image)}") | |
return image | |
def _crop_images(self, image, bboxes) -> List[Img]: | |
"""Crop an image based on bounding boxes. | |
Args: | |
image: Image to crop items from. | |
bboxes: Bounding box containing an item. | |
Returns: | |
List of cropped images. | |
""" | |
items = [] | |
for bbox in bboxes: | |
cropped_image = image.crop(bbox) | |
items.append(cropped_image) | |
return items | |