Spaces:
Running
Running
from typing import List, Tuple, Generator | |
import torch | |
import numpy as np | |
from PIL import Image | |
from surya.model.detection.model import EfficientViTForSemanticSegmentation | |
from surya.postprocessing.heatmap import get_and_clean_boxes | |
from surya.postprocessing.affinity import get_vertical_lines | |
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb | |
from surya.schema import TextDetectionResult | |
from surya.settings import settings | |
from tqdm import tqdm | |
from concurrent.futures import ProcessPoolExecutor | |
import torch.nn.functional as F | |
def get_batch_size(): | |
batch_size = settings.DETECTOR_BATCH_SIZE | |
if batch_size is None: | |
batch_size = 8 | |
if settings.TORCH_DEVICE_MODEL == "mps": | |
batch_size = 8 | |
if settings.TORCH_DEVICE_MODEL == "cuda": | |
batch_size = 36 | |
return batch_size | |
def batch_detection( | |
images: List, | |
model: EfficientViTForSemanticSegmentation, | |
processor, | |
batch_size=None | |
) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]: | |
assert all([isinstance(image, Image.Image) for image in images]) | |
if batch_size is None: | |
batch_size = get_batch_size() | |
heatmap_count = model.config.num_labels | |
orig_sizes = [image.size for image in images] | |
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes] | |
batches = [] | |
current_batch_size = 0 | |
current_batch = [] | |
for i in range(len(images)): | |
if current_batch_size + splits_per_image[i] > batch_size: | |
if len(current_batch) > 0: | |
batches.append(current_batch) | |
current_batch = [] | |
current_batch_size = 0 | |
current_batch.append(i) | |
current_batch_size += splits_per_image[i] | |
if len(current_batch) > 0: | |
batches.append(current_batch) | |
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"): | |
batch_image_idxs = batches[batch_idx] | |
batch_images = [images[j].convert("RGB") for j in batch_image_idxs] | |
split_index = [] | |
split_heights = [] | |
image_splits = [] | |
for image_idx, image in enumerate(batch_images): | |
image_parts, split_height = split_image(image, processor) | |
image_splits.extend(image_parts) | |
split_index.extend([image_idx] * len(image_parts)) | |
split_heights.extend(split_height) | |
image_splits = [prepare_image_detection(image, processor) for image in image_splits] | |
# Batch images in dim 0 | |
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device) | |
with torch.inference_mode(): | |
pred = model(pixel_values=batch) | |
logits = pred.logits | |
correct_shape = [processor.size["height"], processor.size["width"]] | |
current_shape = list(logits.shape[2:]) | |
if current_shape != correct_shape: | |
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False) | |
logits = logits.cpu().detach().numpy().astype(np.float32) | |
preds = [] | |
for i, (idx, height) in enumerate(zip(split_index, split_heights)): | |
# If our current prediction length is below the image idx, that means we have a new image | |
# Otherwise, we need to add to the current image | |
if len(preds) <= idx: | |
preds.append([logits[i][k] for k in range(heatmap_count)]) | |
else: | |
heatmaps = preds[idx] | |
pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] | |
if height < processor.size["height"]: | |
# Cut off padding to get original height | |
pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps] | |
for k in range(heatmap_count): | |
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) | |
preds[idx] = heatmaps | |
yield preds, [orig_sizes[j] for j in batch_image_idxs] | |
def parallel_get_lines(preds, orig_sizes): | |
heatmap, affinity_map = preds | |
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) | |
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) | |
affinity_size = list(reversed(affinity_map.shape)) | |
heatmap_size = list(reversed(heatmap.shape)) | |
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) | |
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes) | |
result = TextDetectionResult( | |
bboxes=bboxes, | |
vertical_lines=vertical_lines, | |
heatmap=heat_img, | |
affinity_map=aff_img, | |
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]] | |
) | |
return result | |
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: | |
detection_generator = batch_detection(images, model, processor, batch_size=batch_size) | |
results = [] | |
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) | |
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH | |
if parallelize: | |
with ProcessPoolExecutor(max_workers=max_workers) as executor: | |
for preds, orig_sizes in detection_generator: | |
batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) | |
results.extend(batch_results) | |
else: | |
for preds, orig_sizes in detection_generator: | |
for pred, orig_size in zip(preds, orig_sizes): | |
results.append(parallel_get_lines(pred, orig_size)) | |
return results | |