File size: 2,068 Bytes
5f7522a 17a9f25 5f7522a 17a9f25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
---
license: mit
language:
- en
library_name: transformers
pipeline_tag: document-question-answering
---
Fine tuned on DocVQA Dataset 40000 questions
```python
import json
from glob import glob
from transformers import AutoProcessor, AutoModelForDocumentQuestionAnswering
import torch
import numpy as np
model_name = "TusharGoel/LayoutLMv2-finetuned-docvqa"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_name)
def pipeline(question, words, boxes, **kwargs):
images = kwargs["images"]
try:
encoding = processor(
images, question, words,boxes = boxes, return_token_type_ids=True, return_tensors="pt", truncation = True
)
word_ids = encoding.word_ids(0)
outputs = model(**encoding)
start_scores = outputs.start_logits
end_scores = outputs.end_logits
start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
answer = " ".join(words[start : end + 1])
start_scores, end_scores = start_scores.detach().numpy(), end_scores.detach().numpy()
undesired_tokens = encoding['attention_mask']
undesired_tokens_mask = undesired_tokens == 0.0
start_ = np.where(undesired_tokens_mask, -10000.0, start_scores)
end_ = np.where(undesired_tokens_mask, -10000.0, end_scores)
start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True)))
end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True)))
outer = np.matmul(np.expand_dims(start_, -1), np.expand_dims(end_, 1))
max_answer_len = 20
candidates = np.tril(np.triu(outer), max_answer_len - 1)
scores_flat = candidates.flatten()
idx_sort = [np.argmax(scores_flat)]
start, end = np.unravel_index(idx_sort, candidates.shape)[1:]
scores = candidates[0, start, end]
score = scores[0]
except Exception as e:
answer, score = "", 0.0
return answer, score
``` |