|
--- |
|
license: mit |
|
language: |
|
- en |
|
library_name: transformers |
|
inference: false |
|
pipeline_tag: document-question-answering |
|
--- |
|
|
|
LiLT Model [Read Here](https://arxiv.org/pdf/2202.13669v1.pdf). This model being fine-tuned on English DocVQA |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
|
from datasets import load_dataset |
|
|
|
model_checkpoint = "TusharGoel/LiLT-Document-QA" |
|
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True) |
|
model_predict = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint) |
|
|
|
model_predict.eval() |
|
dataset = load_dataset("nielsr/funsd", split="train") |
|
example = dataset[0] |
|
print(example) |
|
|
|
question = "What is the Licensee Number?" |
|
print(question) |
|
|
|
words = example["words"] |
|
boxes = example["bboxes"] |
|
|
|
encoding = tokenizer(question, words, boxes = boxes, return_token_type_ids=True, return_tensors="pt") |
|
|
|
word_ids = encoding.word_ids(0) |
|
outputs = model_predict(**encoding) |
|
|
|
loss = outputs.loss |
|
start_scores = outputs.start_logits |
|
end_scores = outputs.end_logits |
|
|
|
start, end = word_ids[start_scores.argmax(-1).item()], word_ids[end_scores.argmax(-1).item()] |
|
# print(start, end) |
|
print(" ".join(words[start : end + 1])) |
|
``` |