Hzqhssn commited on
Commit
9e9c8bb
·
1 Parent(s): 48b33ab

updated ocr

Browse files
Files changed (1) hide show
  1. app.py +14 -28
app.py CHANGED
@@ -1,31 +1,26 @@
1
  import gradio as gr
2
  import requests
3
- from PIL import Image, ImageDraw
4
  from transformers import AutoProcessor, AutoModelForCausalLM
5
  from io import BytesIO
6
  import torch
7
 
8
- # Set device
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
 
12
- # Load model and processor
13
  model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
14
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
15
 
16
- # List of colors to cycle through for bounding boxes
17
- COLORS = ["red", "blue", "green", "yellow", "purple", "orange", "cyan", "magenta"]
18
 
19
- # Prediction function
20
  def predict_from_url(url):
21
- prompt = "<OD>"
22
  if not url:
23
- return {"Error": "Please input a URL"}, None
24
 
25
  try:
26
  image = Image.open(BytesIO(requests.get(url).content))
27
  except Exception as e:
28
- return {"Error": f"Failed to load image: {str(e)}"}, None
29
 
30
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
31
  generated_ids = model.generate(
@@ -36,34 +31,25 @@ def predict_from_url(url):
36
  do_sample=False
37
  )
38
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
39
- parsed_answer = processor.post_process_generation(generated_text, task="<OD>", image_size=(image.width, image.height))
40
 
41
- labels = parsed_answer.get('<OD>', {}).get('labels', [])
42
- bboxes = parsed_answer.get('<OD>', {}).get('bboxes', [])
43
-
44
- # Draw bounding boxes on the image
45
- draw = ImageDraw.Draw(image)
46
- legend = [] # Store legend entries
47
- for idx, (bbox, label) in enumerate(zip(bboxes, labels)):
48
- x1, y1, x2, y2 = bbox
49
- color = COLORS[idx % len(COLORS)] # Cycle through colors
50
- draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
51
- legend.append(f"{label}: {color}")
52
-
53
 
54
- return "\n".join(legend), image
 
 
 
 
 
55
 
56
- # Gradio interface
57
  demo = gr.Interface(
58
  fn=predict_from_url,
59
  inputs=gr.Textbox(label="Enter Image URL"),
60
  outputs=[
61
- gr.Textbox(label="Legend"), # Output the legend
62
- gr.Image(label="Image with Bounding Boxes") # Output the processed image
63
  ],
64
- title="Item Classifier with Bounding Boxes and Legend",
65
  allow_flagging="never"
66
  )
67
 
68
- # Launch the interface
69
  demo.launch()
 
1
  import gradio as gr
2
  import requests
3
+ from PIL import Image
4
  from transformers import AutoProcessor, AutoModelForCausalLM
5
  from io import BytesIO
6
  import torch
7
 
 
8
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
 
 
11
  model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
12
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
13
 
 
 
14
 
 
15
  def predict_from_url(url):
16
+ prompt = "<OCR>"
17
  if not url:
18
+ return "Error: Please input a URL"
19
 
20
  try:
21
  image = Image.open(BytesIO(requests.get(url).content))
22
  except Exception as e:
23
+ return f"Error: Failed to load image: {str(e)}"
24
 
25
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
26
  generated_ids = model.generate(
 
31
  do_sample=False
32
  )
33
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
34
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
35
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ ocr_text = parsed_answer.get("<OCR>", "")
38
+
39
+ text = ocr_text.replace("\n", " ").strip()
40
+ output_text = " ".join(text.split())
41
+
42
+ return output_text, image
43
 
 
44
  demo = gr.Interface(
45
  fn=predict_from_url,
46
  inputs=gr.Textbox(label="Enter Image URL"),
47
  outputs=[
48
+ gr.Textbox(label="Output Text"),
49
+ gr.Image(label="Image")
50
  ],
51
+ title="OCR Text Extractor",
52
  allow_flagging="never"
53
  )
54
 
 
55
  demo.launch()