|
|
|
|
|
import random |
|
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple |
|
import torch |
|
from torch import nn |
|
|
|
SampledData = Any |
|
ModelOutput = Any |
|
|
|
|
|
def _grouper(iterable: Iterable[Any], n: int, fillvalue=None) -> Iterator[Tuple[Any]]: |
|
""" |
|
Group elements of an iterable by chunks of size `n`, e.g. |
|
grouper(range(9), 4) -> |
|
(0, 1, 2, 3), (4, 5, 6, 7), (8, None, None, None) |
|
""" |
|
it = iter(iterable) |
|
while True: |
|
values = [] |
|
for _ in range(n): |
|
try: |
|
value = next(it) |
|
except StopIteration: |
|
if values: |
|
values.extend([fillvalue] * (n - len(values))) |
|
yield tuple(values) |
|
return |
|
values.append(value) |
|
yield tuple(values) |
|
|
|
|
|
class ScoreBasedFilter: |
|
""" |
|
Filters entries in model output based on their scores |
|
Discards all entries with score less than the specified minimum |
|
""" |
|
|
|
def __init__(self, min_score: float = 0.8): |
|
self.min_score = min_score |
|
|
|
def __call__(self, model_output: ModelOutput) -> ModelOutput: |
|
for model_output_i in model_output: |
|
instances = model_output_i["instances"] |
|
if not instances.has("scores"): |
|
continue |
|
instances_filtered = instances[instances.scores >= self.min_score] |
|
model_output_i["instances"] = instances_filtered |
|
return model_output |
|
|
|
|
|
class InferenceBasedLoader: |
|
""" |
|
Data loader based on results inferred by a model. Consists of: |
|
- a data loader that provides batches of images |
|
- a model that is used to infer the results |
|
- a data sampler that converts inferred results to annotations |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
data_loader: Iterable[List[Dict[str, Any]]], |
|
data_sampler: Optional[Callable[[ModelOutput], List[SampledData]]] = None, |
|
data_filter: Optional[Callable[[ModelOutput], ModelOutput]] = None, |
|
shuffle: bool = True, |
|
batch_size: int = 4, |
|
inference_batch_size: int = 4, |
|
drop_last: bool = False, |
|
category_to_class_mapping: Optional[dict] = None, |
|
): |
|
""" |
|
Constructor |
|
|
|
Args: |
|
model (torch.nn.Module): model used to produce data |
|
data_loader (Iterable[List[Dict[str, Any]]]): iterable that provides |
|
dictionaries with "images" and "categories" fields to perform inference on |
|
data_sampler (Callable: ModelOutput -> SampledData): functor |
|
that produces annotation data from inference results; |
|
(optional, default: None) |
|
data_filter (Callable: ModelOutput -> ModelOutput): filter |
|
that selects model outputs for further processing |
|
(optional, default: None) |
|
shuffle (bool): if True, the input images get shuffled |
|
batch_size (int): batch size for the produced annotation data |
|
inference_batch_size (int): batch size for input images |
|
drop_last (bool): if True, drop the last batch if it is undersized |
|
category_to_class_mapping (dict): category to class mapping |
|
""" |
|
self.model = model |
|
self.model.eval() |
|
self.data_loader = data_loader |
|
self.data_sampler = data_sampler |
|
self.data_filter = data_filter |
|
self.shuffle = shuffle |
|
self.batch_size = batch_size |
|
self.inference_batch_size = inference_batch_size |
|
self.drop_last = drop_last |
|
if category_to_class_mapping is not None: |
|
self.category_to_class_mapping = category_to_class_mapping |
|
else: |
|
self.category_to_class_mapping = {} |
|
|
|
def __iter__(self) -> Iterator[List[SampledData]]: |
|
for batch in self.data_loader: |
|
|
|
|
|
|
|
images_and_categories = [ |
|
{"image": image, "category": category} |
|
for element in batch |
|
for image, category in zip(element["images"], element["categories"]) |
|
] |
|
if not images_and_categories: |
|
continue |
|
if self.shuffle: |
|
random.shuffle(images_and_categories) |
|
yield from self._produce_data(images_and_categories) |
|
|
|
def _produce_data( |
|
self, images_and_categories: List[Tuple[torch.Tensor, Optional[str]]] |
|
) -> Iterator[List[SampledData]]: |
|
""" |
|
Produce batches of data from images |
|
|
|
Args: |
|
images_and_categories (List[Tuple[torch.Tensor, Optional[str]]]): |
|
list of images and corresponding categories to process |
|
|
|
Returns: |
|
Iterator over batches of data sampled from model outputs |
|
""" |
|
data_batches: List[SampledData] = [] |
|
category_to_class_mapping = self.category_to_class_mapping |
|
batched_images_and_categories = _grouper(images_and_categories, self.inference_batch_size) |
|
for batch in batched_images_and_categories: |
|
batch = [ |
|
{ |
|
"image": image_and_category["image"].to(self.model.device), |
|
"category": image_and_category["category"], |
|
} |
|
for image_and_category in batch |
|
if image_and_category is not None |
|
] |
|
if not batch: |
|
continue |
|
with torch.no_grad(): |
|
model_output = self.model(batch) |
|
for model_output_i, batch_i in zip(model_output, batch): |
|
assert len(batch_i["image"].shape) == 3 |
|
model_output_i["image"] = batch_i["image"] |
|
instance_class = category_to_class_mapping.get(batch_i["category"], 0) |
|
model_output_i["instances"].dataset_classes = torch.tensor( |
|
[instance_class] * len(model_output_i["instances"]) |
|
) |
|
model_output_filtered = ( |
|
model_output if self.data_filter is None else self.data_filter(model_output) |
|
) |
|
data = ( |
|
model_output_filtered |
|
if self.data_sampler is None |
|
else self.data_sampler(model_output_filtered) |
|
) |
|
for data_i in data: |
|
if len(data_i["instances"]): |
|
data_batches.append(data_i) |
|
if len(data_batches) >= self.batch_size: |
|
yield data_batches[: self.batch_size] |
|
data_batches = data_batches[self.batch_size :] |
|
if not self.drop_last and data_batches: |
|
yield data_batches |
|
|