File size: 1,682 Bytes
4331eba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30e33e1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pytesseract
import torch
import gradio as gr
from transformers import LayoutLMForSequenceClassification
from preprocess import apply_ocr,encode_example

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"
model = LayoutLMForSequenceClassification.from_pretrained("models/document_model")
model.to(device)
classes=['questionnaire', 'memo', 'budget', 'file_folder', 'specification', 'invoice', 'resume',
                               'advertisement', 'news_article', 'email', 'scientific_publication', 'presentation',
                               'letter', 'form', 'handwritten', 'scientific_report']


def predict(image):
    example = apply_ocr(image)
    encoded_example = encode_example(example)
    input_ids = torch.tensor(encoded_example['input_ids']).unsqueeze(0)
    bbox = torch.tensor(encoded_example['bbox']).unsqueeze(0)
    attention_mask = torch.tensor(encoded_example['attention_mask']).unsqueeze(0)
    token_type_ids = torch.tensor(encoded_example['token_type_ids']).unsqueeze(0)
    model.eval()
    outputs=model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
    classification_results = torch.softmax(outputs.logits, dim=1).tolist()[0]
    max_prob_index = classification_results.index(max(classification_results))
    predicted_class = classes[max_prob_index]
    return predicted_class



title="Document Image Classification"

demo = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(type="pil"),
    outputs=gr.outputs.Textbox(label="Predicted Class"),
    title=title,
)
demo.launch()