Karol Idaszak commited on
Commit
4ecae1b
1 Parent(s): d15bee8

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +70 -63
main.py CHANGED
@@ -78,69 +78,76 @@ def create_pretty_table(data):
78
 
79
 
80
  def interference(example, page_number=0):
81
- image, words, boxes = extract_data_from_pdf(example, page_number)
82
- boxes = [list(map(int, box)) for box in boxes]
83
-
84
- # Process the image and words
85
- model = AutoModelForTokenClassification.from_pretrained(
86
- "karida/LayoutLMv3_RFP",
87
- use_auth_token=MODEL_KEY
88
- )
89
- processor = AutoProcessor.from_pretrained(
90
- "microsoft/layoutlmv3-base", apply_ocr=False
91
- )
92
- encoding = processor(image, words, boxes=boxes, return_tensors="pt")
93
-
94
- # Prediction
95
- with torch.no_grad():
96
- outputs = model(**encoding)
97
-
98
- logits = outputs.logits
99
- predictions = logits.argmax(-1).squeeze().tolist()
100
- model_words = encoding.word_ids()
101
-
102
- # Process predictions
103
- token_boxes = encoding.bbox.squeeze().tolist()
104
- width, height = image.size
105
-
106
- true_predictions = [model.config.id2label[pred] for pred in predictions]
107
- true_boxes = token_boxes
108
- # Draw annotations on the image
109
- draw = ImageDraw.Draw(image)
110
- font = ImageFont.load_default()
111
-
112
- def iob_to_label(label):
113
- label = label[2:]
114
- return "other" if not label else label.lower()
115
-
116
- label2color = {
117
- "question": "blue",
118
- "answer": "green",
119
- "header": "orange",
120
- "other": "violet",
121
- }
122
-
123
- # print(len(true_predictions), len(true_boxes), len(model_words))
124
-
125
- table = []
126
- ids = set()
127
-
128
- for prediction, box, model_word in zip(
129
- true_predictions, true_boxes, model_words
130
- ):
131
- predicted_label = iob_to_label(prediction)
132
- draw.rectangle(box, outline=label2color[predicted_label], width=2)
133
- # draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
134
- if model_word and model_word not in ids and predicted_label != "other":
135
- ids.add(model_word)
136
- table.append([predicted_label[0], words[model_word]])
137
-
138
- values = merge_pairs_v2(table)
139
- values = [
140
- ["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
141
- ]
142
- table = create_pretty_table(values)
143
- return image, table
 
 
 
 
 
 
 
144
 
145
 
146
  import gradio as gr
 
78
 
79
 
80
  def interference(example, page_number=0):
81
+ try:
82
+ image, words, boxes = extract_data_from_pdf(example, page_number)
83
+ boxes = [list(map(int, box)) for box in boxes]
84
+
85
+ # Process the image and words
86
+ model = AutoModelForTokenClassification.from_pretrained(
87
+ "karida/LayoutLMv3_RFP",
88
+ use_auth_token=MODEL_KEY
89
+ )
90
+ processor = AutoProcessor.from_pretrained(
91
+ "microsoft/layoutlmv3-base", apply_ocr=False
92
+ )
93
+ encoding = processor(image, words, boxes=boxes, return_tensors="pt")
94
+
95
+ # Prediction
96
+ with torch.no_grad():
97
+ outputs = model(**encoding)
98
+
99
+ logits = outputs.logits
100
+ predictions = logits.argmax(-1).squeeze().tolist()
101
+ model_words = encoding.word_ids()
102
+
103
+ # Process predictions
104
+ token_boxes = encoding.bbox.squeeze().tolist()
105
+ width, height = image.size
106
+
107
+ true_predictions = [model.config.id2label[pred] for pred in predictions]
108
+ true_boxes = token_boxes
109
+ # Draw annotations on the image
110
+ draw = ImageDraw.Draw(image)
111
+ font = ImageFont.load_default()
112
+
113
+ def iob_to_label(label):
114
+ label = label[2:]
115
+ return "other" if not label else label.lower()
116
+
117
+ label2color = {
118
+ "question": "blue",
119
+ "answer": "green",
120
+ "header": "orange",
121
+ "other": "violet",
122
+ }
123
+
124
+ # print(len(true_predictions), len(true_boxes), len(model_words))
125
+
126
+ table = []
127
+ ids = set()
128
+
129
+ for prediction, box, model_word in zip(
130
+ true_predictions, true_boxes, model_words
131
+ ):
132
+ predicted_label = iob_to_label(prediction)
133
+ draw.rectangle(box, outline=label2color[predicted_label], width=2)
134
+ # draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
135
+ if model_word and model_word not in ids and predicted_label != "other":
136
+ ids.add(model_word)
137
+ table.append([predicted_label[0], words[model_word]])
138
+
139
+ values = merge_pairs_v2(table)
140
+ values = [
141
+ ["Heder", x[1]] if x[0] == "q" else ["Section", x[1]] for x in values
142
+ ]
143
+ table = create_pretty_table(values)
144
+ return image, table
145
+ except IndexError as e:
146
+ # Return a custom HTML-styled error message if an IndexError occurs
147
+ return f"<div style='color: grey; font-weight: bold;'>Error: in the current version of the model, the maximum number of words per page is 512.</div>"
148
+ except Exception as e:
149
+ # Handle other exceptions
150
+ return f"<div style='color: grey; font-weight: bold;'>An error occurred: {str(e)}</div>"
151
 
152
 
153
  import gradio as gr