File size: 2,273 Bytes
1523e8d
 
9e9c8bb
1523e8d
 
 
 
262c67d
1523e8d
 
 
262c67d
1523e8d
 
 
262c67d
1523e8d
9e9c8bb
1523e8d
262c67d
1523e8d
 
262c67d
 
1523e8d
262c67d
1523e8d
262c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e9c8bb
 
1523e8d
262c67d
1523e8d
262c67d
1523e8d
 
262c67d
 
1523e8d
0d32855
262c67d
1523e8d
 
 
262c67d
1523e8d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import requests
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from io import BytesIO
import torch

# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load model and processor
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

# Prediction function
def predict_from_url(url):
    prompt = "<OCR>"
    if not url:
        return "Error: Please input a URL", None

    try:
        # Open the image and convert to RGB format to handle grayscale or other formats
        image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
    except Exception as e:
        return f"Error: Failed to load or process the image: {str(e)}", None
    
    # Preprocess and perform inference
    try:
        inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=4096,
            num_beams=3,
            do_sample=False
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
        
        # Extract OCR text
        ocr_text = parsed_answer.get("<OCR>", "")
        output_text = " ".join(ocr_text.replace("\n", " ").strip().split())
    except Exception as e:
        return f"Error: Failed to process the image for OCR: {str(e)}", image

    return output_text, image

# Gradio Interface
demo = gr.Interface(
    fn=predict_from_url,
    inputs=gr.Textbox(label="Enter Image URL"),
    outputs=[
        gr.Textbox(label="Extracted Text"),
        gr.Image(label="Uploaded Image")
    ],
    title="OCR Text Extractor",
    description="Provide an image URL, and this tool will extract text using OCR.",
    allow_flagging="never"
)

# Launch the app
demo.launch()