File size: 6,178 Bytes
92871c6
 
 
 
 
688353f
ea38d8e
2943064
688353f
cfeccec
92871c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688353f
 
cfeccec
92871c6
 
 
 
 
 
 
 
 
 
688353f
92871c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688353f
 
92871c6
 
 
 
 
 
 
688353f
92871c6
 
 
 
 
 
 
71311e8
92871c6
 
71311e8
92871c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71311e8
 
92871c6
2943064
92871c6
 
 
 
71311e8
688353f
 
 
 
2943064
92871c6
 
 
688353f
92871c6
 
 
 
 
 
 
 
 
 
 
 
 
 
688353f
 
 
 
92871c6
 
 
688353f
92871c6
 
 
 
688353f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from transformers import (
    NougatProcessor,
    VisionEncoderDecoderModel,
    TextIteratorStreamer,
)
import gradio as gr
import torch
from pathlib import Path
from pdf2image import convert_from_path
import spaces
from threading import Thread

models_supported = {
    "arabic-small-nougat": [
        NougatProcessor.from_pretrained("MohamedRashad/arabic-small-nougat"),
        VisionEncoderDecoderModel.from_pretrained("MohamedRashad/arabic-small-nougat"),
    ],
    "arabic-base-nougat": [
        NougatProcessor.from_pretrained("MohamedRashad/arabic-base-nougat"),
        VisionEncoderDecoderModel.from_pretrained(
            "MohamedRashad/arabic-base-nougat",
            torch_dtype=torch.bfloat16,
            attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"},
        ),
    ],
    "arabic-large-nougat": [
        NougatProcessor.from_pretrained("MohamedRashad/arabic-large-nougat"),
        VisionEncoderDecoderModel.from_pretrained(
            "MohamedRashad/arabic-large-nougat",
            torch_dtype=torch.bfloat16,
            attn_implementation={"decoder": "flash_attention_2", "encoder": "eager"},
        ),
    ],
}


@spaces.GPU
def extract_text_from_image(image, model_name):
    print(f"Extracting text from image using model: {model_name}")
    processor, model = models_supported[model_name]
    context_length = model.decoder.config.max_position_embeddings
    torch_dtype = model.dtype
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    pixel_values = (
        processor(image, return_tensors="pt").pixel_values.to(torch_dtype).to(device)
    )
    streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)

    # Start generation in a separate thread
    generation_kwargs = {
        "pixel_values": pixel_values,
        "min_length": 1,
        "max_new_tokens": context_length,
        "streamer": streamer,
    }

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Yield tokens as they become available
    output = ""
    for token in streamer:
        output += token
        yield output

    thread.join()


@spaces.GPU
def extract_text_from_pdf(pdf_path, model_name):
    processor, model = models_supported[model_name]
    context_length = model.decoder.config.max_position_embeddings
    torch_dtype = model.dtype
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
    print(f"Extracting text from PDF: {pdf_path}")
    images = convert_from_path(pdf_path)

    pdf_output = ""
    for image in images:
        pixel_values = (
            processor(image, return_tensors="pt")
            .pixel_values.to(torch_dtype)
            .to(device)
        )

        # Start generation in a separate thread
        generation_kwargs = {
            "pixel_values": pixel_values,
            "min_length": 1,
            "max_new_tokens": context_length,
            "streamer": streamer,
        }

        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        # Yield tokens as they become available
        for token in streamer:
            pdf_output += token
            yield pdf_output

        thread.join()
        pdf_output += "\n\n"
        yield pdf_output


model_description = """This is the official demo for the Arabic Nougat models. It is an end-to-end Markdown Extraction model that extracts text from images or PDFs and write them in Markdown.

There are three models available:
- [arabic-small-nougat](https://huggingface.co/MohamedRashad/arabic-small-nougat): A small model that is faster but less accurate (a finetune from [facebook/nougat-small](https://huggingface.co/facebook/nougat-small)).
- [arabic-base-nougat](https://huggingface.co/MohamedRashad/arabic-base-nougat): A base model that is more accurate but slower (a finetune from [facebook/nougat-base](https://huggingface.co/facebook/nougat-base)).
- [arabic-large-nougat](https://huggingface.co/MohamedRashad/arabic-large-nougat): The largest of the three (Made from scratch using [riotu-lab/Aranizer-PBE-86k](https://huggingface.co/riotu-lab/Aranizer-PBE-86k) tokenizer and a larger transformer decoder model).

**Disclaimer**: These models hallucinate text and are not perfect. They are trained on a mix of synthetic and real data and may not work well on all types of images.
"""

example_images = list(Path(__file__).parent.glob("*.jpeg"))

with gr.Blocks(title="Arabic Nougat") as demo:
    gr.HTML(
        "<h1 style='text-align: center'>Arabic End-to-End Structured OCR for textbooks</h1>"
    )
    gr.Markdown(model_description)

    with gr.Tab("Extract Text from Image"):
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Input Image", type="pil")
                model_dropdown = gr.Dropdown(
                    label="Model", choices=list(models_supported.keys()), value=None
                )
                image_submit_button = gr.Button(value="Submit", variant="primary")
            output = gr.Markdown(label="Output Markdown", rtl=True)
        image_submit_button.click(
            extract_text_from_image,
            inputs=[input_image, model_dropdown],
            outputs=output,
        )
        gr.Examples(
            example_images,
            [input_image],
            output,
            extract_text_from_image,
            cache_examples=False,
        )

    with gr.Tab("Extract Text from PDF"):
        with gr.Row():
            with gr.Column():
                pdf = gr.File(label="Input PDF", type="filepath")
                model_dropdown = gr.Dropdown(
                    label="Model", choices=list(models_supported.keys()), value=None
                )
                pdf_submit_button = gr.Button(value="Submit", variant="primary")
            output = gr.Markdown(label="Output Markdown", rtl=True)
        pdf_submit_button.click(
            extract_text_from_pdf, inputs=[pdf, model_dropdown], outputs=output
        )

demo.queue().launch(share=False)