pierreguillou commited on
Commit
fe811a3
1 Parent(s): 62697e4

Update files/functions.py

Browse files
Files changed (1) hide show
  1. files/functions.py +74 -57
files/functions.py CHANGED
@@ -44,39 +44,14 @@ import pathlib
44
  from pathlib import Path
45
  import shutil
46
 
 
 
47
  # Tesseract
48
  print(os.popen(f'cat /etc/debian_version').read())
49
  print(os.popen(f'cat /etc/issue').read())
50
  print(os.popen(f'apt search tesseract').read())
51
  import pytesseract
52
 
53
- ## model / feature extractor / tokenizer
54
-
55
- import torch
56
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
-
58
- # model 1
59
- from transformers import AutoTokenizer, AutoModelForTokenClassification
60
- model_id = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
61
- tokenizer1 = AutoTokenizer.from_pretrained(model_id)
62
- model1 = AutoModelForTokenClassification.from_pretrained(model_id);
63
- model1.to(device);
64
-
65
- from transformers import LayoutLMv2ForTokenClassification
66
- # model 2
67
- model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
68
- model2 = LayoutLMv2ForTokenClassification.from_pretrained(model_id);
69
- model2.to(device);
70
-
71
- # feature extractor
72
- from transformers import LayoutLMv2FeatureExtractor
73
- feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
74
-
75
- # tokenizer
76
- from transformers import AutoTokenizer
77
- tokenizer_id = "xlm-roberta-base"
78
- tokenizer2 = AutoTokenizer.from_pretrained(tokenizer_id)
79
-
80
  ## Key parameters
81
 
82
  # categories colors
@@ -96,27 +71,36 @@ label2color = {
96
 
97
  # bounding boxes start and end of a sequence
98
  cls_box = [0, 0, 0, 0]
99
- sep_box = [1000, 1000, 1000, 1000]
 
100
 
101
- # model
102
- model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
 
103
 
104
- # tokenizer
105
- tokenizer_id = "xlm-roberta-base"
106
 
107
  # (tokenization) The maximum length of a feature (sequence)
108
- if str(384) in model_id:
109
- max_length = 384
110
- elif str(512) in model_id:
111
- max_length = 512
112
  else:
113
- print("Error with max_length of chunks!")
 
 
 
 
 
 
 
114
 
115
  # (tokenization) overlap
116
  doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
117
 
118
  # max PDF page images that will be displayed
119
- max_imgboxes = 2
120
 
121
  # get files
122
  examples_dir = 'files/'
@@ -125,7 +109,7 @@ from huggingface_hub import hf_hub_download
125
  files = ["example.pdf", "blank.pdf", "blank.png", "languages_iso.csv", "languages_tesseract.csv", "wo_content.png"]
126
  for file_name in files:
127
  path_to_file = hf_hub_download(
128
- repo_id = "pierreguillou/Inference-APP-Document-Understanding-at-linelevel-v2",
129
  filename = "files/" + file_name,
130
  repo_type = "space"
131
  )
@@ -162,6 +146,32 @@ for lang_t, langcode_t in zip(langs_t,langscode_t):
162
 
163
  langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  ## General
166
 
167
  # get text and bounding boxes from an image
@@ -477,7 +487,7 @@ def extraction_data_from_image(images):
477
 
478
  ## Inference
479
 
480
- def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
481
 
482
  images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
483
 
@@ -600,7 +610,7 @@ class CustomDataset(Dataset):
600
  import torch.nn.functional as F
601
 
602
  # get predictions at token level
603
- def predictions_token_level(images, custom_encoded_dataset):
604
 
605
  num_imgs = len(images)
606
  if num_imgs > 0:
@@ -635,12 +645,20 @@ def predictions_token_level(images, custom_encoded_dataset):
635
 
636
  # get prediction with forward pass
637
  with torch.no_grad():
638
- output = model(
639
- input_ids=input_id.to(device),
640
- attention_mask=attention_mask.to(device),
641
- bbox=bbox.to(device),
642
- image=pixel_values.to(device)
643
- )
 
 
 
 
 
 
 
 
644
 
645
  # save probabilities of predictions in dictionnary
646
  if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
@@ -654,7 +672,7 @@ def predictions_token_level(images, custom_encoded_dataset):
654
  from functools import reduce
655
 
656
  # Get predictions (line level)
657
- def predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes):
658
 
659
  ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
660
  bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
@@ -711,14 +729,13 @@ def predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_i
711
  bbox_prev = [-100, -100, -100, -100]
712
  for probs, input_id, bbox in zip(ten_probs_list, ten_input_ids_list, ten_bboxes_list):
713
  bbox = denormalize_box(bbox, width, height)
714
- if bbox != bbox_prev and bbox != cls_box:
715
  bboxes_list.append(bbox)
716
  input_ids_dict[str(bbox)] = [input_id]
717
  probs_dict[str(bbox)] = [probs]
718
- else:
719
- if bbox != cls_box:
720
- input_ids_dict[str(bbox)].append(input_id)
721
- probs_dict[str(bbox)].append(probs)
722
  bbox_prev = bbox
723
 
724
  probs_bbox = dict()
@@ -749,7 +766,7 @@ def predictions_line_level(dataset, outputs, images_ids_list, chunk_ids, input_i
749
  print("An error occurred while getting predictions!")
750
 
751
  # Get labeled images with lines bounding boxes
752
- def get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
753
 
754
  labeled_images = list()
755
 
@@ -781,7 +798,7 @@ def get_labeled_images(dataset, images_ids_list, bboxes_list_dict, probs_dict_di
781
  return labeled_images
782
 
783
  # get data of encoded chunk
784
- def get_encoded_chunk_inference(index_chunk=None):
785
 
786
  # get datasets
787
  example = dataset
@@ -833,10 +850,10 @@ def get_encoded_chunk_inference(index_chunk=None):
833
  return image, df, num_tokens, page_no, num_pages
834
 
835
  # display chunk of PDF image and its data
836
- def display_chunk_lines_inference(index_chunk=None):
837
 
838
  # get image and image data
839
- image, df, num_tokens, page_no, num_pages = get_encoded_chunk_inference(index_chunk=index_chunk)
840
 
841
  # get data from dataframe
842
  input_ids = df["input_ids"]
 
44
  from pathlib import Path
45
  import shutil
46
 
47
+ from functools import partial
48
+
49
  # Tesseract
50
  print(os.popen(f'cat /etc/debian_version').read())
51
  print(os.popen(f'cat /etc/issue').read())
52
  print(os.popen(f'apt search tesseract').read())
53
  import pytesseract
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  ## Key parameters
56
 
57
  # categories colors
 
71
 
72
  # bounding boxes start and end of a sequence
73
  cls_box = [0, 0, 0, 0]
74
+ sep_box_lilt = cls_box
75
+ sep_box_layoutxlm = [1000, 1000, 1000, 1000]
76
 
77
+ # models
78
+ model_id_lilt = "pierreguillou/lilt-xlm-roberta-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
79
+ model_id_layoutxlm = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
80
 
81
+ # tokenizer for LayoutXLM
82
+ tokenizer_id_layoutxlm = "xlm-roberta-base"
83
 
84
  # (tokenization) The maximum length of a feature (sequence)
85
+ if str(384) in model_id_lilt:
86
+ max_length_lilt = 384
87
+ elif str(512) in model_id_lilt:
88
+ max_length_lilt = 512
89
  else:
90
+ print("Error with max_length_lilt of chunks!")
91
+
92
+ if str(384) in model_id_layoutxlm:
93
+ max_length_layoutxlm = 384
94
+ elif str(512) in model_id_layoutxlm:
95
+ max_length_layoutxlm = 512
96
+ else:
97
+ print("Error with max_length_layoutxlm of chunks!")
98
 
99
  # (tokenization) overlap
100
  doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
101
 
102
  # max PDF page images that will be displayed
103
+ max_imgboxes = 1
104
 
105
  # get files
106
  examples_dir = 'files/'
 
109
  files = ["example.pdf", "blank.pdf", "blank.png", "languages_iso.csv", "languages_tesseract.csv", "wo_content.png"]
110
  for file_name in files:
111
  path_to_file = hf_hub_download(
112
+ repo_id = "pierreguillou/Inference-APP-Document-Understanding-at-linelevel-v3",
113
  filename = "files/" + file_name,
114
  repo_type = "space"
115
  )
 
146
 
147
  langdetect2Tesseract = {v:k for k,v in Tesseract2langdetect.items()}
148
 
149
+ ## model / feature extractor / tokenizer
150
+
151
+ # get device
152
+ import torch
153
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
154
+
155
+ ## model LiLT
156
+ import transformers
157
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
158
+ tokenizer_lilt = AutoTokenizer.from_pretrained(model_id_lilt)
159
+ model_lilt = AutoModelForTokenClassification.from_pretrained(model_id_lilt);
160
+ model_lilt.to(device);
161
+
162
+ ## model LayoutXLM
163
+ from transformers import LayoutLMv2ForTokenClassification # LayoutXLMTokenizerFast,
164
+ model_layoutxlm = LayoutLMv2ForTokenClassification.from_pretrained(model_id_layoutxlm);
165
+ model_layoutxlm.to(device);
166
+
167
+ # feature extractor
168
+ from transformers import LayoutLMv2FeatureExtractor
169
+ feature_extractor = LayoutLMv2FeatureExtractor(apply_ocr=False)
170
+
171
+ # tokenizer
172
+ from transformers import AutoTokenizer
173
+ tokenizer_layoutxlm = AutoTokenizer.from_pretrained(tokenizer_id_layoutxlm)
174
+
175
  ## General
176
 
177
  # get text and bounding boxes from an image
 
487
 
488
  ## Inference
489
 
490
+ def prepare_inference_features(example, tokenizer, max_length, cls_box, sep_box):
491
 
492
  images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
493
 
 
610
  import torch.nn.functional as F
611
 
612
  # get predictions at token level
613
+ def predictions_token_level(images, custom_encoded_dataset, model_id, model):
614
 
615
  num_imgs = len(images)
616
  if num_imgs > 0:
 
645
 
646
  # get prediction with forward pass
647
  with torch.no_grad():
648
+
649
+ if model_id == model_id_lilt:
650
+ output = model(
651
+ input_ids=input_id.to(device),
652
+ attention_mask=attention_mask.to(device),
653
+ bbox=bbox.to(device),
654
+ )
655
+ elif model_id == model_id_layoutxlm:
656
+ output = model(
657
+ input_ids=input_id.to(device),
658
+ attention_mask=attention_mask.to(device),
659
+ bbox=bbox.to(device),
660
+ image=pixel_values.to(device)
661
+ )
662
 
663
  # save probabilities of predictions in dictionnary
664
  if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
 
672
  from functools import reduce
673
 
674
  # Get predictions (line level)
675
+ def predictions_line_level(max_length, tokenizer, id2label, dataset, outputs, images_ids_list, chunk_ids, input_ids, bboxes, cls_box, sep_box):
676
 
677
  ten_probs_dict, ten_input_ids_dict, ten_bboxes_dict = dict(), dict(), dict()
678
  bboxes_list_dict, input_ids_dict_dict, probs_dict_dict, df = dict(), dict(), dict(), dict()
 
729
  bbox_prev = [-100, -100, -100, -100]
730
  for probs, input_id, bbox in zip(ten_probs_list, ten_input_ids_list, ten_bboxes_list):
731
  bbox = denormalize_box(bbox, width, height)
732
+ if bbox != bbox_prev and bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
733
  bboxes_list.append(bbox)
734
  input_ids_dict[str(bbox)] = [input_id]
735
  probs_dict[str(bbox)] = [probs]
736
+ elif bbox != cls_box and bbox != sep_box and bbox[0] != bbox[2] and bbox[1] != bbox[3]:
737
+ input_ids_dict[str(bbox)].append(input_id)
738
+ probs_dict[str(bbox)].append(probs)
 
739
  bbox_prev = bbox
740
 
741
  probs_bbox = dict()
 
766
  print("An error occurred while getting predictions!")
767
 
768
  # Get labeled images with lines bounding boxes
769
+ def get_labeled_images(id2label, dataset, images_ids_list, bboxes_list_dict, probs_dict_dict):
770
 
771
  labeled_images = list()
772
 
 
798
  return labeled_images
799
 
800
  # get data of encoded chunk
801
+ def get_encoded_chunk_inference(tokenizer, dataset, encoded_dataset, index_chunk=None):
802
 
803
  # get datasets
804
  example = dataset
 
850
  return image, df, num_tokens, page_no, num_pages
851
 
852
  # display chunk of PDF image and its data
853
+ def display_chunk_lines_inference(dataset, encoded_dataset, index_chunk=None):
854
 
855
  # get image and image data
856
+ image, df, num_tokens, page_no, num_pages = get_encoded_chunk_inference(dataset, encoded_dataset, index_chunk=index_chunk)
857
 
858
  # get data from dataframe
859
  input_ids = df["input_ids"]