File size: 2,445 Bytes
7484424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Any, List

import numpy as np
import torch
from PIL import Image
from PIL.Image import Image as Img

from fis.feature_extraction.detection.base import BaseDetector
from fis.feature_extraction.embedding.base import BaseEncoder


class EncodingPipeline:
    """Apply the detection and embedding models to an image."""

    def __init__(self, name: str, detection_model: BaseDetector, embedding_model: BaseEncoder) -> None:
        """Initialize the encoding pipeline.

        Args:
            name: Name of the pipeline.
            detection_model: Model used to detect the fashion items in the images.
            embedding_model: Model used to generate embeddings for each detected item.
        """
        self._name = name
        self._detection_model = detection_model
        self._embedding_model = embedding_model

    def encode(self, image: str) -> List[torch.Tensor]:
        """Encode each item from an image into a embedding.

        Args:
            image: path to the image.

        Returns:
            Embeddings for each detected item in the image.
        """
        image = self._load_images(image)
        bboxes = self._detection_model(image)
        items = self._crop_images(image, bboxes)

        embeddings = []
        for item in items:
            embedding = self._embedding_model(item)
            embeddings.append(embedding)

        return embeddings

    def _load_images(self, image: Any) -> Img:
        """Read an image from disk.

        Args:
            image: Path to the image on disk.

        Raises:
            TypeError: if the type of image is incorrect.

        Returns:
            PIL Image.
        """
        if isinstance(image, Img):
            pass
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        elif isinstance(image, str):
            image = Image.open(image)
        else:
            raise TypeError(f"Unknown type for image: {type(image)}")

        return image

    def _crop_images(self, image, bboxes) -> List[Img]:
        """Crop an image based on bounding boxes.

        Args:
            image: Image to crop items from.
            bboxes: Bounding box containing an item.

        Returns:
            List of cropped images.
        """
        items = []
        for bbox in bboxes:
            cropped_image = image.crop(bbox)
            items.append(cropped_image)

        return items