MohamedRashad's picture
Add GPU support for text extraction
cfeccec
raw history blame
No virus
3.54 kB
from transformers import NougatProcessor, VisionEncoderDecoderModel
import gradio as gr
import torch
from PIL import Image
from pathlib import Path
from pdf2image import convert_from_path
import spaces
# Load the model and processor
processor = NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat")
model = VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Using {device} device")
context_length = 2048
@spaces.GPU
def extract_text_from_image(image):
"""
Extract text from PIL image
Args:
image (PIL.Image): Input image
Returns:
str: Extracted text from the image
"""
# prepare PDF image for the model
pixel_values = processor(image, return_tensors="pt").pixel_values
# generate transcription
outputs = model.generate(
pixel_values.to(device),
min_length=1,
max_new_tokens=context_length,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
)
page_sequence = processor.batch_decode(outputs, skip_special_tokens=True)[0]
page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False)
return page_sequence
def extract_text_from_pdf(pdf_path, progress=gr.Progress()):
"""
Extract text from PDF
Args:
pdf_path (str): Path to the PDF file
progress (gr.Progress): Progress bar
Returns:
str: Extracted text from the PDF
"""
progress(0, desc="Starting...")
images = convert_from_path(pdf_path)
texts = []
for image in progress.tqdm(images):
extracted_text = extract_text_from_image(image)
texts.append(extracted_text)
return "\n".join(texts)
model_description = """
This is a demo for the Arabic Small Nougat model. It is an end-to-end OCR model that can extract text from images and PDFs.
- The model is trained on the [Khatt dataset](https://huggingface.co/datasets/Fakhraddin/khatt) and custom made dataset.
- The model is a finetune of [facebook/nougat-small](https://huggingface.co/facebook/nougat-small) model.
**Note**: The model is a prototype in my book and may not work well on all types of images and PDFs. **Check the output carefully before using it for any serious work.**
"""
example_images = [Image.open(Path(__file__).parent / "book_page.jpeg")]
with gr.Blocks(title="Arabic Small Nougat") as demo:
gr.HTML("<h1 style='text-align: center'>Arabic End-to-End Structured OCR for textbooks</h1>")
gr.Markdown(model_description)
with gr.Tab("Extract Text from Image"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
image_submit_button = gr.Button(value="Submit", variant="primary")
output = gr.Markdown(label="Output Markdown", rtl=True)
image_submit_button.click(extract_text_from_image, inputs=[input_image], outputs=output)
gr.Examples(example_images, [input_image], output, extract_text_from_image, cache_examples=True)
with gr.Tab("Extract Text from PDF"):
with gr.Row():
with gr.Column():
pdf = gr.File(label="Input PDF", type="filepath")
pdf_submit_button = gr.Button(value="Submit", variant="primary")
output = gr.Markdown(label="Output Markdown", rtl=True)
pdf_submit_button.click(extract_text_from_pdf, inputs=[pdf], outputs=output)
demo.queue().launch(share=False)