File size: 3,574 Bytes
f563d24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import numpy as np
import gradio as gr
import os
from transformers import AutoModel, AutoTokenizer
import torch
from PIL import Image
import warnings
import re

# Suppress warnings
warnings.simplefilter("ignore")

# Retrieve Hugging Face token
hf_token = os.getenv("HF_TOKEN")

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, use_auth_token=hf_token)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, 
                                  low_cpu_mem_usage=True, 
                                  device_map='cuda' if torch.cuda.is_available() else 'cpu', 
                                  use_safetensors=True, 
                                  pad_token_id=tokenizer.eos_token_id, 
                                  use_auth_token=hf_token)
model = model.eval()

# Global variable to store OCR result
ocr_result = ""

# Perform OCR function
def perform_ocr(image):
    global ocr_result
    
    # Convert the numpy array to a PIL image
    pil_image = Image.fromarray(image)
    
    # Save the image temporarily
    image_file = "temp_image.png"
    pil_image.save(image_file)
    
    # Perform OCR with the model
    with torch.no_grad():
        ocr_result = model.chat(tokenizer, image_file, ocr_type='ocr')
    
    # Optionally remove the temporary image file
    os.remove(image_file)
    
    return ocr_result

# Function to highlight search term with a different color (e.g., light blue)
def highlight_text(text, query):
    # Use regex to wrap the search query with a span for styling
    pattern = re.compile(re.escape(query), re.IGNORECASE)
    highlighted_text = pattern.sub(f"<span style='background-color: #ADD8E6; color: black;'>{query}</span>", text)
    return highlighted_text

# Search functionality to search within OCR result, highlight, and return the modified text
def search_text(query):
    # If no query is provided, return the original OCR result
    if not query:
        return ocr_result, "No matches found."

    # Highlight the searched term in the OCR text
    highlighted_result = highlight_text(ocr_result, query)
    
    # Split OCR result into lines and search for the query
    lines = ocr_result.split('\n')
    matching_lines = [line for line in lines if query.lower() in line.lower()]
    
    if matching_lines:
        return highlighted_result, '\n'.join(matching_lines)  # Return highlighted text and matched lines
    else:
        return highlighted_result, "No matches found."

# Set up Gradio interface
with gr.Blocks() as demo:
    # Section for uploading image and getting OCR results
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="numpy", label="Upload Image")
            ocr_output = gr.HTML(label="OCR Output")  # Changed to HTML for displaying highlighted text
            ocr_button = gr.Button("Run OCR")
    
    # Section for searching within the OCR result
    with gr.Row():
        with gr.Column():
            search_input = gr.Textbox(label="Search Text")
            search_output = gr.HTML(label="Search Result")  # Separate output for search matches
            search_button = gr.Button("Search in OCR Text")
    
    # Define button actions
    ocr_button.click(perform_ocr, inputs=image_input, outputs=ocr_output)
    search_button.click(search_text, inputs=search_input, outputs=[ocr_output, search_output])

# Launch the Gradio interface
demo.launch(share=True)