Jerome Blin
Add application file
7484424
raw
history blame
2.45 kB
from typing import Any, List
import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as Img
from fis.feature_extraction.detection.base import BaseDetector
from fis.feature_extraction.embedding.base import BaseEncoder
class EncodingPipeline:
"""Apply the detection and embedding models to an image."""
def __init__(self, name: str, detection_model: BaseDetector, embedding_model: BaseEncoder) -> None:
"""Initialize the encoding pipeline.
Args:
name: Name of the pipeline.
detection_model: Model used to detect the fashion items in the images.
embedding_model: Model used to generate embeddings for each detected item.
"""
self._name = name
self._detection_model = detection_model
self._embedding_model = embedding_model
def encode(self, image: str) -> List[torch.Tensor]:
"""Encode each item from an image into a embedding.
Args:
image: path to the image.
Returns:
Embeddings for each detected item in the image.
"""
image = self._load_images(image)
bboxes = self._detection_model(image)
items = self._crop_images(image, bboxes)
embeddings = []
for item in items:
embedding = self._embedding_model(item)
embeddings.append(embedding)
return embeddings
def _load_images(self, image: Any) -> Img:
"""Read an image from disk.
Args:
image: Path to the image on disk.
Raises:
TypeError: if the type of image is incorrect.
Returns:
PIL Image.
"""
if isinstance(image, Img):
pass
elif isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif isinstance(image, str):
image = Image.open(image)
else:
raise TypeError(f"Unknown type for image: {type(image)}")
return image
def _crop_images(self, image, bboxes) -> List[Img]:
"""Crop an image based on bounding boxes.
Args:
image: Image to crop items from.
bboxes: Bounding box containing an item.
Returns:
List of cropped images.
"""
items = []
for bbox in bboxes:
cropped_image = image.crop(bbox)
items.append(cropped_image)
return items