|
import numpy as np
|
|
import gradio as gr
|
|
import os
|
|
from transformers import AutoModel, AutoTokenizer
|
|
import torch
|
|
from PIL import Image
|
|
import warnings
|
|
import re
|
|
|
|
|
|
warnings.simplefilter("ignore")
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN")
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_auth_token=hf_token)
|
|
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
device_map='cuda' if torch.cuda.is_available() else 'cpu',
|
|
use_safetensors=True,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
use_auth_token=hf_token)
|
|
model = model.eval()
|
|
|
|
|
|
ocr_result = ""
|
|
|
|
|
|
def perform_ocr(image):
|
|
global ocr_result
|
|
|
|
|
|
pil_image = Image.fromarray(image)
|
|
|
|
|
|
image_file = "temp_image.png"
|
|
pil_image.save(image_file)
|
|
|
|
|
|
with torch.no_grad():
|
|
ocr_result = model.chat(tokenizer, image_file, ocr_type='ocr')
|
|
|
|
|
|
os.remove(image_file)
|
|
|
|
return ocr_result
|
|
|
|
|
|
def highlight_text(text, query):
|
|
|
|
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
|
highlighted_text = pattern.sub(f"<span style='background-color: #ADD8E6; color: black;'>{query}</span>", text)
|
|
return highlighted_text
|
|
|
|
|
|
def search_text(query):
|
|
|
|
if not query:
|
|
return ocr_result, "No matches found."
|
|
|
|
|
|
highlighted_result = highlight_text(ocr_result, query)
|
|
|
|
|
|
lines = ocr_result.split('\n')
|
|
matching_lines = [line for line in lines if query.lower() in line.lower()]
|
|
|
|
if matching_lines:
|
|
return highlighted_result, '\n'.join(matching_lines)
|
|
else:
|
|
return highlighted_result, "No matches found."
|
|
|
|
|
|
with gr.Blocks() as demo:
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
image_input = gr.Image(type="numpy", label="Upload Image")
|
|
ocr_output = gr.HTML(label="OCR Output")
|
|
ocr_button = gr.Button("Run OCR")
|
|
|
|
|
|
with gr.Row():
|
|
with gr.Column():
|
|
search_input = gr.Textbox(label="Search Text")
|
|
search_output = gr.HTML(label="Search Result")
|
|
search_button = gr.Button("Search in OCR Text")
|
|
|
|
|
|
ocr_button.click(perform_ocr, inputs=image_input, outputs=ocr_output)
|
|
search_button.click(search_text, inputs=search_input, outputs=[ocr_output, search_output])
|
|
|
|
|
|
demo.launch(share=True)
|
|
|