oschan77's picture
Update app.py
5b7272a verified
# -*- coding: utf-8 -*-
import copy
import os
import datetime
import glob
import io
import json
from copy import deepcopy
from collections import OrderedDict
from zipfile import ZipFile
import time
import torch
import numpy as np
import pandas as pd
from torchvision.ops import nms
from adjustText import adjust_text
import tqdm
import requests
import gradio as gr
import cv2
from PIL import Image
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon
import matplotlib.patches as patches
from openpyxl import Workbook
from openpyxl.styles import Alignment
from openpyxl.utils import get_column_letter
from openpyxl.utils.dataframe import dataframe_to_rows
if torch.cuda.is_available():
os.system("pip install -q onnxruntime-gpu")
provider = ['CUDAExecutionProvider']
else:
os.system("pip install -q onnxruntime")
provider = ['CPUExecutionProvider']
import onnxruntime as ort
ort_session = ort.InferenceSession(
'./cas_rcnn_swin_t_fpn_mix500_r2_150e_reg_class_agnostic_F5.onnx',
providers=provider
)
def get_image(path, show=False):
"""
Load an image from the specified path and return it as a NumPy array.
Args:
path (str): The path to the image file.
show (bool, optional): Whether to display the loaded image using Matplotlib. Default is False.
Returns:
numpy.ndarray: The image as a NumPy array.
Note:
- The image is loaded using the PIL library and converted to the RGB color space.
- If `show` is set to True, the loaded image is displayed using Matplotlib.
"""
with Image.open(path) as img:
img = np.array(img.convert('RGB'))
if show:
plt.imshow(img)
plt.axis('off')
plt.show()
plt.close()
return img
def image_resize(image, width = None, height = None, inter = cv2.INTER_AREA):
"""
Resize the input image while maintaining its aspect ratio.
Args:
image (numpy.ndarray): The input image to be resized.
width (int, optional): The desired width of the output image. If None, the width is calculated based on the given height.
height (int, optional): The desired height of the output image. If None, the height is calculated based on the given width.
inter (int, optional): The interpolation method to be used for resizing the image.
Returns:
numpy.ndarray: The resized image.
Note:
- If both `width` and `height` are None, the original image is returned without any changes.
- If only one of `width` or `height` is provided, the other dimension is calculated to maintain the aspect ratio of the image.
"""
dim = None
(h, w) = image.shape[:2]
if width is None and height is None:
return image
if width is None:
r = height / float(h)
dim = (int(w * r), height)
else:
r = width / float(w)
dim = (width, int(h * r))
resized = cv2.resize(image, dim, interpolation = inter)
return resized
def preprocess(img):
"""
Preprocess the input image for the task.
Args:
img (numpy.ndarray): The input image to be preprocessed.
Returns:
numpy.ndarray: The preprocessed image.
Note:
- The image is normalized by dividing it by 255 to scale the pixel values between 0 and 1.
- The image is resized to a height of 800 pixels while maintaining the aspect ratio.
- The mean [0.485, 0.456, 0.406] and standard deviation [0.229, 0.224, 0.225] are subtracted from the image.
- The image dimensions are transposed to match the expected input shape.
- The image is converted to the data type np.float32 and expanded to have a batch dimension.
"""
img = img / 255.
img = image_resize(img, width=None, height=800)
h, w = img.shape[0], img.shape[1]
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
img = np.transpose(img, axes=[2, 0, 1])
img = img.astype(np.float32)
img = np.expand_dims(img, axis=0)
return img
def preprocess_for_imshow(img):
"""
Preprocess the input image for displaying using Matplotlib's imshow function.
Args:
img (numpy.ndarray): The input image to be preprocessed.
Returns:
numpy.ndarray: The preprocessed image.
Note:
- The image is resized to a height of 800 pixels while maintaining the aspect ratio.
"""
img = image_resize(img, width=None, height=800)
return img
def predict_one_image(path):
"""
Predict the labels, scores, and bounding boxes for an input image using a pre-trained model.
Args:
path (str): The path to the image file.
Returns:
dict: A dictionary containing the predicted labels, scores, and bounding boxes.
Note:
- The image is loaded and preprocessed using the `get_image` and `preprocess` functions.
- The preprocessed image is passed through a pre-trained model using an ONNX Runtime session (`ort_session`).
- The predicted results, including the bounding box coordinates, scores, and labels, are extracted from the output of the model.
- The labels, scores, and bounding boxes are returned as a dictionary.
"""
img = get_image(path, show=False)
img = preprocess(img)
ort_inputs = {ort_session.get_inputs()[0].name: img}
preds = ort_session.run(None, ort_inputs)
bbox_and_score = preds[0] # bounding box and score
labels = preds[1][0] # labels
scores = bbox_and_score[0][:, -1]
bboxes = bbox_and_score[0][:, 0:4]
return {"labels": labels, "scores": scores, "bboxes": bboxes}
def results_postprocess(results, classes, scaling_factor=(None, None)):
"""
Post-process the detection results to rescale bounding boxes and append scores.
Args:
results (dict): Detection results dictionary containing "labels", "bboxes", and "scores".
classes (list): List of classes representing the detected objects.
scaling_factor (tuple, optional): Scaling factors for the x-axis and y-axis. Defaults to (None, None).
Returns:
list: List of rescaled bounding boxes with appended scores.
Notes:
This function takes the detection results dictionary, rescales the bounding boxes, and appends
the scores for each detected class. The scaling factor is used to rescale the bounding boxes.
The function initializes the scaling factor to a default value of (4.083194675540765, 4.08).
It then iterates over the classes and filters the results based on the class label. The bounding
boxes and scores corresponding to the current class are extracted.
The bounding boxes are rescaled using the scaling factor and the array [x_ratio, y_ratio, x_ratio, y_ratio].
The scores are reshaped to have a shape of (-1, 1).
Finally, the rescaled bounding boxes with appended scores are concatenated using `np.hstack`, and
added to the output_bbox list.
The function returns the output_bbox list containing the rescaled bounding boxes with appended scores.
"""
scaling_factor = (4.083194675540765, 4.08)
output_bbox = []
x_ratio, y_ratio = scaling_factor
for i in range(len(classes)):
target_idx = results["labels"] == i
outputbbox = results["bboxes"][target_idx]
outputscore = results["scores"][target_idx]
ratio_array = np.array([x_ratio, y_ratio, x_ratio, y_ratio])
rescaled_boxes = outputbbox * ratio_array
scores = outputscore.reshape(-1, 1)
rescaled_boxes_with_scores = np.hstack([rescaled_boxes, scores])
output_bbox.append(rescaled_boxes_with_scores)
return output_bbox
def bbox_xyxy_to_ulxulywh(bbox):
"""
Convert bbox coordinates from (x1, y1, x2, y2) to (ulx, uly, w, h).
Args:
bbox (Tensor): Shape (n, 4) for bboxes.
Returns:
Tensor: Converted bboxes.
"""
x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
bbox_new = [x1, y1, x2-x1, y2 - y1]
return torch.cat(bbox_new, dim=-1)
def bbox_ulxulywh_to_xyxy(bbox):
"""
Convert bounding box coordinates from (upper-left, width, height) format to (x1, y1, x2, y2) format.
Args:
bbox (torch.Tensor): The input bounding box tensor in the (upper-left, width, height) format.
Returns:
torch.Tensor: The converted bounding box tensor in the (x1, y1, x2, y2) format.
Note:
- The input bounding box tensor is expected to have shape (N, 4), where N is the number of bounding boxes.
- Each bounding box is represented by the upper-left coordinates (x1, y1) and the width (w) and height (h) values.
- The output bounding box tensor will have shape (N, 4), where each bounding box is represented by the coordinates (x1, y1, x2, y2).
- The coordinates (x2, y2) are calculated by adding the width (w) to the x1 coordinate and the height (h) to the y1 coordinate.
"""
x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
bbox_new = [x1, y1, x2+x1, y2+y1]
return torch.cat(bbox_new, dim=-1)
class Det2json_converter:
"""
Utility class for converting detection results to COCO JSON format.
Args:
init_info (dict, optional): The information about the dataset. Default is None.
init_licenses (list, optional): The licenses for the dataset. Default is None.
init_categories (list, optional): The categories/classes in the dataset. Default is None.
convert_image_paths (bool, optional): Whether to convert image paths. Default is None.
Attributes:
info (dict): The information about the dataset.
licenses (list): The licenses for the dataset.
categories (list): The categories/classes in the dataset.
images (list): The list to store image information.
convert_image_paths (bool): Whether to convert image paths.
empty_count (int): The count of empty predictions.
Methods:
_get_img_info_from_path(img_path): Extracts image filename, width, and height from the image path.
_create_image_info(file_name, width, height, date_captured, license_id, coco_url, flickr_url, annotation):
Creates image information dictionary.
_create_annotation_info(score, category_id, is_crowd, area, bounding_box): Creates annotation information dictionary.
build_json_dict(images, annotations): Builds the JSON dictionary for COCO format.
clean_empty_count(): Resets the count of empty predictions.
add_result_into_pool(img_path, result): Adds an image and its corresponding result to the pool.
generate_coco_format_dataset(out_path, score_thr): Generates a COCO format dataset in JSON file.
Note:
- The class provides methods to convert detection results to COCO JSON format.
- The class is initialized with information about the dataset, licenses, categories, and image path conversion option.
- The conversion process involves creating image and annotation information dictionaries, counting items, calculating areas, and building the JSON dictionary.
- The converted dataset can be saved in a JSON file using the `generate_coco_format_dataset` method.
"""
def __init__(self, init_info=None, init_licenses=None, init_categories=None, convert_image_paths=None):
"""
Initializes the Det2json_converter class.
Args:
init_info (dict, optional): The information about the dataset. Default is None.
init_licenses (list, optional): The licenses for the dataset. Default is None.
init_categories (list, optional): The categories/classes in the dataset. Default is None.
convert_image_paths (bool, optional): Whether to convert image paths. Default is None.
"""
self.info = init_info
self.licenses = init_licenses
self.categories = init_categories
self.images = []
self.convert_image_paths = convert_image_paths
self.empty_count = 0
def _get_img_info_from_path(self, img_path):
"""
Extracts image filename, width, and height from the image path.
Args:
img_path (str): The path of the image.
Returns:
tuple: A tuple containing the filename, width, and height of the image.
"""
filename = img_path.split("/")[-1]
img = cv2.imread(img_path)
W, H = img.shape[0:2]
return filename, W, H
def _create_image_info(self, file_name, width, height,
date_captured=datetime.datetime.utcnow().isoformat(' '),
license_id=1, annotation=None):
"""
Creates the image information dictionary.
Args:
file_name (str): The filename of the image.
width (int): The width of the image.
height (int): The height of the image.
date_captured (str, optional): The date the image was captured. Default is the current UTC date and time.
license_id (int, optional): The license ID of the image. Default is 1.
annotation (list, optional): The list of annotation information dictionaries. Default is None.
Returns:
dict: The image information dictionary.
"""
image_info = {
"file_name": file_name,
"width": width,
"height": height,
"date_captured": date_captured,
"license": license_id,
"annotation": annotation
}
return image_info
def _create_annotation_info(self, score, category_id, is_crowd,
area, bounding_box):
"""
Creates the annotation information dictionary.
Args:
score (float): The score of the annotation.
category_id (int): The category ID of the annotation.
is_crowd (int): Flag indicating whether the annotation represents a crowd.
area (float): The area of the annotation.
bounding_box (list): The bounding box coordinates [x, y, width, height].
Returns:
dict: The annotation information dictionary.
"""
annotation_info = {
"score": score,
"category_id": category_id,
"iscrowd": is_crowd,
"area": area,
"bbox": bounding_box, # [x,y,width,height]
}
return annotation_info
def build_json_dict(self, images, annotations):
"""
Builds the final JSON dictionary for the COCO format dataset.
Args:
images (list): The list of image information dictionaries.
annotations (list): The list of annotation information dictionaries.
Returns:
collections.OrderedDict: The JSON dictionary for the COCO format dataset.
"""
return OrderedDict({'info': self.info,
'licenses': self.licenses,
'images': images,
'annotations': annotations,
'categories': self.categories})
def clean_empty_count(self):
"""
Resets the empty count variable to zero.
"""
self.empty_count = 0
def add_result_into_pool(self, img_path, result):
"""
Adds the result of an image into the pool of images and annotations.
Args:
img_path (str): The path of the image.
result (list): The result list containing the bounding box and score information.
Notes:
This method is used to add one image with its bounding box and score information into the pool.
"""
filename, W, H = self._get_img_info_from_path(img_path)
box = result
annotation_pool = []
found_empty = False
for class_idx in range(len(self.categories)):
if len(box[class_idx]) != 0:
for object_idx, bbox in enumerate(box[class_idx]):
is_crowd = 0
score = bbox[4]
bbox = torch.tensor(bbox[0:4])
bbox = bbox_xyxy_to_ulxulywh(bbox)
area = float(bbox[2] * bbox[3])
bbox = list(int(bb.item()) for bb in bbox)
anno = self._create_annotation_info(score, class_idx, is_crowd, area, bbox)
annotation_pool.append(anno)
img_info = self._create_image_info(filename, W, H, annotation=annotation_pool)
if found_empty:
print(f"Found empty prediction {self.empty_count}.")
self.images.append(img_info)
def generate_coco_format_dataset(self, out_path, score_thr=0.6):
"""
Generates the COCO format dataset in JSON file format.
Args:
out_path (str): The output file path for the JSON file.
score_thr (float): The score threshold for filtering annotations (default: 0.5).
Notes:
This method generates a JSON file in the COCO format based on the stored images and annotations.
"""
images = []
annotations = []
annotation_accumulator = 0
for img_id, image_info in enumerate(self.images):
assert image_info["annotation"] is not None, "Did not found any annotations."
for ann_id, ann_info in enumerate(image_info["annotation"]):
working_anno = OrderedDict(deepcopy(ann_info))
working_anno["score"] = float(working_anno["score"])
if score_thr is not None:
if working_anno["score"] < score_thr:
continue
del working_anno["score"]
working_anno["id"] = annotation_accumulator
working_anno["image_id"] = img_id
working_anno.move_to_end("image_id", last=False)
working_anno.move_to_end("id", last=False)
annotations.append(working_anno)
annotation_accumulator += 1
working_image = OrderedDict(deepcopy(image_info))
del working_image["annotation"]
working_image["id"] = img_id
working_image.move_to_end("id", last=False)
images.append(working_image)
if not out_path.endswith(".json"):
out_path = out_path + ".json"
out_image_annotation_json = self.build_json_dict(images, annotations)
with open(f"{out_path}", 'w') as file_obj:
json.dump(out_image_annotation_json, file_obj)
self.out_file = out_image_annotation_json
class modified_COCO(COCO):
def showAnns(
self,
anns, draw_bbox=False, palettes=None,
mask_fill_alpha=0.4, mask_linewidth=1, mask_line_alpha=0.5,
bbox_fill_alpha=1, bbox_linewidth=2, bbox_line_alpha=0.7
):
"""
Display the specified annotations.
Args:
anns (array of object): Annotations to display.
draw_bbox (bool): Flag to indicate whether to draw bounding boxes (default: False).
palettes (list or None): Color palettes for different categories (default: None).
mask_fill_alpha (float): Alpha value for mask fill color (default: 0.4).
mask_linewidth (int): Line width for mask boundaries (default: 1).
mask_line_alpha (float): Alpha value for mask boundary lines (default: 0.5).
bbox_fill_alpha (float): Alpha value for bounding box fill color (default: 1).
bbox_linewidth (int): Line width for bounding box boundaries (default: 2).
bbox_line_alpha (float): Alpha value for bounding box boundary lines (default: 0.7).
Returns:
None
Notes:
This method displays the specified annotations by visualizing masks, keypoints, and bounding boxes.
It supports both instance-level and caption-level annotations.
"""
if len(anns) == 0:
return 0
if 'segmentation' in anns[0] or 'keypoints' in anns[0] or 'bbox' in anns[0]:
datasetType = 'instances'
elif 'caption' in anns[0]:
datasetType = 'captions'
else:
raise Exception('datasetType not supported')
if datasetType == 'instances':
ax = plt.gca()
ax.set_autoscale_on(False)
bbox_polygons = []
bbox_color = []
for ann in anns:
if palettes is not None:
c = palettes[ann["category_id"]]
else:
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
if draw_bbox:
[bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
poly = [[bbox_x, bbox_y], [bbox_x, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y+bbox_h], [bbox_x+bbox_w, bbox_y]]
np_poly = np.array(poly).reshape((4,2))
bbox_polygons.append(Polygon(np_poly))
bbox_color.append(c)
if draw_bbox:
p = PatchCollection(bbox_polygons, facecolor=bbox_color, linewidths=0, alpha=bbox_fill_alpha)
ax.add_collection(p)
p = PatchCollection(bbox_polygons, facecolor='none', edgecolors=bbox_color, linewidths=bbox_linewidth, alpha=bbox_line_alpha)
ax.add_collection(p)
pass
elif datasetType == 'captions':
for ann in anns:
print(ann['caption'])
def json_to_excel(input_file_path, algae_count_holder):
"""
Convert JSON data to an Excel file.
Args:
input_file_path (str): Path to the input JSON file.
Returns:
str: Path to the output Excel file.
Notes:
This function reads data from a JSON file, processes it into a pandas DataFrame,
and saves the data into an Excel file. The Excel file is named based on the input
file name and includes a timestamp to indicate when the conversion was performed.
The function returns the path to the output Excel file.
"""
data = read_json_file(input_file_path)
df = process_data_into_dataframe(data, algae_count_holder)
df.index = df.index + 1
wb = Workbook()
ws = wb.active
for r_idx, row in enumerate(dataframe_to_rows(df, index=True, header=True), start=1):
if r_idx > 2:
r_idx -= 1
for c_idx, value in enumerate(row, start=1):
cell = ws.cell(row=r_idx, column=c_idx, value=value)
cell.alignment = Alignment(horizontal="center")
column_widths = [8.5, 22.4, 12.2, 17]
for i, width in enumerate(column_widths):
column = get_column_letter(i + 1)
ws.column_dimensions[column].width = width
timestamp = datetime.datetime.now().strftime('%d%b%Y_%Hh%Mm')
output_file = os.path.splitext(input_file_path)[0] + f'_algae_count_summary_{timestamp}.xlsx'
wb.save(output_file)
return output_file
def txt_to_excel(input_file_path, algae_count_holder, coco_path):
"""
Convert JSON data to an Excel file.
Args:
input_file_path (str): Path to the input JSON file.
Returns:
str: Path to the output Excel file.
Notes:
This function reads data from a JSON file, processes it into a pandas DataFrame,
and saves the data into an Excel file. The Excel file is named based on the input
file name and includes a timestamp to indicate when the conversion was performed.
The function returns the path to the output Excel file.
"""
data = read_json_file(input_file_path)
# Count the occurrences of each algae category based on the score threshold
# for idx, item in enumerate(data["annotations"]):
# if item["score"] > CFG["score_thr"]:
# algae_count_holder[str(item["category_id"])] += 1
for k, v in data.items():
algae_count_holder[str(k.split("(")[-1][:-1])] = int(v)
# Construct the dictionary for DataFrame creation
df_dict = {
"Class (English)": algae_list,
"Class (Chinese)": algae_list_zh,
"Count": [value if value != 0 else 0 for value in algae_count_holder.values()]
}
# Create a pandas DataFrame from the dictionary
df = pd.DataFrame(data=df_dict)
df.index = df.index + 1
wb = Workbook()
ws = wb.active
for r_idx, row in enumerate(dataframe_to_rows(df, index=True, header=True), start=1):
if r_idx > 2:
r_idx -= 1
for c_idx, value in enumerate(row, start=1):
cell = ws.cell(row=r_idx, column=c_idx, value=value)
cell.alignment = Alignment(horizontal="center")
column_widths = [8.5, 22.4, 12.2, 17]
for i, width in enumerate(column_widths):
column = get_column_letter(i + 1)
ws.column_dimensions[column].width = width
timestamp = datetime.datetime.now().strftime('%d%b%Y_%Hh%Mm')
output_file = os.path.splitext(coco_path)[0] + f'_algae_count_summary_{timestamp}.xlsx'
wb.save(output_file)
return output_file
def read_json_file(input_file_path):
"""
Read and load JSON data from a file.
Args:
input_file_path (str): Path to the input JSON file.
Returns:
dict: Loaded JSON data.
Notes:
This function reads a JSON file from the specified path and returns the loaded data
as a dictionary. It uses the `json.load` function to parse the JSON file and extract
the data. The loaded JSON data is returned by the function.
"""
with open(input_file_path) as f:
data = json.load(f)
return data
def read_json_file(input_file_path):
"""
Read and load JSON data from a file.
Args:
input_file_path (str): Path to the input JSON file.
Returns:
dict: Loaded JSON data.
Notes:
This function reads a JSON file from the specified path and returns the loaded data
as a dictionary. It uses the `json.load` function to parse the JSON file and extract
the data. The loaded JSON data is returned by the function.
"""
with open(input_file_path) as f:
data = json.load(f)
return data
def process_data_into_dataframe(data, algae_count_holder):
"""
Process the data and convert it into a pandas DataFrame.
Args:
data (dict): Input data containing annotations.
Returns:
pandas.DataFrame: Processed data as a DataFrame.
Notes:
This function takes the input data, iterates through the annotations, and counts the occurrences
of each algae category based on the score threshold specified in the CFG. The category counts are
stored in a dictionary.
The function then constructs a dictionary `df_dict` with the necessary columns and their values,
including the English and Chinese class names and the corresponding category counts. If a category
count is zero, it is replaced with 0 to ensure consistent data.
Finally, the function creates a pandas DataFrame `df` using the `df_dict` dictionary and returns it.
"""
# Count the occurrences of each algae category based on the score threshold
for idx, item in enumerate(data["annotations"]):
if item["score"] > CFG["score_thr"]:
algae_count_holder[str(item["category_id"])] += 1
# Construct the dictionary for DataFrame creation
df_dict = {
"Class (English)": algae_list,
"Class (Chinese)": algae_list_zh,
"Count": [value if value != 0 else 0 for value in algae_count_holder.values()]
}
# Create a pandas DataFrame from the dictionary
df = pd.DataFrame(data=df_dict)
return df
ClASSES = [
"Acanthoceras", "Achnanthes", "Actinastrum", "Anabaena", "Ankistrodesmus",
"Aphanizomenon", "Asterionella", "Aulacoseira", "Botryococcus", "Ceratium",
"Chlamydomonas", "Chlorella", "Chlorococcum", "Chroococcus", "Closterium",
"Cocconeis", "Coelastrum", "Cosmarium", "Crucigenia", "Cryptomonas",
"Cyclostephanos", "Cyclotella", "Cylindrospermopsis", "Cymbella", "Diatoma",
"Dictyosphaerium", "Dinobryon", "Elakatothrix", "Eudorina", "Euglena",
"Fragilaria", "Golenkinia", "Gomphonema", "Gomphosphaeria", "Gyrosigma",
"Kirchneriella", "Mallomonas", "Melosira", "Merismopedia", "Micractinium",
"Microcystis", "Mougeotia", "Navicula", "Nitzschia", "Oscillatoria",
"Pandorina", "Pediastrum", "Peridinium", "Phytoconis", "Pinnularia",
"Pseudanabaena", "Raphidiopsis", "Rhizosolenia", "Scenedesmus", "Snowella",
"Sphaerocystis", "Spirogyra", "Spondylosium", "Staurastrum", "Stauroneis",
"Stephanodiscus", "Synedra", "Synura", "Tetraedron", "Trachelomonas",
"Tribonema", "Ulothrix", "Volvox", "Zygnema", "Oocystis"
]
for idx in range(len(ClASSES)):
ClASSES[idx] = ClASSES[idx] + f"({idx+1})"
coco_output = {}
coco_output['info'] = {
"description": "Hong Kong Plover Cove Reservoir Algae",
"version": "1.0",
"year": 2023,
"contributor": "WSD",
"date_created": datetime.datetime.utcnow().isoformat(' ')
}
coco_output['licenses'] = [
{
"id": 1,
"name": "WSD"
}
]
supercategory = ["Algae"]
coco_output['images'] = []
coco_output['annotations'] = []
coco_output['categories'] = [
{
'id': idx,
'name': item,
'supercategory': supercategory[0],
} for idx, item in enumerate(ClASSES)
]
PALETTES = [
(0.0, 1.0, 0.0),
(1.0, 0.0, 1.0),
(0.0, 0.5, 1.0),
(1.0, 0.5, 0.0),
(0.5, 0.75, 0.5),
(0.4665116004964742, 0.11982097799124902, 0.5284368834373697),
(0.8514644228951851, 0.5004980751634254, 0.9568567731851599),
(1.0, 0.0, 0.0),
(0.11527826817032671, 0.9961973630797564, 0.9849162882782089),
(0.9734931270466471, 0.9979600377943707, 0.06968337081090525),
(0.0, 0.0, 1.0),
(0.009566364537970662, 0.43878127333066075, 0.30426214721825484),
(0.5151466813324389, 0.3270101219438506, 0.029192184298747592),
(0.0, 1.0, 0.5),
(0.9240457249001183, 0.24122630295393777, 0.4618655263001544),
(0.9951361941692213, 0.7146136378380459, 0.503472458829089),
(0.573222247660118, 0.8259873019868292, 0.9802929519550568),
(0.5, 1.0, 0.0),
(0.5, 0.0, 1.0),
(0.0, 0.0, 0.5)
]
PALETTES = [list(i) for i in PALETTES]
rare_algae_idx = [18,54,5,65,64,45,59,29,40,37,48,8,47]
common_algae_idx = [23,2,43,20,62,41,4]
remove_algae_idx = [6,17,21,32,38]
object_palettes = {item - 1:PALETTES[idx] for idx, item in enumerate(common_algae_idx + rare_algae_idx)}
CFG = {
"display_label": True,
"display_bbox": True,
"display_score": True,
"display_mask": False,
"score_thr": 0.6,
"bbox_fill_alpha": 0,
"bbox_linewidth": 1.5,
"bbox_line_alpha": 1,
"mask_fill_alpha": 0.2,
"mask_linewidth": 1,
"mask_line_alpha": 0.7,
"palettes": object_palettes,
"nms_thr": 0.1,
"adjust_text": True
}
algae_list = [
'Acanthoceras',
'Achnanthes',
'Actinastrum',
'Anabaena',
'Ankistrodesmus',
'Aphanizomenon',
'Asterionella',
'Aulacoseira',
'Botryococcus',
'Ceratium',
'Chlamydomonas',
'Chlorella',
'Chlorococcum',
'Chroococcus',
'Closterium',
'Cocconeis',
'Coelastrum',
'Cosmarium',
'Crucigenia',
'Cryptomonas',
'Cyclostephanos',
'Cyclotella',
'Cylindrospermopsis',
'Cymbella',
'Diatoma',
'Dictyosphaerium',
'Dinobryon',
'Elakatothrix',
'Eudorina',
'Euglena',
'Fragilaria',
'Golenkinia',
'Gomphonema',
'Gomphosphaeria',
'Gyrosigma',
'Kirchneriella',
'Mallomonas',
'Melosira',
'Merismopedia',
'Micractinium',
'Microcystis',
'Mougeotia',
'Navicula',
'Nitzschia',
'Oscillatoria',
'Pandorina',
'Pediastrum',
'Peridinium',
'Phytoconis',
'Pinnularia',
'Pseudanabaena',
'Raphidiopsis',
'Rhizosolenia',
'Scenedesmus',
'Snowella',
'Sphaerocystis',
'Spirogyra',
'Spondylosium',
'Staurastrum',
'Stauroneis',
'Stephanodiscus',
'Synedra',
'Synura',
'Tetraedron',
'Trachelomonas',
'Tribonema',
'Ulothrix',
'Volvox',
'Zygnema',
'Oocystis'
]
algae_list_zh = [
'四棘藻',
'線形曲殼藻',
'集星藻',
'魚腥藻',
'纖維藻',
'束絲藻',
'星杆藻',
'溝鏈藻',
'葡萄藻',
'角藻',
'衣藻',
'小球藻',
'綠球藻',
'色球藻',
'新月藻',
'卵形藻',
'空星藻',
'鼓藻',
'十字藻',
'隱藻',
'環冠藻',
'小環藻',
'柱孢藻',
'橋彎藻',
'等片藻',
'膠網藻',
'錐囊藻',
'紡錘藻',
'空球藻',
'裸藻',
'脆桿藻',
'多芒藻',
'異極藻',
'束球藻',
'布紋藻',
'蹄形藻',
'魚鱗藻',
'直鏈藻',
'平裂藻',
'微芒藻',
'微囊藻',
'轉板藻',
'舟形藻',
'菱形藻',
'顫藻屬',
'實球藻',
'盤星藻',
'多甲藻',
'原球藻',
'羽紋藻',
'偽魚腥藻',
'尖頭藻',
'根管藻',
'柵藻',
'小雪藻',
'球囊藻',
'水棉',
'頂接鼓藻',
'角星鼓藻',
'輻節藻',
'冠盤藻',
'針杆藻',
'黃群藻',
'四角藻',
'囊裸藻',
'黃絲藻',
'絲藻',
'團藻',
'雙星藻',
'卵囊藻'
]
def get_det2json_from_default():
"""
Get the default Det2json converter object.
Returns:
Det2json_converter: Det2json converter object.
Notes:
This function constructs a Det2json converter object using the default configuration. The converter
is responsible for converting detection results to COCO JSON format.
The function defines a list of classes called `ClASSES`, which represents the names of the algae
categories. Each class name is modified by appending its index in parentheses.
The function then constructs the `coco_output` dictionary, which contains the necessary information
for the COCO JSON format, including information about the dataset, licenses, images, annotations,
and categories. The categories are constructed based on the modified class names.
Finally, the Det2json converter object is created using the `coco_output` dictionary and returned.
"""
# Define the list of classes
ClASSES = [
"Acanthoceras", "Achnanthes", "Actinastrum", "Anabaena", "Ankistrodesmus",
"Aphanizomenon", "Asterionella", "Aulacoseira", "Botryococcus", "Ceratium",
"Chlamydomonas", "Chlorella", "Chlorococcum", "Chroococcus", "Closterium",
"Cocconeis", "Coelastrum", "Cosmarium", "Crucigenia", "Cryptomonas",
"Cyclostephanos", "Cyclotella", "Cylindrospermopsis", "Cymbella", "Diatoma",
"Dictyosphaerium", "Dinobryon", "Elakatothrix", "Eudorina", "Euglena",
"Fragilaria", "Golenkinia", "Gomphonema", "Gomphosphaeria", "Gyrosigma",
"Kirchneriella", "Mallomonas", "Melosira", "Merismopedia", "Micractinium",
"Microcystis", "Mougeotia", "Navicula", "Nitzschia", "Oscillatoria",
"Pandorina", "Pediastrum", "Peridinium", "Phytoconis", "Pinnularia",
"Pseudanabaena", "Raphidiopsis", "Rhizosolenia", "Scenedesmus", "Snowella",
"Sphaerocystis", "Spirogyra", "Spondylosium", "Staurastrum", "Stauroneis",
"Stephanodiscus", "Synedra", "Synura", "Tetraedron", "Trachelomonas",
"Tribonema", "Ulothrix", "Volvox", "Zygnema", "Oocystis"
]
# Modify class names by appending index in parentheses
for idx in range(len(ClASSES)):
ClASSES[idx] = ClASSES[idx] + f"({idx+1})"
# Construct the coco_output dictionary
coco_output = {}
coco_output['info'] = {
"description": "Hong Kong Plover Cove Reservoir Algae",
"version": "1.0",
"year": 2023,
"contributor": "WSD",
"date_created": datetime.datetime.utcnow().isoformat(' ')
}
coco_output['licenses'] = [
{
"id": 1,
"name": "WSD"
}
]
supercategory = ["Algae"]
coco_output['images'] = []
coco_output['annotations'] = []
coco_output['categories'] = [
{
'id': idx,
'name': item,
'supercategory': supercategory[0],
} for idx, item in enumerate(ClASSES)
]
# Create and return the Det2json converter object
det2json = Det2json_converter(
coco_output['info'],
coco_output['licenses'],
coco_output['categories'],
convert_image_paths=None
)
return det2json
def predict_batch(images, genImg, progress=gr.Progress(track_tqdm=True)):
start = time.time()
preds = []
labelled_image_paths = []
cwd = os.getcwd() + '/'
det2json = get_det2json_from_default()
for image_tmp in tqdm.tqdm(images, desc="Performing Detection on Images"):
image_path = image_tmp.name
results = predict_one_image(image_path)
processed_results = results_postprocess(results=results, classes=ClASSES)
det2json.add_result_into_pool(image_path, processed_results)
det2json.generate_coco_format_dataset(f"{cwd}detailed_inference_results_coco.json", score_thr=CFG["score_thr"])
save_dir = f'{cwd}resultant_images'
if not os.path.exists(save_dir):
os.makedirs(save_dir)
coco_results = modified_COCO(annotation_file=f"{cwd}detailed_inference_results_coco.json")
algae_summary = []
for idx, image_tmp in enumerate(tqdm.tqdm(images, desc="Summarizing Results")):
img_path = image_tmp.name
file_name, _, _ = det2json._get_img_info_from_path(img_path)
img_info = coco_results.loadImgs(ids=idx)[0]
assert file_name == img_info['file_name']
anns_ids = coco_results.getAnnIds(
imgIds=idx,
catIds=list(
CFG["palettes"].keys()
)
)
anns_info = coco_results.loadAnns(anns_ids)
if CFG["score_thr"] is not None:
anns_info_copy = copy.deepcopy(anns_info)
del_idx = []
print(f"Out Loop: ann_info: {ann_info}")
for idx, ann_info in enumerate(anns_info):
print(f"In Loop: ann_info: {ann_info}")
print(f"ann_info Keys: {ann_info.keys()}")
if ann_info["score"] < CFG["score_thr"]:
del_idx.append(idx)
anns_info_copy = [val for idx, val in enumerate(anns_info_copy) if idx not in del_idx]
anns_info = anns_info_copy
anns_info = sorted(anns_info, key=lambda d: d["score"], reverse=False)
bboxes = []
scores = []
for ann in anns_info:
if ann["score"] >= CFG["score_thr"]:
print(f'Keep: {ann["bbox"]} with score = {ann["score"]}; score_thr = {CFG["score_thr"]}')
bboxes.append(ann["bbox"])
scores.append(ann["score"])
else:
print(f'Removed {ann["bbox"]} with score = {ann["score"]}')
bboxes = torch.tensor(bboxes, dtype=torch.float32)
bboxes = bbox_ulxulywh_to_xyxy(bboxes)
scores = torch.tensor(scores, dtype=torch.float32)
print(f"Bounding Boxes (x1, y1, x2, y2):\n{bboxes}")
print(f"Scores:\n{scores}")
assert (bboxes[:, 0] <= bboxes[:, 2]).all() and (bboxes[:, 1] <= bboxes[:, 3]).all(), "Bounding boxes are not in the correct format!"
print(f"NMS Original: {len(anns_info)}")
print(f"bboxes before NMS: {bboxes.shape}")
print(f"scores before NMS: {scores.shape}")
print(f"iou_threshold: {CFG['nms_thr']}")
nms_results = nms(boxes = bboxes, scores = scores, iou_threshold=CFG["nms_thr"]).tolist()
print(f"nms_results: {len(nms_results)}")
print(f"nms_results: {nms_results}")
nms_indexes = nms(boxes = bboxes, scores = scores, iou_threshold=CFG["nms_thr"]).tolist()[::-1]
print(f"nms_indexes: {nms_indexes}")
anns_info = [anns_info[keep_index] for keep_index in nms_indexes]
print(f"After NMS: {len(anns_info)}")
algae_summary.append([ann['category_id'] for ann in anns_info])
if genImg:
fig, ax = plt.subplots(figsize=(49.08, 32.64))
img = plt.imread(img_path)
plt.imshow(img)
coco_results.showAnns(
anns_info,
draw_bbox=True,
palettes=CFG["palettes"],
mask_fill_alpha=CFG["mask_fill_alpha"],
mask_linewidth=CFG["mask_linewidth"],
mask_line_alpha=CFG["mask_line_alpha"],
bbox_fill_alpha=CFG["bbox_fill_alpha"],
bbox_linewidth=CFG["bbox_linewidth"],
bbox_line_alpha=CFG["bbox_line_alpha"],
)
if CFG["display_label"] or CFG["display_score"]:
texts = []
for i, ann in enumerate(anns_info):
text = ""
if CFG["display_label"]:
text += f"{coco_results.loadCats(ann['category_id'])[0]['name']}"
if CFG["display_score"]:
text += " , "
if CFG["display_score"]:
text += f"{ann['score']:.3f}"
ha = "right"
x = ann['bbox'][0]
if x < img_info["height"]/2:
ha = "left"
x += ann['bbox'][2]
va = "bottom"
y = ann['bbox'][1]
if y < img_info["width"]/2:
va = "top"
y += ann['bbox'][3]
texts.append(
plt.text(
x=x,
y=y,
s=text,
ha=ha,
va=va,
fontsize=30,
bbox={
'facecolor': 'white',
'edgecolor': CFG["palettes"][ann["category_id"]],
'alpha': 0.8,
'pad': 0,
'boxstyle':"round,pad=0.2"
}
)
)
if CFG["adjust_text"]:
rect_objects =[]
bbox_rect = []
for ann in anns_info:
bbox_rect.append(ann["bbox"])
for bbox in bbox_rect:
patch = patches.Rectangle((bbox[0], bbox[1]), bbox[2], bbox[3], fill="k", alpha=0)
rect_objects.append(patch)
ax.add_patch(patch)
adjust_text(
texts,
only_move={'points':'xy', 'texts':'xy', 'objects':''},
add_objects=rect_objects,
)
plt.axis('off')
plt.savefig(f"{save_dir}/result_{file_name}", dpi=100, bbox_inches='tight', pad_inches=0)
plt.close()
algae_summary_np = np.array([item for sublist in algae_summary for item in sublist])
algae_count = np.unique(algae_summary_np, return_counts=True)
output_dict = {coco_results.loadCats(ids=int(algae_id))[0]["name"]: int(count) for algae_id, count in np.array(algae_count).T}
with open(f"{save_dir}/summary_for_resultant_images.txt", "w") as f:
json.dump(output_dict, f, indent=4)
algae_count_holder = {f"{k+1}": 0 for k in range(len(algae_list))}
if os.path.exists(f"{cwd}resultant_images/summary_for_resultant_images.txt"):
csv_file_path = txt_to_excel(
f"{cwd}resultant_images/summary_for_resultant_images.txt",
algae_count_holder,
coco_path = f"{cwd}detailed_inference_results_coco.json"
)
else:
csv_file_path = json_to_excel(f"{cwd}detailed_inference_results_coco.json", algae_count_holder)
file_lists = glob.glob(f"{cwd}*.json") + glob.glob(f"{cwd}*.xlsx")
with ZipFile('results.zip', 'w') as zipObj:
for file_path in glob.glob(f"{cwd}resultant_images/*"):
zipObj.write(file_path, os.path.join(*file_path.split("/")[-2:]))
for file_path in file_lists:
zipObj.write(file_path, file_path.split("/")[-1])
destroy_file = glob.glob(f"{cwd}*.json") + glob.glob(f"{cwd}*.xlsx") + glob.glob(f"{cwd}resultant_images/*")
for file_path in destroy_file:
if os.path.isfile(file_path):
os.remove(file_path)
end = time.time()
inference_time = int(end - start)
return 'results.zip', inference_time
title="HKUST Automated Algae Detection System"
description="Automated Algae Detection System for freshwater algae developed by HKUST"
interface = gr.Interface(
fn=predict_batch,
inputs=[gr.File(file_count="multiple", file_types=["image"], label="Input"), gr.Checkbox(label="Generate resultant images", info="Whether to generate the resultant images?")],
outputs=[
gr.File(label="Output").style(height='6'),
gr.Number(label="Inference Time (sec):"),
],
title=title,
description=description,
)
interface.queue().launch()