import re import gradio as gr import torch from transformers import AutoFeatureExtractor, AutoModelForImageClassification extractor = AutoFeatureExtractor.from_pretrained("DunnBC22/dit-base-Business_Documents_Classified_v2") model = AutoModelForImageClassification.from_pretrained("DunnBC22/dit-base-Business_Documents_Classified_v2") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) def classify_documents(image): # input_image = image.convert("RGB") inputs = extractor(images=image, return_tensor='pt') tensors = torch.from_numpy(inputs.pixel_values[0]).unsqueeze(0) model_output = model(tensors).logits max_index = torch.argmax(model_output) document_class = model.config.id2label[max_index.item()] return { "result" : str(document_class) } article = "
" demo = gr.Interface( fn=classify_documents, inputs="image", outputs="json", title="Document Classification", article=article, enable_queue=True, examples=[ ["./test_images/email_image_2.jpg"], ["./test_images/form_image_3.jpg"] ], cache_examples=False) demo.launch()