|
import re |
|
import jaconv |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
import torch |
|
import spaces |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("kha-white/manga-ocr-base") |
|
|
|
model = VisionEncoderDecoderModel.from_pretrained("kha-white/manga-ocr-base") |
|
model.to("cuda") |
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained("kha-white/manga-ocr-base") |
|
|
|
examples = ["00.jpg", "01.jpg", "02.jpg", "03.jpg", "04.jpg", "05.jpg", "06.jpg", "07.jpg", "08.jpg", "09.jpg", "10.jpg", "11.jpg"] |
|
|
|
def post_process(text): |
|
text = ''.join(text.split()) |
|
text = text.replace('…', '...') |
|
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text) |
|
text = jaconv.h2z(text, ascii=True, digit=True) |
|
return text |
|
|
|
@spaces.GPU |
|
def manga_ocr(img): |
|
img = img.convert('L').convert('RGB') |
|
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values |
|
output = model.generate(pixel_values)[0] |
|
text = tokenizer.decode(output, skip_special_tokens=True) |
|
text = post_process(text) |
|
return text |
|
|
|
iface = gr.Interface( |
|
fn=manga_ocr, |
|
inputs=gr.Image(type='pil'), |
|
outputs="text", |
|
title="Manga OCR", |
|
description="Extract Manga in lighting speed ⚡", |
|
) |
|
|
|
iface.launch() |