HiraganaOCR / app.py
sabarinathan
Update app.py
a2855b2
raw
history blame
3.56 kB
import cv2
import math
import argparse
from tensorflow.keras.models import load_model
from flask import Flask, request, jsonify
import cv2
import json
import numpy as np
from tensorflow.keras import backend as K
from get_coordinate import get_object_coordinates
import requests
import gradio as gr
import os
file_urls = [
'https://www.dropbox.com/scl/fi/skt4o9a37ccrxvruojk3o/2.png?rlkey=kxppvdnvbs9852rj6ly123xfk&dl=0',
'https://www.dropbox.com/scl/fi/3opkr5aoca1fq0wrudlcx/3.png?rlkey=wm4vog7yyk5naoqu68vr6v48s&dl=0',
'https://www.dropbox.com/scl/fi/t74nd09fod52x0gua93ty/1.png?rlkey=er4ktuephlapzyvh5glkym5b4&dl=0']
def download_file(url, save_name):
url = url
if not os.path.exists(save_name):
file = requests.get(url)
open(save_name, 'wb').write(file.content)
for i, url in enumerate(file_urls):
if 'mp4' in file_urls[i]:
download_file(
file_urls[i],
f"video.mp4"
)
else:
download_file(
file_urls[i],
f"image_{i}.jpg"
)
class OCR():
def __init__(self,path="model-ocr-0.1829.h5",config_path="config.json"):
# Read the config JSON file
with open(config_path, 'r',encoding="utf-8") as file:
self.config_data = json.load(file)
# Get the threshold value
self.threshold = self.config_data['hiragana']['threshold']
# Get the label dictionary
self.label_dict = self.config_data['hiragana']['label']
# load the model from local
self.model = load_model(path,custom_objects={"K": K})
def run(self,image):
# extract the character coordinates using the cv2 contours
coordinate,thresholdedImage = get_object_coordinates(image)
image_batch = np.zeros((1,64,64,1))
output =[]
for row in range(len(coordinate)):
temp = {}
# crop the image
cropImg = thresholdedImage[coordinate[row][1]:coordinate[row][3],coordinate[row][0]:coordinate[row][2]]
# resize the image
image_batch[0,:,:,0] = cv2.resize(cropImg,(64,64))*255
# predict the results
predict = self.model.predict(image_batch)
position = np.argmax(predict)
label_name = self.label_dict[str(position)]
temp["text"] = label_name
temp["prob"] = predict[position]
temp["coord"] = coordinate[row] # Xmin,Ymin,Xmax,Ymax
output.append(temp)
return output
def getOCRResults(image_path):
temp0 =[]
for i in range(len(image_path)):
image = cv2.imread(image_path[i])
results0 = ocrAPP.run(image)
temp0.append(results0)
result_json={}
result_json["result"] = temp0
response = jsonify(result_json)
response.headers['Content-Type'] = 'application/json; charset=utf-8'
return response
ocrAPP = OCR()
video_path = [['video.mp4']]
path = [['image_0.jpg'], ['image_1.jpg']]
inputs_image = [
gr.components.Image(type="filepath", label="Input Image"),
]
outputs = [
gr.components.JSON(label="Output Json"),
]
interface_image = gr.Interface(
fn=getOCRResults,
inputs=inputs_image,
outputs=outputs,
title="Hiragana Character Recognition",
examples=path,
cache_examples=False,
)
gr.TabbedInterface(
[interface_image],
tab_names=['Image inference']
).queue().launch()