|
import gradio as gr |
|
import requests |
|
import torch |
|
import os |
|
from tqdm import tqdm |
|
|
|
from ultralytics import YOLO |
|
import cv2 |
|
import numpy as np |
|
import pandas as pd |
|
from skimage.transform import resize |
|
from skimage import img_as_bool |
|
from skimage.morphology import convex_hull_image |
|
import json |
|
|
|
|
|
|
|
def tableConvexHull(img, masks): |
|
mask=np.zeros(masks[0].shape,dtype="bool") |
|
for msk in masks: |
|
temp=msk.cpu().detach().numpy(); |
|
chull = convex_hull_image(temp); |
|
mask=np.bitwise_or(mask,chull) |
|
return mask |
|
|
|
def cls_exists(clss, cls): |
|
indices = torch.where(clss==cls) |
|
return len(indices[0])>0 |
|
|
|
def empty_mask(img): |
|
mask = np.zeros(img.shape[:2], dtype="uint8") |
|
return np.array(mask, dtype=bool) |
|
|
|
def extract_img_mask(img_model, img, config): |
|
res_dict = { |
|
'status' : 1 |
|
} |
|
res = get_predictions(img_model, img, config) |
|
|
|
if res['status']==-1: |
|
res_dict['status'] = -1 |
|
|
|
elif res['status']==0: |
|
res_dict['mask']=empty_mask(img) |
|
|
|
else: |
|
masks = res['masks'] |
|
boxes = res['boxes'] |
|
clss = boxes[:, 5] |
|
mask = extract_mask(img, masks, boxes, clss, 0) |
|
res_dict['mask'] = mask |
|
return res_dict |
|
|
|
def get_predictions(model, img2, config): |
|
res_dict = { |
|
'status': 1 |
|
} |
|
try: |
|
for result in model.predict(source=img2, verbose=False, retina_masks=config['rm'],\ |
|
imgsz=config['sz'], conf=config['conf'], stream=True,\ |
|
classes=config['classes']): |
|
try: |
|
res_dict['masks'] = result.masks.data |
|
res_dict['boxes'] = result.boxes.data |
|
del result |
|
return res_dict |
|
except Exception as e: |
|
res_dict['status'] = 0 |
|
return res_dict |
|
except: |
|
res_dict['status'] = -1 |
|
return res_dict |
|
|
|
def extract_mask(img, masks, boxes, clss, cls): |
|
if not cls_exists(clss, cls): |
|
return empty_mask(img) |
|
indices = torch.where(clss==cls) |
|
c_masks = masks[indices] |
|
mask_arr = torch.any(c_masks, dim=0).bool() |
|
mask_arr = mask_arr.cpu().detach().numpy() |
|
mask = mask_arr |
|
return mask |
|
|
|
|
|
def get_masks(img, model, img_model, flags, configs): |
|
response = { |
|
'status': 1 |
|
} |
|
ans_masks = [] |
|
img2 = img |
|
|
|
|
|
|
|
res = get_predictions(model, img2, configs['paratext']) |
|
if res['status']==-1: |
|
response['status'] = -1 |
|
return response |
|
elif res['status']==0: |
|
for i in range(2): ans_masks.append(empty_mask(img)) |
|
else: |
|
masks, boxes = res['masks'], res['boxes'] |
|
clss = boxes[:, 5] |
|
for cls in range(2): |
|
mask = extract_mask(img, masks, boxes, clss, cls) |
|
ans_masks.append(mask) |
|
|
|
|
|
|
|
res2 = get_predictions(model, img2, configs['imgtab']) |
|
if res2['status']==-1: |
|
response['status'] = -1 |
|
return response |
|
elif res2['status']==0: |
|
for i in range(2): ans_masks.append(empty_mask(img)) |
|
else: |
|
masks, boxes = res2['masks'], res2['boxes'] |
|
clss = boxes[:, 5] |
|
|
|
if cls_exists(clss, 2): |
|
img_res = extract_img_mask(img_model, img, configs['image']) |
|
if img_res['status'] == 1: |
|
img_mask = img_res['mask'] |
|
else: |
|
response['status'] = -1 |
|
return response |
|
|
|
else: |
|
img_mask = empty_mask(img) |
|
ans_masks.append(img_mask) |
|
|
|
if cls_exists(clss, 3): |
|
indices = torch.where(clss==3) |
|
tbl_mask = tableConvexHull(img, masks[indices]) |
|
else: |
|
tbl_mask = empty_mask(img) |
|
ans_masks.append(tbl_mask) |
|
|
|
if not configs['paratext']['rm']: |
|
h, w, c = img.shape |
|
for i in range(4): |
|
ans_masks[i] = img_as_bool(resize(ans_masks[i], (h, w))) |
|
|
|
|
|
response['masks'] = ans_masks |
|
return response |
|
|
|
def overlay(image, mask, color, alpha, resize=None): |
|
"""Combines image and its segmentation mask into a single image. |
|
https://www.kaggle.com/code/purplejester/showing-samples-with-segmentation-mask-overlay |
|
|
|
Params: |
|
image: Training image. np.ndarray, |
|
mask: Segmentation mask. np.ndarray, |
|
color: Color for segmentation mask rendering. tuple[int, int, int] = (255, 0, 0) |
|
alpha: Segmentation mask's transparency. float = 0.5, |
|
resize: If provided, both image and its mask are resized before blending them together. |
|
tuple[int, int] = (1024, 1024)) |
|
|
|
Returns: |
|
image_combined: The combined image. np.ndarray |
|
|
|
""" |
|
color = color[::-1] |
|
colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) |
|
colored_mask = np.moveaxis(colored_mask, 0, -1) |
|
masked = np.ma.MaskedArray(image, mask=colored_mask, fill_value=color) |
|
image_overlay = masked.filled() |
|
|
|
if resize is not None: |
|
image = cv2.resize(image.transpose(1, 2, 0), resize) |
|
image_overlay = cv2.resize(image_overlay.transpose(1, 2, 0), resize) |
|
|
|
image_combined = cv2.addWeighted(image, 1 - alpha, image_overlay, alpha, 0) |
|
|
|
return image_combined |
|
|
|
|
|
|
|
model_path = 'models' |
|
general_model_name = 'e50_aug.pt' |
|
image_model_name = 'e100_img.pt' |
|
|
|
general_model = YOLO(os.path.join(model_path, general_model_name)) |
|
image_model = YOLO(os.path.join(model_path, image_model_name)) |
|
|
|
image_path = 'examples' |
|
sample_name = ['0040da34-25c8-4a5a-a6aa-36733ea3b8eb.png', |
|
'0050a8ee-382b-447e-9c5b-8506d9507bef.png', '0064d3e2-3ba2-4332-a28f-3a165f2b84b1.png'] |
|
|
|
sample_path = [os.path.join(image_path, sample) for sample in sample_name] |
|
|
|
flags = { |
|
'hist': False, |
|
'bz': False |
|
} |
|
|
|
|
|
configs = {} |
|
configs['paratext'] = { |
|
'sz' : 640, |
|
'conf': 0.25, |
|
'rm': True, |
|
'classes': [0, 1] |
|
} |
|
configs['imgtab'] = { |
|
'sz' : 640, |
|
'conf': 0.35, |
|
'rm': True, |
|
'classes': [2, 3] |
|
} |
|
configs['image'] = { |
|
'sz' : 640, |
|
'conf': 0.35, |
|
'rm': True, |
|
'classes': [0] |
|
} |
|
|
|
def evaluate(img_path, model=general_model, img_model=image_model,\ |
|
configs=configs, flags=flags): |
|
|
|
img = cv2.imread(img_path) |
|
res = get_masks(img, general_model, image_model, flags, configs) |
|
if res['status']==-1: |
|
for idx in configs.keys(): |
|
configs[idx]['rm'] = False |
|
return evaluate(img, model, img_model, flags, configs) |
|
else: |
|
masks = res['masks'] |
|
|
|
color_map = { |
|
0 : (255, 0, 0), |
|
1 : (0, 255, 0), |
|
2 : (0, 0, 255), |
|
3 : (255, 255, 0), |
|
} |
|
for i, mask in enumerate(masks): |
|
img = overlay(image=img, mask=mask, color=color_map[i], alpha=0.4) |
|
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
inputs_image = [ |
|
gr.components.Image(type="filepath", label="Input Image"), |
|
] |
|
outputs_image = [ |
|
gr.components.Image(type="numpy", label="Output Image"), |
|
] |
|
interface_image = gr.Interface( |
|
fn=evaluate, |
|
inputs=inputs_image, |
|
outputs=outputs_image, |
|
title="Document Layout Segmentor", |
|
examples=sample_path, |
|
cache_examples=True, |
|
).launch() |