Heramb26 commited on
Commit
9206f14
1 Parent(s): ede1ece

Add application file

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
+
7
+ # Set up device
8
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9
+
10
+ # Load the fine-tuned model
11
+ checkpoint_path = './checkpoint-2070' # Path to your fine-tuned model checkpoint
12
+ model = VisionEncoderDecoderModel.from_pretrained(checkpoint_path).to(device)
13
+
14
+ # Use the original model's processor (tokenizer and feature extractor)
15
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
16
+
17
+ def ocr_image(image):
18
+ """
19
+ Perform OCR on a single image.
20
+ :param image: PIL Image object.
21
+ :return: Extracted text from the image.
22
+ """
23
+ pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
24
+ generated_ids = model.generate(pixel_values)
25
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
26
+ return generated_text
27
+
28
+ # Define the Gradio interface
29
+ interface = gr.Interface(
30
+ fn=ocr_image, # Function to call for prediction
31
+ inputs=gr.inputs.Image(type="pil"), # Accept an image as input
32
+ outputs="text", # Return extracted text
33
+ title="OCR with TrOCR",
34
+ description="Upload an image, and the fine-tuned TrOCR model will extract the text for you."
35
+ )
36
+
37
+ # Launch the Gradio app
38
+ if __name__ == "__main__":
39
+ interface.launch()