Heramb26 commited on
Commit
134092a
1 Parent(s): e705ba3
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -1,16 +1,12 @@
1
  import torch
2
  from PIL import Image
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
- from huggingface_hub import hf_hub_download
5
- import os
6
-
7
- # Load the model checkpoint and tokenizer files from Hugging Face Model Hub
8
- # checkpoint_folder = hf_hub_download(repo_id="Heramb26/tr-ocr-custom-checkpoints", filename="checkpoint-2070")
9
 
10
  # Set up the device (GPU or CPU)
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- # Load the fine-tuned model and processor from the downloaded folder
14
  model = VisionEncoderDecoderModel.from_pretrained("Heramb26/TC-OCR-Custom").to(device)
15
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
16
 
@@ -26,8 +22,12 @@ def ocr_image(image):
26
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
27
  return generated_text
28
 
29
- # Example usage
30
- image_path = "path/to/your/image.jpg" # Update with the path to your image
31
- image = Image.open(image_path) # Open the image file using PIL
32
- extracted_text = ocr_image(image) # Perform OCR on the image
33
- print("Extracted Text:", extracted_text)
 
 
 
 
 
1
  import torch
2
  from PIL import Image
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+ import gradio as gr
 
 
 
 
5
 
6
  # Set up the device (GPU or CPU)
7
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
9
+ # Load the fine-tuned model and processor from the Hugging Face repository
10
  model = VisionEncoderDecoderModel.from_pretrained("Heramb26/TC-OCR-Custom").to(device)
11
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
12
 
 
22
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
23
  return generated_text
24
 
25
+ # Create a Gradio interface
26
+ interface = gr.Interface(fn=ocr_image, # Function to be called when an image is uploaded
27
+ inputs=gr.inputs.Image(type="pil"), # Input is an image file
28
+ outputs="text", # Output is extracted text
29
+ title="OCR Inference", # Title of the app
30
+ description="Upload an image with handwritten text to extract the text.") # Description
31
+
32
+ # Launch the Gradio app
33
+ interface.launch()