File size: 2,952 Bytes
ff22ecd
df9d779
ff22ecd
4880398
 
 
c448195
9f60ec2
4880398
 
 
5dcf050
bf6623b
df9d779
 
 
 
 
4880398
df9d779
 
5dcf050
 
4880398
df9d779
bf6623b
5dcf050
df9d779
 
 
 
bf6623b
5dcf050
df9d779
 
bf6623b
 
df9d779
bf6623b
df9d779
5dcf050
 
df9d779
bf6623b
df9d779
5dcf050
 
4880398
 
 
df9d779
4880398
 
 
 
 
 
5dcf050
4880398
bf6623b
 
df9d779
 
 
 
 
 
5dcf050
4880398
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
---
license: mit
---

# ERNIE-Layout_Pytorch

[This repo](https://github.com/NormXU/ERNIE-Layout-Pytorch) is an unofficial Pytorch implementation of [ERNIE-Layout](http://arxiv.org/abs/2210.06155) which is originally released through PaddleNLP. 
The model is translated from [PaddlePaddle/ernie-layoutx-base-uncased](https://huggingface.co/PaddlePaddle/ernie-layoutx-base-uncased) with [tools/convert2torch.py](https://github.com/NormXU/ERNIE-Layout-Pytorch/blob/main/tools/convert2torch.py). It is a script to translate all state dicts of ERNIE-pretrained models for PaddlePaddle into Pytorch style. Feel free to edit it if necessary.

**A Quick Example**
```python
import torch
from PIL import Image
import numpy as np
import torch.nn.functional as F
from networks.model_util import ernie_qa_processing
from networks import ErnieLayoutConfig, ErnieLayoutForQuestionAnswering, ErnieLayoutImageProcessor, \
    ERNIELayoutProcessor, ErnieLayoutTokenizerFast

pretrain_torch_model_or_path = "Norm/ERNIE-Layout-Pytorch"
doc_imag_path = "/path/to/dummy_input.jpeg"

device = torch.device("cuda:0")

# Dummy Input
context = ['This is an example document', 'All ocr boxes are inserted into this list']
layout = [[381, 91, 505, 115], [738, 96, 804, 122]]  # all boxes are resized between 0 - 1000
pil_image = Image.open(doc_imag_path).convert("RGB")

# initialize tokenizer
tokenizer = ErnieLayoutTokenizerFast.from_pretrained(pretrained_model_name_or_path=pretrain_torch_model_or_path)

# initialize feature extractor
feature_extractor = ErnieLayoutImageProcessor(apply_ocr=False)
processor = ERNIELayoutProcessor(image_processor=feature_extractor, tokenizer=tokenizer)

# Tokenize context & questions
context_encodings = processor(pil_image, context)
question = "what is it?"
tokenized_res = ernie_qa_processing(tokenizer, question, layout, context_encodings)
tokenized_res['input_ids'] = torch.tensor([tokenized_res['input_ids']]).to(device)
tokenized_res['bbox'] = torch.tensor([tokenized_res['bbox']]).to(device)
tokenized_res['pixel_values'] = torch.tensor(np.array(context_encodings.data['pixel_values'])).to(device)

# dummy answer start && end index
tokenized_res['start_positions'] = torch.tensor([6]).to(device)
tokenized_res['end_positions'] = torch.tensor([12]).to(device)

# initialize config
config = ErnieLayoutConfig.from_pretrained(pretrained_model_name_or_path=pretrain_torch_model_or_path)
config.num_classes = 2  # start and end

# initialize ERNIE for VQA
model = ErnieLayoutForQuestionAnswering.from_pretrained(
    pretrained_model_name_or_path=pretrain_torch_model_or_path,
    config=config,
)
model.to(device)

output = model(**tokenized_res)

# decode output
start_max = torch.argmax(F.softmax(output.start_logits, dim=-1))
end_max = torch.argmax(F.softmax(output.end_logits, dim=-1)) + 1  # add one ##because of python list indexing
answer = tokenizer.decode(tokenized_res["input_ids"][0][start_max: end_max])
print(answer)


```