|
import re |
|
import gradio as gr |
|
|
|
import torch |
|
|
|
|
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
|
extractor = AutoFeatureExtractor.from_pretrained("DunnBC22/dit-base-Business_Documents_Classified_v2") |
|
|
|
model = AutoModelForImageClassification.from_pretrained("DunnBC22/dit-base-Business_Documents_Classified_v2") |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
|
|
def classify_documents(image): |
|
|
|
inputs = extractor(images=image, return_tensor='pt') |
|
tensors = torch.from_numpy(inputs.pixel_values[0]).unsqueeze(0) |
|
model_output = model(tensors).logits |
|
max_index = torch.argmax(model_output) |
|
|
|
document_class = model.config.id2label[max_index.item()] |
|
|
|
return { |
|
"result" : str(document_class) |
|
} |
|
|
|
article = "<p style='text-align: center'><a href='https://www.xelpmoc.in/' target='_blank'>Made by Xelpmoc</a></p>" |
|
|
|
demo = gr.Interface( |
|
fn=classify_documents, |
|
inputs="image", |
|
outputs="json", |
|
title="Document Classification", |
|
article=article, |
|
enable_queue=True, |
|
examples=[ |
|
|
|
["./test_images/email_image_2.jpg"], |
|
["./test_images/form_image_3.jpg"] |
|
], |
|
cache_examples=False) |
|
|
|
demo.launch() |