|
|
|
"""LiLT For Deployment |
|
|
|
Automatically generated by Colaboratory. |
|
|
|
Original file is located at |
|
https://colab.research.google.com/drive/1ol6RWyff15SF6ZJPf47X5380hBTEDiUH |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
os.system('pip install pyyaml==5.1') |
|
|
|
os.system('pip install -q pytesseract') |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
import pandas as pd |
|
import os |
|
from PIL import Image |
|
from transformers import RobertaTokenizer |
|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
|
|
from dataset import create_features |
|
from modeling import LiLT |
|
from utils import LiLTPL |
|
|
|
import gdown |
|
import gradio as gr |
|
|
|
seed = 42 |
|
|
|
|
|
config = { |
|
"hidden_dropout_prob": 0.1, |
|
"hidden_size_t": 768, |
|
"hidden_size" : 768, |
|
"hidden_size_l": 768 // 6, |
|
"intermediate_ff_size_factor": 4, |
|
"max_2d_position_embeddings": 1001, |
|
"max_seq_len_l": 512, |
|
"max_seq_len_t" : 512, |
|
"num_attention_heads": 12, |
|
"num_hidden_layers": 12, |
|
'dim_head' : 64, |
|
"shape_size": 96, |
|
"vocab_size": 50265, |
|
"eps": 1e-12, |
|
"fine_tune" : True |
|
} |
|
|
|
id2label = ['scientific_report', |
|
'resume', |
|
'memo', |
|
'file_folder', |
|
'specification', |
|
'news_article', |
|
'letter', |
|
'form', |
|
'budget', |
|
'handwritten', |
|
'email', |
|
'invoice', |
|
'presentation', |
|
'scientific_publication', |
|
'questionnaire', |
|
'advertisement'] |
|
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') |
|
|
|
url = 'https://drive.google.com/uc?id=1eRV4fS_LFwI5MHqcRwLUNQZgewxI6Se_' |
|
output = 'lilt_ckpt.ckpt' |
|
gdown.download(url, output, quiet=False) |
|
|
|
class RVLCDIPData(Dataset): |
|
|
|
def __init__(self, image_list, label_list, tokenizer, max_len = 512, size = 1000): |
|
|
|
self.image_list = image_list |
|
self.label_list = label_list |
|
self.tokenizer = tokenizer |
|
self.max_seq_length = max_len |
|
self.size = size |
|
|
|
def __len__(self): |
|
return len(self.image_list) |
|
|
|
def __getitem__(self, idx): |
|
img_path = self.image_list[idx] |
|
label = self.label_list[idx] |
|
|
|
boxes, words, normal_box = create_features( |
|
img_path = img_path, |
|
tokenizer = self.tokenizer, |
|
max_seq_length = self.max_seq_length, |
|
size = self.size, |
|
use_ocr = True, |
|
) |
|
|
|
final_encoding = {'input_boxes': boxes, 'input_words': words} |
|
final_encoding['label'] = torch.as_tensor(label).long() |
|
|
|
return final_encoding |
|
|
|
lilt = LiLTPL(config) |
|
|
|
lilt.load_from_checkpoint('lilt_ckpt.ckpt') |
|
|
|
|
|
|
|
image = gr.inputs.Image(type="pil") |
|
label = gr.outputs.Label(num_top_classes=5) |
|
examples = [['00093726.png'], ['00866042.png']] |
|
title = "Interactive demo: LiLT for Image Classification" |
|
description = "Demo for classifying document images with LiLT model. To use it, \ |
|
simply upload an image or use the example images below and click 'submit' to let the model predict the 5 most probable Document classes. \ |
|
Results will show up in a few seconds." |
|
|
|
def classify_image(image): |
|
|
|
image.save('sample_img.png') |
|
boxes, words, normal_box = create_features( |
|
img_path = 'sample_img.png', |
|
tokenizer = tokenizer, |
|
max_seq_length = 512, |
|
size = 1000, |
|
use_ocr = True, |
|
) |
|
|
|
final_encoding = {'input_boxes': boxes.unsqueeze(0), 'input_words': words.unsqueeze(0)} |
|
output = lilt.forward(final_encoding) |
|
output = output[0].softmax(axis = -1) |
|
|
|
final_pred = {} |
|
for i, score in enumerate(output): |
|
score = output[i] |
|
final_pred[id2label[i]] = score.detach().cpu().tolist() |
|
|
|
return final_pred |
|
|
|
gr.Interface(fn=classify_image, inputs=image, outputs=label, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True) |
|
|
|
|