Hzqhssn commited on
Commit
262c67d
·
1 Parent(s): 9e9c8bb

fixed issues with image res

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -5,51 +5,59 @@ from transformers import AutoProcessor, AutoModelForCausalLM
5
  from io import BytesIO
6
  import torch
7
 
 
8
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
9
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
 
 
11
  model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
12
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
13
 
14
-
15
  def predict_from_url(url):
16
  prompt = "<OCR>"
17
  if not url:
18
- return "Error: Please input a URL"
19
 
20
  try:
21
- image = Image.open(BytesIO(requests.get(url).content))
 
22
  except Exception as e:
23
- return f"Error: Failed to load image: {str(e)}"
24
 
25
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
26
- generated_ids = model.generate(
27
- input_ids=inputs["input_ids"],
28
- pixel_values=inputs["pixel_values"],
29
- max_new_tokens=4096,
30
- num_beams=3,
31
- do_sample=False
32
- )
33
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
34
- parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
35
-
36
-
37
- ocr_text = parsed_answer.get("<OCR>", "")
38
-
39
- text = ocr_text.replace("\n", " ").strip()
40
- output_text = " ".join(text.split())
 
 
41
 
42
  return output_text, image
43
 
 
44
  demo = gr.Interface(
45
- fn=predict_from_url,
46
  inputs=gr.Textbox(label="Enter Image URL"),
47
  outputs=[
48
- gr.Textbox(label="Output Text"),
49
- gr.Image(label="Image")
50
  ],
51
- title="OCR Text Extractor",
 
52
  allow_flagging="never"
53
  )
54
 
 
55
  demo.launch()
 
5
  from io import BytesIO
6
  import torch
7
 
8
+ # Set device
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
 
12
+ # Load model and processor
13
  model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
14
  processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
15
 
16
+ # Prediction function
17
  def predict_from_url(url):
18
  prompt = "<OCR>"
19
  if not url:
20
+ return "Error: Please input a URL", None
21
 
22
  try:
23
+ # Open the image and convert to RGB format to handle grayscale or other formats
24
+ image = Image.open(BytesIO(requests.get(url).content)).convert("RGB")
25
  except Exception as e:
26
+ return f"Error: Failed to load or process the image: {str(e)}", None
27
 
28
+ # Preprocess and perform inference
29
+ try:
30
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
31
+ generated_ids = model.generate(
32
+ input_ids=inputs["input_ids"],
33
+ pixel_values=inputs["pixel_values"],
34
+ max_new_tokens=4096,
35
+ num_beams=3,
36
+ do_sample=False
37
+ )
38
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
39
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
40
+
41
+ # Extract OCR text
42
+ ocr_text = parsed_answer.get("<OCR>", "")
43
+ output_text = " ".join(ocr_text.replace("\n", " ").strip().split())
44
+ except Exception as e:
45
+ return f"Error: Failed to process the image for OCR: {str(e)}", image
46
 
47
  return output_text, image
48
 
49
+ # Gradio Interface
50
  demo = gr.Interface(
51
+ fn=predict_from_url,
52
  inputs=gr.Textbox(label="Enter Image URL"),
53
  outputs=[
54
+ gr.Textbox(label="Extracted Text"),
55
+ gr.Image(label="Uploaded Image")
56
  ],
57
+ title="Enhanced OCR Text Extractor",
58
+ description="Provide an image URL, and this tool will extract text using OCR.",
59
  allow_flagging="never"
60
  )
61
 
62
+ # Launch the app
63
  demo.launch()