fixed issues with image res
Browse files
app.py
CHANGED
@@ -5,51 +5,59 @@ from transformers import AutoProcessor, AutoModelForCausalLM
|
|
5 |
from io import BytesIO
|
6 |
import torch
|
7 |
|
|
|
8 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
9 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
10 |
|
|
|
11 |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
|
12 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
13 |
|
14 |
-
|
15 |
def predict_from_url(url):
|
16 |
prompt = "<OCR>"
|
17 |
if not url:
|
18 |
-
return "Error: Please input a URL"
|
19 |
|
20 |
try:
|
21 |
-
image
|
|
|
22 |
except Exception as e:
|
23 |
-
return f"Error: Failed to load image: {str(e)}"
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
42 |
return output_text, image
|
43 |
|
|
|
44 |
demo = gr.Interface(
|
45 |
-
fn=predict_from_url,
|
46 |
inputs=gr.Textbox(label="Enter Image URL"),
|
47 |
outputs=[
|
48 |
-
gr.Textbox(label="
|
49 |
-
gr.Image(label="Image")
|
50 |
],
|
51 |
-
title="OCR Text Extractor",
|
|
|
52 |
allow_flagging="never"
|
53 |
)
|
54 |
|
|
|
55 |
demo.launch()
|
|
|
5 |
from io import BytesIO
|
6 |
import torch
|
7 |
|
8 |
+
# Set device
|
9 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
11 |
|
12 |
+
# Load model and processor
|
13 |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
|
14 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
15 |
|
16 |
+
# Prediction function
|
17 |
def predict_from_url(url):
|
18 |
prompt = "<OCR>"
|
19 |
if not url:
|
20 |
+
return "Error: Please input a URL", None
|
21 |
|
22 |
try:
|
23 |
+
# Open the image and convert to RGB format to handle grayscale or other formats
|
24 |
+
image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
|
25 |
except Exception as e:
|
26 |
+
return f"Error: Failed to load or process the image: {str(e)}", None
|
27 |
|
28 |
+
# Preprocess and perform inference
|
29 |
+
try:
|
30 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
31 |
+
generated_ids = model.generate(
|
32 |
+
input_ids=inputs["input_ids"],
|
33 |
+
pixel_values=inputs["pixel_values"],
|
34 |
+
max_new_tokens=4096,
|
35 |
+
num_beams=3,
|
36 |
+
do_sample=False
|
37 |
+
)
|
38 |
+
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
39 |
+
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
|
40 |
+
|
41 |
+
# Extract OCR text
|
42 |
+
ocr_text = parsed_answer.get("<OCR>", "")
|
43 |
+
output_text = " ".join(ocr_text.replace("\n", " ").strip().split())
|
44 |
+
except Exception as e:
|
45 |
+
return f"Error: Failed to process the image for OCR: {str(e)}", image
|
46 |
|
47 |
return output_text, image
|
48 |
|
49 |
+
# Gradio Interface
|
50 |
demo = gr.Interface(
|
51 |
+
fn=predict_from_url,
|
52 |
inputs=gr.Textbox(label="Enter Image URL"),
|
53 |
outputs=[
|
54 |
+
gr.Textbox(label="Extracted Text"),
|
55 |
+
gr.Image(label="Uploaded Image")
|
56 |
],
|
57 |
+
title="Enhanced OCR Text Extractor",
|
58 |
+
description="Provide an image URL, and this tool will extract text using OCR.",
|
59 |
allow_flagging="never"
|
60 |
)
|
61 |
|
62 |
+
# Launch the app
|
63 |
demo.launch()
|