JAP-OCR / app.py
Snowad's picture
Update app.py
8f8226b verified
raw history blame
No virus
1.81 kB
import re
import jaconv
import gradio as gr
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
from PIL import Image
import torch, time
import spaces
tokenizer = AutoTokenizer.from_pretrained("kha-white/manga-ocr-base")
model = VisionEncoderDecoderModel.from_pretrained("kha-white/manga-ocr-base")
model.to("cuda")
feature_extractor = AutoFeatureExtractor.from_pretrained("kha-white/manga-ocr-base")
def post_process(text):
text = ''.join(text.split())
text = text.replace('…', '...')
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
text = jaconv.h2z(text, ascii=True, digit=True)
return text
# @spaces.GPU
# def manga_ocr(img):
# img = img.convert('L').convert('RGB')
# pixel_values = feature_extractor(img, return_tensors="pt").pixel_values.to("cuda")
# start_time = time.time()
# output = model.generate(pixel_values)[0]
# print("Time taken for OCR:", time.time() - start_time)
# text = tokenizer.decode(output, skip_special_tokens=True)
# text = post_process(text)
# return text
@spaces.GPU(duration=8)
def manga_ocr(imgs):
texts = []
for img in imgs:
img = Image.open(img)
img = img.convert('L').convert('RGB')
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values.to("cuda")
start_time = time.time()
output = model.generate(pixel_values)[0]
print("Time taken for OCR:", time.time() - start_time)
text = tokenizer.decode(output, skip_special_tokens=True)
text = post_process(text)
texts.append(text)
return "|||".join(texts)
iface = gr.Interface(
fn=manga_ocr,
inputs=gr.File(file_types=["image"], file_count="multiple"),
outputs="text",
title="Manga OCR",
description="Extract Manga in lighting speed ⚡",
)
iface.launch()