mrm8488 commited on
Commit
b154577
1 Parent(s): 3f23a54

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +43 -9
README.md CHANGED
@@ -1,5 +1,4 @@
1
-
2
- # LayoutLM fine-tuned on FUNSD for Document token classification
3
 
4
  ## Usage
5
 
@@ -8,12 +7,13 @@ import torch
8
  import numpy as np
9
  from PIL import Image, ImageDraw, ImageFont
10
  import pytesseract
11
- from transformers import LayoutLMForTokenClassification
12
 
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
- model = LayoutLMForTokenClassification.from_pretrained("mrm8488/layoutlm-finetuned-funsd", num_labels=num_labels)
 
17
  model.to(device)
18
 
19
 
@@ -29,10 +29,8 @@ width, height = image.size
29
  w_scale = 1000/width
30
  h_scale = 1000/height
31
 
32
- ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \
33
-
34
- ocr_df = ocr_df.dropna() \
35
- .assign(left_scaled = ocr_df.left*w_scale,
36
  width_scaled = ocr_df.width*w_scale,
37
  top_scaled = ocr_df.top*h_scale,
38
  height_scaled = ocr_df.height*h_scale,
@@ -41,7 +39,7 @@ ocr_df = ocr_df.dropna() \
41
 
42
  float_cols = ocr_df.select_dtypes('float').columns
43
  ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
44
- ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
45
  ocr_df = ocr_df.dropna().reset_index(drop=True)
46
  ocr_df[:20]
47
 
@@ -140,5 +138,41 @@ bbox = torch.tensor(token_boxes, device=device).unsqueeze(0)
140
 
141
  outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  ```
 
1
+ \n# LayoutLM fine-tuned on FUNSD for Document token classification
 
2
 
3
  ## Usage
4
 
 
7
  import numpy as np
8
  from PIL import Image, ImageDraw, ImageFont
9
  import pytesseract
10
+ from transformers import LayoutLMForTokenClassification, LayoutLMTokenizer
11
 
12
 
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ tokenizer = LayoutLMTokenizer.from_pretrained("mrm8488/layoutlm-finetuned-funsd")
16
+ model = LayoutLMForTokenClassification.from_pretrained("mrm8488/layoutlm-finetuned-funsd", num_labels=13)
17
  model.to(device)
18
 
19
 
 
29
  w_scale = 1000/width
30
  h_scale = 1000/height
31
 
32
+ ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \\n
33
+ ocr_df = ocr_df.dropna() \\n .assign(left_scaled = ocr_df.left*w_scale,
 
 
34
  width_scaled = ocr_df.width*w_scale,
35
  top_scaled = ocr_df.top*h_scale,
36
  height_scaled = ocr_df.height*h_scale,
 
39
 
40
  float_cols = ocr_df.select_dtypes('float').columns
41
  ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
42
+ ocr_df = ocr_df.replace(r'^\s*{{%htmlContent%}}#39;, np.nan, regex=True)
43
  ocr_df = ocr_df.dropna().reset_index(drop=True)
44
  ocr_df[:20]
45
 
 
138
 
139
  outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)
140
 
141
+ token_predictions = outputs.logits.argmax(-1).squeeze().tolist() # the predictions are at the token level
142
+
143
+ word_level_predictions = [] # let's turn them into word level predictions
144
+ final_boxes = []
145
+ for id, token_pred, box in zip(input_ids.squeeze().tolist(), token_predictions, token_actual_boxes):
146
+ if (tokenizer.decode([id]).startswith("##")) or (id in [tokenizer.cls_token_id,
147
+ tokenizer.sep_token_id,
148
+ tokenizer.pad_token_id]):
149
+ # skip prediction + bounding box
150
+
151
+ continue
152
+ else:
153
+ word_level_predictions.append(token_pred)
154
+ final_boxes.append(box)
155
+
156
+ #print(word_level_predictions)
157
+
158
+
159
+ draw = ImageDraw.Draw(image)
160
+
161
+ font = ImageFont.load_default()
162
+
163
+ def iob_to_label(label):
164
+ if label != 'O':
165
+ return label[2:]
166
+ else:
167
+ return "other"
168
+
169
+ label2color = {'question':'blue', 'answer':'green', 'header':'orange', 'other':'violet'}
170
+
171
+ for prediction, box in zip(word_level_predictions, final_boxes):
172
+ predicted_label = iob_to_label(label_map[prediction]).lower()
173
+ draw.rectangle(box, outline=label2color[predicted_label])
174
+ draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
175
+
176
+ # Display the result (image)
177
 
178
  ```