acverma commited on
Commit
bf830aa
1 Parent(s): ba26445
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -42,14 +42,14 @@ from transformers import AutoProcessor
42
  from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D
43
  from datasets import load_dataset # this dataset uses the new Image feature :)
44
 
45
- from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
46
 
47
  #import cv2
48
  from PIL import Image, ImageDraw, ImageFont
49
 
50
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base",apply_ocr = True)
51
 
52
- model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
53
 
54
  dataset = load_dataset("nielsr/funsd-layoutlmv3")
55
 
@@ -93,7 +93,7 @@ if isinstance(features[label_column_name].feature, ClassLabel):
93
  id2label = {k: v for k,v in enumerate(label_list)}
94
  label2id = {v: k for k,v in enumerate(label_list)}
95
  else:
96
- label_list = get_label_list(dataset["train"][label_column_name])
97
  id2label = {k: v for k,v in enumerate(label_list)}
98
  label2id = {v: k for k,v in enumerate(label_list)}
99
  num_labels = len(label_list)
 
42
  from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D
43
  from datasets import load_dataset # this dataset uses the new Image feature :)
44
 
45
+ from transformers import LayoutLMv3Processor,LayoutLMv3ForTokenClassification, AutoProcessor ,AutoModelForTokenClassification
46
 
47
  #import cv2
48
  from PIL import Image, ImageDraw, ImageFont
49
 
50
  processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base",apply_ocr = True)
51
 
52
+ model = LayoutLMv3ForTokenClassification.from_pretrained("nielsr/layoutlmv3-finetuned-funsd")
53
 
54
  dataset = load_dataset("nielsr/funsd-layoutlmv3")
55
 
 
93
  id2label = {k: v for k,v in enumerate(label_list)}
94
  label2id = {v: k for k,v in enumerate(label_list)}
95
  else:
96
+ label_list = get_label_list(dataset["test"][label_column_name])
97
  id2label = {k: v for k,v in enumerate(label_list)}
98
  label2id = {v: k for k,v in enumerate(label_list)}
99
  num_labels = len(label_list)