File size: 3,574 Bytes
f563d24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import numpy as np
import gradio as gr
import os
from transformers import AutoModel, AutoTokenizer
import torch
from PIL import Image
import warnings
import re
# Suppress warnings
warnings.simplefilter("ignore")
# Retrieve Hugging Face token
hf_token = os.getenv("HF_TOKEN")
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_auth_token=hf_token)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True,
low_cpu_mem_usage=True,
device_map='cuda' if torch.cuda.is_available() else 'cpu',
use_safetensors=True,
pad_token_id=tokenizer.eos_token_id,
use_auth_token=hf_token)
model = model.eval()
# Global variable to store OCR result
ocr_result = ""
# Perform OCR function
def perform_ocr(image):
global ocr_result
# Convert the numpy array to a PIL image
pil_image = Image.fromarray(image)
# Save the image temporarily
image_file = "temp_image.png"
pil_image.save(image_file)
# Perform OCR with the model
with torch.no_grad():
ocr_result = model.chat(tokenizer, image_file, ocr_type='ocr')
# Optionally remove the temporary image file
os.remove(image_file)
return ocr_result
# Function to highlight search term with a different color (e.g., light blue)
def highlight_text(text, query):
# Use regex to wrap the search query with a span for styling
pattern = re.compile(re.escape(query), re.IGNORECASE)
highlighted_text = pattern.sub(f"<span style='background-color: #ADD8E6; color: black;'>{query}</span>", text)
return highlighted_text
# Search functionality to search within OCR result, highlight, and return the modified text
def search_text(query):
# If no query is provided, return the original OCR result
if not query:
return ocr_result, "No matches found."
# Highlight the searched term in the OCR text
highlighted_result = highlight_text(ocr_result, query)
# Split OCR result into lines and search for the query
lines = ocr_result.split('\n')
matching_lines = [line for line in lines if query.lower() in line.lower()]
if matching_lines:
return highlighted_result, '\n'.join(matching_lines) # Return highlighted text and matched lines
else:
return highlighted_result, "No matches found."
# Set up Gradio interface
with gr.Blocks() as demo:
# Section for uploading image and getting OCR results
with gr.Row():
with gr.Column():
image_input = gr.Image(type="numpy", label="Upload Image")
ocr_output = gr.HTML(label="OCR Output") # Changed to HTML for displaying highlighted text
ocr_button = gr.Button("Run OCR")
# Section for searching within the OCR result
with gr.Row():
with gr.Column():
search_input = gr.Textbox(label="Search Text")
search_output = gr.HTML(label="Search Result") # Separate output for search matches
search_button = gr.Button("Search in OCR Text")
# Define button actions
ocr_button.click(perform_ocr, inputs=image_input, outputs=ocr_output)
search_button.click(search_text, inputs=search_input, outputs=[ocr_output, search_output])
# Launch the Gradio interface
demo.launch(share=True)
|