Update README.md
Browse files
README.md
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
|
2 |
-
# LayoutLM fine-tuned on FUNSD for Document token classification
|
3 |
|
4 |
## Usage
|
5 |
|
@@ -8,12 +7,13 @@ import torch
|
|
8 |
import numpy as np
|
9 |
from PIL import Image, ImageDraw, ImageFont
|
10 |
import pytesseract
|
11 |
-
from transformers import LayoutLMForTokenClassification
|
12 |
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
-
|
|
|
17 |
model.to(device)
|
18 |
|
19 |
|
@@ -29,10 +29,8 @@ width, height = image.size
|
|
29 |
w_scale = 1000/width
|
30 |
h_scale = 1000/height
|
31 |
|
32 |
-
ocr_df = pytesseract.image_to_data(image, output_type='data.frame')
|
33 |
-
|
34 |
-
ocr_df = ocr_df.dropna() \
|
35 |
-
.assign(left_scaled = ocr_df.left*w_scale,
|
36 |
width_scaled = ocr_df.width*w_scale,
|
37 |
top_scaled = ocr_df.top*h_scale,
|
38 |
height_scaled = ocr_df.height*h_scale,
|
@@ -41,7 +39,7 @@ ocr_df = ocr_df.dropna() \
|
|
41 |
|
42 |
float_cols = ocr_df.select_dtypes('float').columns
|
43 |
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
|
44 |
-
ocr_df = ocr_df.replace(r'^\s
|
45 |
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
46 |
ocr_df[:20]
|
47 |
|
@@ -140,5 +138,41 @@ bbox = torch.tensor(token_boxes, device=device).unsqueeze(0)
|
|
140 |
|
141 |
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
```
|
|
|
1 |
+
\n# LayoutLM fine-tuned on FUNSD for Document token classification
|
|
|
2 |
|
3 |
## Usage
|
4 |
|
|
|
7 |
import numpy as np
|
8 |
from PIL import Image, ImageDraw, ImageFont
|
9 |
import pytesseract
|
10 |
+
from transformers import LayoutLMForTokenClassification, LayoutLMTokenizer
|
11 |
|
12 |
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
+
tokenizer = LayoutLMTokenizer.from_pretrained("mrm8488/layoutlm-finetuned-funsd")
|
16 |
+
model = LayoutLMForTokenClassification.from_pretrained("mrm8488/layoutlm-finetuned-funsd", num_labels=13)
|
17 |
model.to(device)
|
18 |
|
19 |
|
|
|
29 |
w_scale = 1000/width
|
30 |
h_scale = 1000/height
|
31 |
|
32 |
+
ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \\n
|
33 |
+
ocr_df = ocr_df.dropna() \\n .assign(left_scaled = ocr_df.left*w_scale,
|
|
|
|
|
34 |
width_scaled = ocr_df.width*w_scale,
|
35 |
top_scaled = ocr_df.top*h_scale,
|
36 |
height_scaled = ocr_df.height*h_scale,
|
|
|
39 |
|
40 |
float_cols = ocr_df.select_dtypes('float').columns
|
41 |
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
|
42 |
+
ocr_df = ocr_df.replace(r'^\s*{{%htmlContent%}}#39;, np.nan, regex=True)
|
43 |
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
44 |
ocr_df[:20]
|
45 |
|
|
|
138 |
|
139 |
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
140 |
|
141 |
+
token_predictions = outputs.logits.argmax(-1).squeeze().tolist() # the predictions are at the token level
|
142 |
+
|
143 |
+
word_level_predictions = [] # let's turn them into word level predictions
|
144 |
+
final_boxes = []
|
145 |
+
for id, token_pred, box in zip(input_ids.squeeze().tolist(), token_predictions, token_actual_boxes):
|
146 |
+
if (tokenizer.decode([id]).startswith("##")) or (id in [tokenizer.cls_token_id,
|
147 |
+
tokenizer.sep_token_id,
|
148 |
+
tokenizer.pad_token_id]):
|
149 |
+
# skip prediction + bounding box
|
150 |
+
|
151 |
+
continue
|
152 |
+
else:
|
153 |
+
word_level_predictions.append(token_pred)
|
154 |
+
final_boxes.append(box)
|
155 |
+
|
156 |
+
#print(word_level_predictions)
|
157 |
+
|
158 |
+
|
159 |
+
draw = ImageDraw.Draw(image)
|
160 |
+
|
161 |
+
font = ImageFont.load_default()
|
162 |
+
|
163 |
+
def iob_to_label(label):
|
164 |
+
if label != 'O':
|
165 |
+
return label[2:]
|
166 |
+
else:
|
167 |
+
return "other"
|
168 |
+
|
169 |
+
label2color = {'question':'blue', 'answer':'green', 'header':'orange', 'other':'violet'}
|
170 |
+
|
171 |
+
for prediction, box in zip(word_level_predictions, final_boxes):
|
172 |
+
predicted_label = iob_to_label(label_map[prediction]).lower()
|
173 |
+
draw.rectangle(box, outline=label2color[predicted_label])
|
174 |
+
draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
|
175 |
+
|
176 |
+
# Display the result (image)
|
177 |
|
178 |
```
|