keesephillips commited on
Commit
40d3327
·
verified ·
1 Parent(s): b6cfd37

updated for error handling in cases with no table

Browse files
Files changed (1) hide show
  1. main.py +233 -221
main.py CHANGED
@@ -1,221 +1,233 @@
1
- """
2
- Attribution: https://github.com/AIPI540/AIPI540-Deep-Learning-Applications/
3
-
4
- Jon Reifschneider
5
- Brinnae Bent
6
-
7
- """
8
-
9
- import streamlit as st
10
- from PIL import Image
11
- import numpy as np
12
- import os
13
- import numpy as np
14
- import pandas as pd
15
- import pandas as pd
16
- import os
17
- import json
18
- import pandas as pd
19
- import torch
20
- import numpy as np
21
- import pandas as pd
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- import matplotlib.pyplot as plt
25
- from ultralytics import YOLO
26
- from PIL import Image, ImageDraw, ImageFont
27
- import numpy as np
28
- import cv2
29
- import pytesseract
30
- from PIL import ImageEnhance
31
- import numpy as np
32
- import os
33
- import json
34
- from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
35
- from datasets import load_dataset
36
- from transformers import DataCollatorForLanguageModeling
37
- from PIL import Image, ImageEnhance
38
- from io import StringIO
39
-
40
-
41
- def crop_image(model, original_image):
42
- """
43
- Crop the region of interest (table) from an image using a YOLO model.
44
-
45
- Inputs:
46
- model (YOLO): The YOLO model used for object detection.
47
- original_image (PIL.image): The image to be processed.
48
-
49
- Returns:
50
- PIL.Image: The cropped image containing the detected table.
51
- """
52
- image_array = np.array(image)
53
- results = model(image_array)
54
-
55
- for r in results:
56
- boxes = r.boxes
57
-
58
- for box in boxes:
59
- if box.cls == 3:
60
- x1, y1, x2, y2 = box.xyxy[0]
61
- x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
62
-
63
- table_image = original_image.crop((x1, y1, x2, y2))
64
-
65
- return table_image
66
- return
67
-
68
- def process_image(model, image):
69
- """
70
- Process the uploaded image with YOLO model and draw bounding boxes with class-specific colors.
71
-
72
- Inputs:
73
- model: The trained YOLO model
74
- image: The image file uploaded through Streamlit.
75
-
76
- Returns:
77
- PIL.Image: The processed image with bounding boxes and labels.
78
- """
79
- colors = {'title': (255, 0, 0),
80
- 'text': (0, 255, 0),
81
- 'figure': (0, 0, 255),
82
- 'table': (255, 255, 0),
83
- 'list': (0, 255, 255)}
84
-
85
- image_array = np.array(image)
86
- results = model(image_array)
87
-
88
- for result in results:
89
- boxes = result.boxes.cpu().numpy()
90
- for box in boxes:
91
- r = box.xyxy[0].astype(int)
92
- label = result.names[int(box.cls)]
93
- color = colors.get(label.lower(), (255, 255, 255))
94
-
95
- cv2.rectangle(image_array, r[:2], r[2:], color, 2)
96
-
97
- label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
98
- top_left = (r[0], r[1] - label_size[1] - baseline)
99
- bottom_right = (r[0] + label_size[0], r[1])
100
- cv2.rectangle(image_array, top_left, bottom_right, color, cv2.FILLED)
101
- cv2.putText(image_array, label, (r[0], r[1] - baseline),
102
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
103
-
104
- return Image.fromarray(image_array)
105
-
106
- def improve_ocr_accuracy(img):
107
- """
108
- Preprocess the image to improve OCR accuracy.
109
-
110
- This function resizes the image, increases contrast, and applies thresholding
111
- to enhance the image for better OCR results.
112
-
113
- Inputs:
114
- img (PIL.Image): The input image to be processed.
115
-
116
- Returns:
117
- numpy.ndarray: A binary thresholded image as a numpy array.
118
- """
119
- img = img.resize((img.width * 4, img.height * 4))
120
-
121
- enhancer = ImageEnhance.Contrast(img)
122
- img = enhancer.enhance(2)
123
-
124
- _, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV)
125
-
126
- return thresh
127
-
128
- def ocr_core(image):
129
- """
130
- Perform OCR on the given image and process the extracted text.
131
-
132
- This function uses pytesseract to extract text from the image and then
133
- processes the extracted data to format it with appropriate line breaks
134
- and spacing.
135
-
136
- Inputs:
137
- image (numpy.ndarray): The preprocessed image as a numpy array.
138
-
139
- Returns:
140
- str: The extracted and formatted text from the image.
141
- """
142
- data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
143
- df = pd.DataFrame(data)
144
- df = df[df['conf'] != -1]
145
- df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int)
146
- df['prev_width'] = df['width'].shift(1).fillna(0).astype(int)
147
- df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int)
148
- df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1)
149
- df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 80 else x['text'], axis=1)
150
- ocr_text = ""
151
- for text in df['text']:
152
- ocr_text += text + ' '
153
- return ocr_text
154
-
155
- def generate_csv_from_text(tokenizer, model, ocr_text):
156
- """
157
- Generate CSV text from OCR extracted text using the gpt model
158
-
159
- This function takes the OCR extracted text, processes it through a language model,
160
- and generates CSV formatted text.
161
-
162
- Inputs:
163
- tokenizer: The tokenizer for the gpt model
164
- model: The gpt model used for csv
165
- ocr_text (str): The text extracted from OCR
166
-
167
- Returns:
168
- str: The generated CSV formatted text.
169
- """
170
- inputs = tokenizer.encode(ocr_text, return_tensors='pt')
171
- outputs = model.generate(inputs, max_length=1000, num_return_sequences=1)
172
- csv_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
-
174
- return csv_text
175
-
176
- if __name__ == '__main__':
177
- pytesseract.pytesseract.tesseract_cmd = r'C:/Program Files/Tesseract-OCR/tesseract.exe' # Update this path for your system
178
-
179
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
-
181
- model = YOLO(os.getcwd() + '/models/trained_yolov8.pt')
182
- gpt_model = GPT2LMHeadModel.from_pretrained(os.getcwd() + '/models/gpt_model')
183
- tokenizer = GPT2Tokenizer.from_pretrained(os.getcwd() + '/models/gpt_model')
184
-
185
- st.header('''
186
- Intelligent Document Processing: Table Extraction
187
- ''')
188
-
189
- header_img = Image.open('assets/header_img.png')
190
- st.image(header_img, use_column_width=True)
191
-
192
- with st.sidebar:
193
- user_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
194
-
195
- if user_image is not None:
196
- st.divider()
197
- image = Image.open(user_image)
198
- st.image(image, caption='Uploaded Image', use_column_width=True)
199
-
200
- st.divider()
201
- st.subheader("Document Classes:")
202
- processed_image = process_image(model, image)
203
- st.image(processed_image, caption='Processed Image', use_column_width=True)
204
-
205
- st.divider()
206
- st.subheader("Table Cropped Image:")
207
- cropped_table = crop_image(model, image)
208
- st.image(cropped_table, caption='Cropped Table', use_column_width=True)
209
-
210
- st.divider()
211
- st.subheader("OCR Text:")
212
- improved_image = improve_ocr_accuracy(cropped_table)
213
- ocr_text = ocr_core(improved_image)
214
- st.write(ocr_text)
215
-
216
- st.divider()
217
- st.subheader("CSV Output:")
218
- csv_output = generate_csv_from_text(tokenizer,gpt_model,ocr_text)
219
- data = StringIO(csv_output)
220
- st.dataframe(pd.read_csv(data, sep=",").head())
221
-
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attribution: https://github.com/AIPI540/AIPI540-Deep-Learning-Applications/
3
+
4
+ Jon Reifschneider
5
+ Brinnae Bent
6
+
7
+ """
8
+
9
+ import streamlit as st
10
+ from PIL import Image
11
+ import numpy as np
12
+ import os
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pandas as pd
16
+ import os
17
+ import json
18
+ import pandas as pd
19
+ import torch
20
+ import numpy as np
21
+ import pandas as pd
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import matplotlib.pyplot as plt
25
+ from ultralytics import YOLO
26
+ from PIL import Image, ImageDraw, ImageFont
27
+ import numpy as np
28
+ import cv2
29
+ import pytesseract
30
+ from PIL import ImageEnhance
31
+ import numpy as np
32
+ import os
33
+ import json
34
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
35
+ from datasets import load_dataset
36
+ from transformers import DataCollatorForLanguageModeling
37
+ from PIL import Image, ImageEnhance
38
+ from io import StringIO
39
+
40
+
41
+ def crop_image(model, original_image):
42
+ """
43
+ Crop the region of interest (table) from an image using a YOLO model.
44
+
45
+ Inputs:
46
+ model (YOLO): The YOLO model used for object detection.
47
+ original_image (PIL.image): The image to be processed.
48
+
49
+ Returns:
50
+ PIL.Image: The cropped image containing the detected table.
51
+ """
52
+ image_array = np.array(image)
53
+ results = model(image_array)
54
+
55
+ for r in results:
56
+ boxes = r.boxes
57
+
58
+ for box in boxes:
59
+ if box.cls == 3:
60
+ x1, y1, x2, y2 = box.xyxy[0]
61
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
62
+
63
+ table_image = original_image.crop((x1, y1, x2, y2))
64
+
65
+ return table_image
66
+ return
67
+
68
+ def process_image(model, image):
69
+ """
70
+ Process the uploaded image with YOLO model and draw bounding boxes with class-specific colors.
71
+
72
+ Inputs:
73
+ model: The trained YOLO model
74
+ image: The image file uploaded through Streamlit.
75
+
76
+ Returns:
77
+ PIL.Image: The processed image with bounding boxes and labels.
78
+ """
79
+ colors = {'title': (255, 0, 0),
80
+ 'text': (0, 255, 0),
81
+ 'figure': (0, 0, 255),
82
+ 'table': (255, 255, 0),
83
+ 'list': (0, 255, 255)}
84
+
85
+ image_array = np.array(image)
86
+ results = model(image_array)
87
+
88
+ for result in results:
89
+ boxes = result.boxes.cpu().numpy()
90
+ for box in boxes:
91
+ r = box.xyxy[0].astype(int)
92
+ label = result.names[int(box.cls)]
93
+ color = colors.get(label.lower(), (255, 255, 255))
94
+
95
+ cv2.rectangle(image_array, r[:2], r[2:], color, 2)
96
+
97
+ label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
98
+ top_left = (r[0], r[1] - label_size[1] - baseline)
99
+ bottom_right = (r[0] + label_size[0], r[1])
100
+ cv2.rectangle(image_array, top_left, bottom_right, color, cv2.FILLED)
101
+ cv2.putText(image_array, label, (r[0], r[1] - baseline),
102
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
103
+
104
+ return Image.fromarray(image_array)
105
+
106
+ def improve_ocr_accuracy(img):
107
+ """
108
+ Preprocess the image to improve OCR accuracy.
109
+
110
+ This function resizes the image, increases contrast, and applies thresholding
111
+ to enhance the image for better OCR results.
112
+
113
+ Inputs:
114
+ img (PIL.Image): The input image to be processed.
115
+
116
+ Returns:
117
+ numpy.ndarray: A binary thresholded image as a numpy array.
118
+ """
119
+ img = img.resize((img.width * 4, img.height * 4))
120
+
121
+ enhancer = ImageEnhance.Contrast(img)
122
+ img = enhancer.enhance(2)
123
+
124
+ _, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV)
125
+
126
+ return thresh
127
+
128
+ def ocr_core(image):
129
+ """
130
+ Perform OCR on the given image and process the extracted text.
131
+
132
+ This function uses pytesseract to extract text from the image and then
133
+ processes the extracted data to format it with appropriate line breaks
134
+ and spacing.
135
+
136
+ Inputs:
137
+ image (numpy.ndarray): The preprocessed image as a numpy array.
138
+
139
+ Returns:
140
+ str: The extracted and formatted text from the image.
141
+ """
142
+ data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
143
+ df = pd.DataFrame(data)
144
+ df = df[df['conf'] != -1]
145
+ df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int)
146
+ df['prev_width'] = df['width'].shift(1).fillna(0).astype(int)
147
+ df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int)
148
+ df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1)
149
+ df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 80 else x['text'], axis=1)
150
+ ocr_text = ""
151
+ for text in df['text']:
152
+ ocr_text += text + ' '
153
+ return ocr_text
154
+
155
+ def generate_csv_from_text(tokenizer, model, ocr_text):
156
+ """
157
+ Generate CSV text from OCR extracted text using the gpt model
158
+
159
+ This function takes the OCR extracted text, processes it through a language model,
160
+ and generates CSV formatted text.
161
+
162
+ Inputs:
163
+ tokenizer: The tokenizer for the gpt model
164
+ model: The gpt model used for csv
165
+ ocr_text (str): The text extracted from OCR
166
+
167
+ Returns:
168
+ str: The generated CSV formatted text.
169
+ """
170
+ inputs = tokenizer.encode(ocr_text, return_tensors='pt')
171
+ outputs = model.generate(inputs, max_length=1000, num_return_sequences=1)
172
+ csv_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
+
174
+ return csv_text
175
+
176
+ if __name__ == '__main__':
177
+ # pytesseract.pytesseract.tesseract_cmd = r'C:/Program Files/Tesseract-OCR/tesseract.exe' # Update this path for your system
178
+
179
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
+
181
+ model = YOLO(os.getcwd() + '/models/trained_yolov8.pt')
182
+ gpt_model = GPT2LMHeadModel.from_pretrained(os.getcwd() + '/models/gpt_model')
183
+ tokenizer = GPT2Tokenizer.from_pretrained(os.getcwd() + '/models/gpt_model')
184
+
185
+ st.header('''
186
+ Intelligent Document Processing: Table Extraction
187
+ ''')
188
+
189
+ header_img = Image.open('assets/header_img.png')
190
+ st.image(header_img, use_column_width=True)
191
+
192
+ st.subheader("Please upload an image of a scanned document with a table using the sidebar")
193
+
194
+ with st.sidebar:
195
+ user_image = st.file_uploader("Upload an image of a scanned document", type=["png", "jpg", "jpeg"])
196
+
197
+ if user_image is not None:
198
+ st.divider()
199
+ image = Image.open(user_image)
200
+ st.image(image, caption='Uploaded Image', use_column_width=True)
201
+
202
+ st.divider()
203
+ st.subheader("Document Classes:")
204
+ processed_image = process_image(model, image)
205
+ st.image(processed_image, caption='Processed Image', use_column_width=True)
206
+
207
+ try:
208
+ cropped_table = crop_image(model, image)
209
+ st.divider()
210
+ st.subheader("Table Cropped Image:")
211
+ st.image(cropped_table, caption='Cropped Table', use_column_width=True)
212
+
213
+ improved_image = improve_ocr_accuracy(cropped_table)
214
+ st.divider()
215
+ st.subheader("Improved Table Image:")
216
+ st.image(improved_image, caption='Improved Table Image', use_column_width=True)
217
+
218
+ ocr_text = ocr_core(improved_image)
219
+ st.divider()
220
+ st.subheader("OCR Text:")
221
+ st.write(ocr_text)
222
+
223
+ csv_output = generate_csv_from_text(tokenizer,gpt_model,ocr_text)
224
+ st.divider()
225
+ st.subheader("CSV Output:")
226
+ st.write(csv_output.encode('utf-8'))
227
+ except:
228
+ st.divider()
229
+ st.subheader("Error:")
230
+ st.write("Please upload a scanned document with a table")
231
+
232
+
233
+