updated ocr
Browse files
app.py
CHANGED
@@ -1,31 +1,26 @@
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
-
from PIL import Image
|
4 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
5 |
from io import BytesIO
|
6 |
import torch
|
7 |
|
8 |
-
# Set device
|
9 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
10 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
11 |
|
12 |
-
# Load model and processor
|
13 |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
|
14 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
15 |
|
16 |
-
# List of colors to cycle through for bounding boxes
|
17 |
-
COLORS = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"]
|
18 |
|
19 |
-
# Prediction function
|
20 |
def predict_from_url(url):
|
21 |
-
prompt = "<
|
22 |
if not url:
|
23 |
-
return
|
24 |
|
25 |
try:
|
26 |
image = Image.open(BytesIO(requests.get(url).content))
|
27 |
except Exception as e:
|
28 |
-
return
|
29 |
|
30 |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
31 |
generated_ids = model.generate(
|
@@ -36,34 +31,25 @@ def predict_from_url(url):
|
|
36 |
do_sample=False
|
37 |
)
|
38 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
39 |
-
parsed_answer = processor.post_process_generation(generated_text, task=
|
40 |
|
41 |
-
labels = parsed_answer.get('<OD>', {}).get('labels', [])
|
42 |
-
bboxes = parsed_answer.get('<OD>', {}).get('bboxes', [])
|
43 |
-
|
44 |
-
# Draw bounding boxes on the image
|
45 |
-
draw = ImageDraw.Draw(image)
|
46 |
-
legend = [] # Store legend entries
|
47 |
-
for idx, (bbox, label) in enumerate(zip(bboxes, labels)):
|
48 |
-
x1, y1, x2, y2 = bbox
|
49 |
-
color = COLORS[idx % len(COLORS)] # Cycle through colors
|
50 |
-
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
|
51 |
-
legend.append(f"{label}: {color}")
|
52 |
-
|
53 |
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
# Gradio interface
|
57 |
demo = gr.Interface(
|
58 |
fn=predict_from_url,
|
59 |
inputs=gr.Textbox(label="Enter Image URL"),
|
60 |
outputs=[
|
61 |
-
gr.Textbox(label="
|
62 |
-
gr.Image(label="Image
|
63 |
],
|
64 |
-
title="
|
65 |
allow_flagging="never"
|
66 |
)
|
67 |
|
68 |
-
# Launch the interface
|
69 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import requests
|
3 |
+
from PIL import Image
|
4 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
5 |
from io import BytesIO
|
6 |
import torch
|
7 |
|
|
|
8 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
9 |
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
10 |
|
|
|
11 |
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
|
12 |
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
|
13 |
|
|
|
|
|
14 |
|
|
|
15 |
def predict_from_url(url):
|
16 |
+
prompt = "<OCR>"
|
17 |
if not url:
|
18 |
+
return "Error: Please input a URL"
|
19 |
|
20 |
try:
|
21 |
image = Image.open(BytesIO(requests.get(url).content))
|
22 |
except Exception as e:
|
23 |
+
return f"Error: Failed to load image: {str(e)}"
|
24 |
|
25 |
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
|
26 |
generated_ids = model.generate(
|
|
|
31 |
do_sample=False
|
32 |
)
|
33 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
34 |
+
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
+
ocr_text = parsed_answer.get("<OCR>", "")
|
38 |
+
|
39 |
+
text = ocr_text.replace("\n", " ").strip()
|
40 |
+
output_text = " ".join(text.split())
|
41 |
+
|
42 |
+
return output_text, image
|
43 |
|
|
|
44 |
demo = gr.Interface(
|
45 |
fn=predict_from_url,
|
46 |
inputs=gr.Textbox(label="Enter Image URL"),
|
47 |
outputs=[
|
48 |
+
gr.Textbox(label="Output Text"),
|
49 |
+
gr.Image(label="Image")
|
50 |
],
|
51 |
+
title="OCR Text Extractor",
|
52 |
allow_flagging="never"
|
53 |
)
|
54 |
|
|
|
55 |
demo.launch()
|