keesephillips commited on
Commit
24f52b5
·
verified ·
1 Parent(s): b3fc8d0

created dev branch and refactored code (#1)

Browse files

- created dev branch and refactored code (31e0d1a8efa41e56aa53d9e147159a71066ee1bc)

Files changed (5) hide show
  1. .gitignore +1 -0
  2. README.md +4 -2
  3. config.yaml +2 -2
  4. main.py +185 -17
  5. setup.py +5 -5
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ data/
README.md CHANGED
@@ -8,6 +8,7 @@ The purpose of this project is to perform very basic intelligent document proces
8
 
9
  ### If you want to run the full pipeline and train the model from scratch
10
  1. You will need to install all of the necessary packages to run the setup.py script beforehand
 
11
  3. You will then need to run setup.py to create the data pipeline and train the model
12
  4. You will then need to run the frontend to use the model
13
  ```bash
@@ -17,7 +18,7 @@ streamlit run main.py
17
  ```
18
 
19
  ### If you want to just run the frontend
20
- 1. You will need to install all of the necessary packages to run the setup.py script beforehand
21
  2. You will then need to run the frontend to use the model
22
  ```bash
23
  pip install -r requirements.txt
@@ -34,7 +35,8 @@ streamlit run main.py
34
  > - build_features.py: script to prepare the dataset for training
35
  > - model.py: script to train model and predict
36
  > - models: directory for trained models
37
- > - recommendation.pt: pytorch trained model for album recommendations
 
38
  > - data: directory for project data
39
  > - raw: directory for raw data
40
  > - processed: directory to store the processed data
 
8
 
9
  ### If you want to run the full pipeline and train the model from scratch
10
  1. You will need to install all of the necessary packages to run the setup.py script beforehand
11
+ 2. You will need to download pytesseract and add it to your Path if you are using Windows OS
12
  3. You will then need to run setup.py to create the data pipeline and train the model
13
  4. You will then need to run the frontend to use the model
14
  ```bash
 
18
  ```
19
 
20
  ### If you want to just run the frontend
21
+ 1. You will need to install all of the necessary packages to run the setup.py script beforehand and install pytesseract
22
  2. You will then need to run the frontend to use the model
23
  ```bash
24
  pip install -r requirements.txt
 
35
  > - build_features.py: script to prepare the dataset for training
36
  > - model.py: script to train model and predict
37
  > - models: directory for trained models
38
+ > - trained_yolov8.pt: pytorch trained model for album recommendations
39
+ > - gpt_model: directory to store the gpt model
40
  > - data: directory for project data
41
  > - raw: directory for raw data
42
  > - processed: directory to store the processed data
config.yaml CHANGED
@@ -1,6 +1,6 @@
1
  path: C:/Users/keese/term_project
2
- train: training/images
3
- val: validation/images
4
 
5
  names:
6
  0: text
 
1
  path: C:/Users/keese/term_project
2
+ train: data/processed/training/images
3
+ val: data/processed/validation/images
4
 
5
  names:
6
  0: text
main.py CHANGED
@@ -22,31 +22,199 @@ import pandas as pd
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
27
 
28
- if __name__ == '__main__':
 
29
 
30
- st.header('Spotify Playlists')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- img1, img2 = st.columns(2)
 
 
 
 
 
 
 
33
 
34
- music_notes = Image.open('assets/music_notes.png')
35
- img1.image(music_notes, use_column_width=True)
36
 
37
- trumpet = Image.open('assets/trumpet.png')
38
- img2.image(trumpet, use_column_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  with st.sidebar:
41
- playlist_name = st.selectbox(
42
- "Playlist Selection",
43
- ( list(set([1,2])) )
44
- )
45
-
46
- col1, col2 = st.columns(2)
47
- with col1:
48
- st.write(f'Artist')
49
- with col2:
50
- st.write(f'Album')
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
52
 
 
22
  import torch.nn as nn
23
  import torch.nn.functional as F
24
  import matplotlib.pyplot as plt
25
+ from ultralytics import YOLO
26
+ from PIL import Image, ImageDraw, ImageFont
27
+ import numpy as np
28
+ import cv2
29
+ import pytesseract
30
+ from PIL import ImageEnhance
31
+ import numpy as np
32
+ import os
33
+ import json
34
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
35
+ from datasets import load_dataset
36
+ from transformers import DataCollatorForLanguageModeling
37
+ from PIL import Image, ImageEnhance
38
+ from io import StringIO
39
+
40
 
41
+ def crop_image(model, original_image):
42
+ """
43
+ Crop the region of interest (table) from an image using a YOLO model.
44
+
45
+ Inputs:
46
+ model (YOLO): The YOLO model used for object detection.
47
+ image_file (str): Path to the image file to be processed.
48
+
49
+ Returns:
50
+ PIL.Image: The cropped image containing the detected table.
51
+ """
52
+ image_array = np.array(image)
53
+ results = model(image_array)
54
+
55
+ for r in results:
56
+ boxes = r.boxes
57
+
58
+ for box in boxes:
59
+ if box.cls == 3:
60
+ x1, y1, x2, y2 = box.xyxy[0]
61
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
62
+
63
+ table_image = original_image.crop((x1, y1, x2, y2))
64
+
65
+ return table_image
66
+ return
67
 
68
+ def process_image(model, image):
69
+ """
70
+ Process the uploaded image with YOLO model and draw bounding boxes with class-specific colors.
71
 
72
+ Inputs:
73
+ uploaded_image (UploadedFile): The image file uploaded through Streamlit.
74
 
75
+ Returns:
76
+ PIL.Image: The processed image with bounding boxes and labels.
77
+ """
78
+ colors = {'title': (255, 0, 0),
79
+ 'text': (0, 255, 0),
80
+ 'figure': (0, 0, 255),
81
+ 'table': (255, 255, 0),
82
+ 'list': (0, 255, 255)}
83
+
84
+ image_array = np.array(image)
85
+ results = model(image_array)
86
+
87
+ for result in results:
88
+ boxes = result.boxes.cpu().numpy()
89
+ for box in boxes:
90
+ r = box.xyxy[0].astype(int)
91
+ label = result.names[int(box.cls)]
92
+ color = colors.get(label.lower(), (255, 255, 255))
93
+
94
+ cv2.rectangle(image_array, r[:2], r[2:], color, 2)
95
+
96
+ label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
97
+ top_left = (r[0], r[1] - label_size[1] - baseline)
98
+ bottom_right = (r[0] + label_size[0], r[1])
99
+ cv2.rectangle(image_array, top_left, bottom_right, color, cv2.FILLED)
100
+ cv2.putText(image_array, label, (r[0], r[1] - baseline),
101
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
102
 
103
+ return Image.fromarray(image_array)
104
+
105
+ def improve_ocr_accuracy(img):
106
+ """
107
+ Preprocess the image to improve OCR accuracy.
108
+
109
+ This function resizes the image, increases contrast, and applies thresholding
110
+ to enhance the image for better OCR results.
111
 
112
+ Inputs:
113
+ img (PIL.Image): The input image to be processed.
114
 
115
+ Returns:
116
+ numpy.ndarray: A binary thresholded image as a numpy array.
117
+ """
118
+ img = img.resize((img.width * 4, img.height * 4))
119
+
120
+ enhancer = ImageEnhance.Contrast(img)
121
+ img = enhancer.enhance(2)
122
+
123
+ _, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV)
124
+
125
+ return thresh
126
+
127
+ def ocr_core(image):
128
+ """
129
+ Perform OCR on the given image and process the extracted text.
130
+
131
+ This function uses pytesseract to extract text from the image and then
132
+ processes the extracted data to format it with appropriate line breaks
133
+ and spacing.
134
+
135
+ Inputs:
136
+ image (numpy.ndarray): The preprocessed image as a numpy array.
137
+
138
+ Returns:
139
+ str: The extracted and formatted text from the image.
140
+ """
141
+ data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
142
+ df = pd.DataFrame(data)
143
+ df = df[df['conf'] != -1]
144
+ df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int)
145
+ df['prev_width'] = df['width'].shift(1).fillna(0).astype(int)
146
+ df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int)
147
+ df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1)
148
+ df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 80 else x['text'], axis=1)
149
+ ocr_text = ""
150
+ for text in df['text']:
151
+ ocr_text += text + ' '
152
+ return ocr_text
153
+
154
+ def generate_csv_from_text(tokenizer, model, ocr_text):
155
+ """
156
+ Generate CSV text from OCR extracted text using the gpt model
157
+
158
+ This function takes the OCR extracted text, processes it through a language model,
159
+ and generates CSV formatted text.
160
+
161
+ Inputs:
162
+ tokenizer: The tokenizer for the gpt model
163
+ model: The gpt model used for csv
164
+ ocr_text (str): The text extracted from OCR
165
+
166
+ Returns:
167
+ str: The generated CSV formatted text.
168
+ """
169
+ inputs = tokenizer.encode(ocr_text, return_tensors='pt')
170
+ outputs = model.generate(inputs, max_length=1000, num_return_sequences=1)
171
+ csv_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
172
+
173
+ return csv_text
174
+
175
+ if __name__ == '__main__':
176
+ pytesseract.pytesseract.tesseract_cmd = r'C:/Program Files/Tesseract-OCR/tesseract.exe' # Update this path for your system
177
+
178
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
179
+
180
+ model = YOLO(os.getcwd() + '/models/trained_yolov8.pt')
181
+ gpt_model = GPT2LMHeadModel.from_pretrained(os.getcwd() + '/models/gpt_model')
182
+ tokenizer = GPT2Tokenizer.from_pretrained(os.getcwd() + '/models/gpt_model')
183
+
184
+ st.header('''
185
+ Intelligent Document Processing: Table Extraction
186
+ ''')
187
+
188
+ header_img = Image.open('assets/header_img.png')
189
+ st.image(header_img, use_column_width=True)
190
 
191
  with st.sidebar:
192
+ user_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
193
+
194
+ if user_image is not None:
195
+ st.divider()
196
+ image = Image.open(user_image)
197
+ st.image(image, caption='Uploaded Image', use_column_width=True)
198
+
199
+ st.divider()
200
+ st.subheader("Document Classes:")
201
+ processed_image = process_image(model, image)
202
+ st.image(processed_image, caption='Processed Image', use_column_width=True)
203
+
204
+ st.divider()
205
+ st.subheader("Table Cropped Image:")
206
+ cropped_table = crop_image(model, image)
207
+ st.image(cropped_table, caption='Cropped Table', use_column_width=True)
208
+
209
+ st.divider()
210
+ st.subheader("OCR Text:")
211
+ improved_image = improve_ocr_accuracy(cropped_table)
212
+ ocr_text = ocr_core(improved_image)
213
+ st.write(ocr_text)
214
 
215
+ st.divider()
216
+ st.subheader("CSV Output:")
217
+ csv_output = generate_csv_from_text(tokenizer,gpt_model,ocr_text)
218
+ data = StringIO(csv_output)
219
+ st.dataframe(pd.read_csv(data, sep=",").head())
220
 
setup.py CHANGED
@@ -1,14 +1,14 @@
1
  import subprocess
2
  import sys
3
 
4
- script = 'make_dataset.py'
5
- command = f'{sys.executable} scripts/{script}'
6
- subprocess.run(command, shell=True)
7
 
8
  script = 'build_features.py'
9
- command = f'{sys.executable} python scripts/{script}'
10
  subprocess.run(command, shell=True)
11
 
12
  script = 'model.py'
13
- command = f'{sys.executable} python scripts/{script}'
14
  subprocess.run(command, shell=True)
 
1
  import subprocess
2
  import sys
3
 
4
+ # script = 'make_dataset.py'
5
+ # command = f'{sys.executable} scripts/{script}'
6
+ # subprocess.run(command, shell=True)
7
 
8
  script = 'build_features.py'
9
+ command = f'{sys.executable} scripts/{script}'
10
  subprocess.run(command, shell=True)
11
 
12
  script = 'model.py'
13
+ command = f'{sys.executable} scripts/{script}'
14
  subprocess.run(command, shell=True)