keesephillips commited on
Commit
b3fc8d0
·
verified ·
1 Parent(s): fe39ef8

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ notebooks/model.ipynb filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,54 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AIPI Term Project
2
+ ## Developer: Keese Phillips
3
+
4
+ ## About:
5
+ The purpose of this project is to perform very basic intelligent document processing (IDP) to extract a table from a document image. This can be a document that is in a PDF or image format that cannot be mapped directly to a csv file. The steps in this process is table detection, optical character recognition (OCR), table extraction and conversion to csv format.
6
+
7
+ ## How to run the project
8
+
9
+ ### If you want to run the full pipeline and train the model from scratch
10
+ 1. You will need to install all of the necessary packages to run the setup.py script beforehand
11
+ 3. You will then need to run setup.py to create the data pipeline and train the model
12
+ 4. You will then need to run the frontend to use the model
13
+ ```bash
14
+ pip install -r requirements.txt
15
+ python setup.py
16
+ streamlit run main.py
17
+ ```
18
+
19
+ ### If you want to just run the frontend
20
+ 1. You will need to install all of the necessary packages to run the setup.py script beforehand
21
+ 2. You will then need to run the frontend to use the model
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ streamlit run main.py
25
+ ```
26
+
27
+ ## Project Structure
28
+ > - requirements.txt: list of python libraries to download before running project
29
+ > - setup.py: script to set up project (get data, train model)
30
+ > - main.py: main script/notebook to run streamlit user interface
31
+ > - assets: directory for images used in frontend
32
+ > - scripts: directory for pipeline scripts or utility scripts
33
+ > - make_dataset.py: script to get data
34
+ > - build_features.py: script to prepare the dataset for training
35
+ > - model.py: script to train model and predict
36
+ > - models: directory for trained models
37
+ > - recommendation.pt: pytorch trained model for album recommendations
38
+ > - data: directory for project data
39
+ > - raw: directory for raw data
40
+ > - processed: directory to store the processed data
41
+ > - outputs: directory to store the prepared data
42
+ > - notebooks: directory to store any exploration notebooks used
43
+ > - .gitignore: git ignore file
44
+
45
+ ## [Data source](https://github.com/ibm-aur-nlp/PubLayNet)
46
+ The data used to train the model was provided by [IBM](https://developer.ibm.com/exchanges/data/all/publaynet/) and [PubLayNet: largest dataset ever for document layout analysis](https://arxiv.org/abs/1908.07836). As per their dataset description:
47
+ > PubLayNet is a large dataset of document images, of which the layout is annotated with both bounding boxes and polygonal segmentations. The source of the documents is PubMed Central Open Access Subset (commercial use collection). The annotations are automatically generated by matching the PDF format and the XML format of the articles in the PubMed Central Open Access Subset.
48
+
49
+ ## Contributions
50
+ Brinnae Bent
51
+ Jon Reifschneider
52
+ Xu Zhong
53
+ Jianbin Tang
54
+ Antonio Jimeno Yepes
assets/music_notes.png ADDED
assets/trumpet.png ADDED
config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ path: C:/Users/keese/term_project
2
+ train: training/images
3
+ val: validation/images
4
+
5
+ names:
6
+ 0: text
7
+ 1: title
8
+ 2: list
9
+ 3: table
10
+ 4: figure
hand_labeled_tables.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c72e57db0f0e7b6f770771b8b212e991afda049773d95120d5bae783b110ada8
3
+ size 2028840
main.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Attribution: https://github.com/AIPI540/AIPI540-Deep-Learning-Applications/
3
+
4
+ Jon Reifschneider
5
+ Brinnae Bent
6
+
7
+ """
8
+
9
+ import streamlit as st
10
+ from PIL import Image
11
+ import numpy as np
12
+ import os
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pandas as pd
16
+ import os
17
+ import json
18
+ import pandas as pd
19
+ import torch
20
+ import numpy as np
21
+ import pandas as pd
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import matplotlib.pyplot as plt
25
+
26
+
27
+
28
+ if __name__ == '__main__':
29
+
30
+ st.header('Spotify Playlists')
31
+
32
+ img1, img2 = st.columns(2)
33
+
34
+ music_notes = Image.open('assets/music_notes.png')
35
+ img1.image(music_notes, use_column_width=True)
36
+
37
+ trumpet = Image.open('assets/trumpet.png')
38
+ img2.image(trumpet, use_column_width=True)
39
+
40
+ with st.sidebar:
41
+ playlist_name = st.selectbox(
42
+ "Playlist Selection",
43
+ ( list(set([1,2])) )
44
+ )
45
+
46
+ col1, col2 = st.columns(2)
47
+ with col1:
48
+ st.write(f'Artist')
49
+ with col2:
50
+ st.write(f'Album')
51
+
52
+
notebooks/model.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e580316d90588119633a0091618c9eba64d964086822c44c4e41c96101c7177
3
+ size 17075128
notebooks/svm.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
Binary file (9.92 kB). View file
 
scripts/build_features.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import json
4
+ import sys
5
+ import requests
6
+ import tarfile
7
+ import numpy as np
8
+ from PIL import Image
9
+ import PIL.Image
10
+ from pathlib import Path
11
+ import shutil
12
+ from PIL import Image
13
+ import pandas as pd
14
+ from PIL import ImageFont, ImageDraw
15
+ from IPython.display import display, Image
16
+
17
+ from matplotlib import pyplot as plt
18
+ import cv2 as cv
19
+
20
+
21
+ def get_data_and_annots():
22
+ images={}
23
+ with open('data/raw/label/publaynet/train.json') as t:
24
+ data=json.load(t)
25
+
26
+ for train_images in os.walk('data/raw/train/publaynet/train'):
27
+ train_imgs = train_images[2]
28
+
29
+ for image in data['images']:
30
+ if image['file_name'] in train_imgs:
31
+ images[image['id']] = {'file_name': "data/raw/train/publaynet/train/" + image['file_name'], 'annotations': []}
32
+ if len(images) == 10000:
33
+ break
34
+
35
+ for ann in data['annotations']:
36
+ if ann['image_id'] in images.keys():
37
+ images[ann['image_id']]['annotations'].append(ann)
38
+ return images,data
39
+
40
+
41
+ def markup(samples,image, annotations):
42
+ ''' Draws the segmentation, bounding box, and label of each annotation
43
+ '''
44
+ draw = ImageDraw.Draw(image, 'RGBA')
45
+ font = ImageFont.load_default() # You can specify a different font if needed
46
+ for annotation in annotations:
47
+ # Draw segmentation
48
+ draw.polygon(annotation['segmentation'][0],
49
+ fill=colors[samples['categories'][annotation['category_id'] - 1]['name']] + (64,))
50
+ # Draw bbox
51
+ draw.rectangle(
52
+ (annotation['bbox'][0],
53
+ annotation['bbox'][1],
54
+ annotation['bbox'][0] + annotation['bbox'][2],
55
+ annotation['bbox'][1] + annotation['bbox'][3]),
56
+ outline=colors[data['categories'][annotation['category_id'] - 1]['name']] + (255,),
57
+ width=2
58
+ )
59
+ # Draw label
60
+ text = samples['categories'][annotation['category_id'] - 1]['name']
61
+ bbox = draw.textbbox((0, 0), text, font=font)
62
+ w = bbox[2] - bbox[0]
63
+ h = bbox[3] - bbox[1]
64
+
65
+ if annotation['bbox'][3] < h:
66
+ draw.rectangle(
67
+ (annotation['bbox'][0] + annotation['bbox'][2],
68
+ annotation['bbox'][1],
69
+ annotation['bbox'][0] + annotation['bbox'][2] + w,
70
+ annotation['bbox'][1] + h),
71
+ fill=(64, 64, 64, 255)
72
+ )
73
+ draw.text(
74
+ (annotation['bbox'][0] + annotation['bbox'][2],
75
+ annotation['bbox'][1]),
76
+ text=samples['categories'][annotation['category_id'] - 1]['name'],
77
+ fill=(255, 255, 255, 255)
78
+ )
79
+ else:
80
+ draw.rectangle(
81
+ (annotation['bbox'][0],
82
+ annotation['bbox'][1],
83
+ annotation['bbox'][0] + w,
84
+ annotation['bbox'][1] + h),
85
+ fill=(64, 64, 64, 255)
86
+ )
87
+ draw.text(
88
+ (annotation['bbox'][0],
89
+ annotation['bbox'][1]),
90
+ text=samples['categories'][annotation['category_id'] - 1]['name'],
91
+ fill=(255, 255, 255, 255)
92
+
93
+ )
94
+ return np.array(image)
95
+
96
+ import os
97
+ import shutil
98
+ from pathlib import Path
99
+ import cv2 as cv
100
+
101
+ def write_file(image_id, inside, filename, content, check_set):
102
+ """
103
+ Writes content to a file. If 'inside' is True, appends the content, otherwise overwrites the file.
104
+
105
+ Args:
106
+ image_id (str): The ID of the image.
107
+ inside (bool): Flag to determine if content should be appended or overwritten.
108
+ filename (str): The path to the file.
109
+ content (str): The content to write to the file.
110
+ check_set (set): A set to keep track of image IDs.
111
+ """
112
+ if inside:
113
+ with open(filename, "a") as file:
114
+ file.write("\n")
115
+ file.write(content)
116
+ else:
117
+ check_set.add(image_id)
118
+ with open(filename, "w") as file:
119
+ file.write(content)
120
+
121
+ def get_bb_shape(bboxe, img):
122
+ """
123
+ Calculates the shape of the bounding box in the image.
124
+
125
+ Args:
126
+ bboxe (list): Bounding box coordinates [x, y, width, height].
127
+ img (numpy.ndarray): The image array.
128
+
129
+ Returns:
130
+ tuple: The shape (height, width) of the bounding box.
131
+ """
132
+ tleft = (bboxe[0], bboxe[1])
133
+ tright = (bboxe[0] + bboxe[2], bboxe[1])
134
+ bleft = (bboxe[0], bboxe[1] + bboxe[3])
135
+ bright = (bboxe[0] + bboxe[2], bboxe[1] + bboxe[3])
136
+
137
+ top_left_x = min([tleft[0], tright[0], bleft[0], bright[0]])
138
+ top_left_y = min([tleft[1], tright[1], bleft[1], bright[1]])
139
+ bot_right_x = max([tleft[0], tright[0], bleft[0], bright[0]])
140
+ bot_right_y = max([tleft[1], tright[1], bleft[1], bright[1]])
141
+
142
+ image = img[int(top_left_y):int(bot_right_y) + 1, int(top_left_x):int(bot_right_x) + 1]
143
+
144
+ return image.shape[:2]
145
+
146
+ def coco_to_yolo(x1, y1, w, h, image_w, image_h):
147
+ """
148
+ Converts COCO format bounding box to YOLO format.
149
+
150
+ Args:
151
+ x1 (float): Top-left x coordinate.
152
+ y1 (float): Top-left y coordinate.
153
+ w (float): Width of the bounding box.
154
+ h (float): Height of the bounding box.
155
+ image_w (int): Width of the image.
156
+ image_h (int): Height of the image.
157
+
158
+ Returns:
159
+ list: YOLO format bounding box [x_center, y_center, width, height].
160
+ """
161
+ return [((2 * x1 + w) / (2 * image_w)), ((2 * y1 + h) / (2 * image_h)), w / image_w, h / image_h]
162
+
163
+ def create_directory(path):
164
+ """
165
+ Creates a directory, deleting it first if it already exists.
166
+
167
+ Args:
168
+ path (str): The path to the directory.
169
+ """
170
+ dirpath = Path(path)
171
+ if dirpath.exists() and dirpath.is_dir():
172
+ shutil.rmtree(dirpath)
173
+ os.mkdir(dirpath)
174
+
175
+ def generate_yolo_labels(images):
176
+ """
177
+ Generates YOLO format labels from the given images and annotations.
178
+
179
+ Args:
180
+ images (dict): Dictionary containing image data and annotations.
181
+ """
182
+ check_set = set()
183
+
184
+ create_directory(os.getcwd() + '/data/processed/yolo')
185
+
186
+ for key in images:
187
+ image_id = ','.join(map(str, [image_id['image_id'] for image_id in images[key]['annotations']]))
188
+ category_id = ''.join(map(str, [cat_id['category_id'] - 1 for cat_id in images[key]['annotations']]))
189
+ bbox = [bbox['bbox'] for bbox in images[key]['annotations']]
190
+ image_path = images[key]['file_name']
191
+ filename = os.getcwd() + '/data/processed/yolo/' + image_path.split('/')[-1].split(".")[0] + '.txt'
192
+
193
+ for index, b in enumerate(bbox):
194
+ bbox = [b[0], b[1], b[2], b[3]]
195
+ shape = get_bb_shape(bbox, cv.imread(image_path))
196
+ yolo_bbox = coco_to_yolo(bbox[0], bbox[1], shape[1], shape[0], cv.imread(image_path).shape[1], cv.imread(image_path).shape[0])
197
+ content = category_id[index] + ' ' + str(yolo_bbox[0]) + ' ' + str(yolo_bbox[1]) + ' ' + str(yolo_bbox[2]) + ' ' + str(yolo_bbox[3])
198
+
199
+ if image_id in check_set:
200
+ write_file(image_id, True, filename, content, check_set)
201
+ else:
202
+ write_file(image_id, False, filename, content, check_set)
203
+
204
+
205
+ def delete_additional_images(old_train_path, temp_images_path, yolo_path):
206
+ train = next(os.walk(old_train_path), (None, None, []))[2]
207
+ label = next(os.walk(yolo_path), (None, None, []))[2]
208
+
209
+ dirpath = Path(temp_images_path)
210
+ if dirpath.exists() and dirpath.is_dir():
211
+ shutil.rmtree(dirpath)
212
+ os.mkdir(dirpath)
213
+
214
+ for img in train:
215
+ splited = img.split(".")[0]
216
+ txt = f"{splited}.txt"
217
+ if txt in label:
218
+ shutil.move(f"{old_train_path}/{img}", f"{temp_images_path}/{img}")
219
+ return
220
+
221
+ def split_data(temp_images_path):
222
+ image = next(os.walk(temp_images_path), (None, None, []))[2]
223
+ train = image[int(len(image) * .1) : int(len(image) * .90)]
224
+ validation = list(set(image) - set(train))
225
+
226
+ create_directory(os.getcwd() + '/data/processed/training')
227
+ create_directory(os.getcwd() + '/data/processed/validation')
228
+ create_directory(os.getcwd() + '/data/processed/training/images/')
229
+ create_directory(os.getcwd() + '/data/processed/validation/images/')
230
+
231
+ for train_img in train:
232
+ shutil.move(f'{temp_images_path}/{train_img}', os.getcwd() + '/data/processed/training/images/')
233
+
234
+ for valid_img in validation:
235
+ shutil.move(f'{temp_images_path}/{valid_img}', os.getcwd() + '/data/processed/validation/images/')
236
+
237
+ validation_without_ext = [i.split('.')[0] for i in validation]
238
+ return validation_without_ext
239
+
240
+ def create_directory(path):
241
+ dirpath = Path(path)
242
+ if dirpath.exists() and dirpath.is_dir():
243
+ shutil.rmtree(dirpath)
244
+ os.mkdir(dirpath)
245
+
246
+ def get_labels(yolo_path, valid_without_extension):
247
+ create_directory(os.getcwd() + '/data/processed/training/labels')
248
+ create_directory(os.getcwd() + '/data/processed/validation/labels')
249
+
250
+ label = next(os.walk(yolo_path), (None, None, []))[2]
251
+ for lab in label:
252
+ split = lab.split(".")[0]
253
+ if split in valid_without_extension:
254
+ shutil.move(f"{yolo_path}/{lab}", os.getcwd() + f'/data/processed/validation/labels/{lab}')
255
+ else:
256
+ shutil.move(f"{yolo_path}/{lab}", os.getcwd() + f'/data/processed/training/labels/{lab}')
257
+
258
+ return
259
+
260
+ def final_preparation(old_train_path, temp_images_path, yolo_path, images):
261
+ delete_additional_images(old_train_path, temp_images_path, yolo_path)
262
+ valid_without_extension = split_data(temp_images_path)
263
+
264
+ dirpath = Path(temp_images_path)
265
+ if dirpath.exists() and dirpath.is_dir():
266
+ shutil.rmtree(dirpath)
267
+
268
+ return get_labels(yolo_path, valid_without_extension)
269
+
270
+
271
+ def annotate_tables(directory):
272
+ dirpath = Path(os.getcwd() + f'/data/processed/tables')
273
+ if dirpath.exists() and dirpath.is_dir():
274
+ shutil.rmtree(dirpath)
275
+ os.mkdir(dirpath)
276
+
277
+ # Iterate through the directory
278
+ for filename in os.listdir(directory):
279
+ # Get the full path of the file
280
+ file_path = os.path.join(directory, filename)
281
+
282
+ # Check if it's a file (not a subdirectory)
283
+ if os.path.isfile(file_path):
284
+ img_name = filename.split('.')[0]
285
+
286
+ if os.path.isfile(os.getcwd() + f'/data/processed/training/images/{img_name}.jpg'):
287
+ with open(os.getcwd() + f'/data/processed/training/labels/{img_name}.txt', 'r') as f:
288
+ results = f.read()
289
+ original_image = Image.open(os.getcwd() + f'/data/processed/training/images/{img_name}.jpg')
290
+
291
+ elif os.path.isfile(os.getcwd() + f'/data/processed/validation/images/{img_name}.jpg'):
292
+ with open(os.getcwd() + f'/data/processed/validation/labels/{img_name}.txt', 'r') as f:
293
+ results = f.read()
294
+ original_image = Image.open(os.getcwd() + f'/data/processed/validation/images/{img_name}.jpg')
295
+
296
+ # Iterate through the results
297
+ for r in results:
298
+ boxes = r.boxes # Bounding boxes object
299
+
300
+ for box in boxes:
301
+ # Check if the detected object is a table
302
+ if box.cls == 3:
303
+ # Get the bounding box coordinates
304
+ x1, y1, x2, y2 = box.xyxy[0] # get box coordinates in (top, left, bottom, right) format
305
+
306
+ # Convert tensor to int
307
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
308
+
309
+ # Crop the original image to the table region
310
+ table_image = original_image.crop((x1, y1, x2, y2))
311
+
312
+ # Show the cropped table image
313
+ table_image.show()
314
+
315
+ # Save the cropped table image
316
+ table_image.save(os.getcwd() + f'/data/processed/tables/{img_name}.jpg')
317
+
318
+ # Break after finding the first table (remove this if you want to detect multiple tables)
319
+ break
320
+
321
+ # Break after processing the first result (usually there's only one result per image)
322
+ break
323
+
324
+ if __name__ == '__main__':
325
+ colors = {'title': (255, 0, 0),
326
+ 'text': (0, 255, 0),
327
+ 'figure': (0, 0, 255),
328
+ 'table': (255, 255, 0),
329
+ 'list': (0, 255, 255)}
330
+ images,data = get_data_and_annots()
331
+ generate_labels = generate_yolo_labels(images)
332
+ finalPrep = final_preparation(os.path.join(os.getcwd() + r'\data\raw\train\publaynet\train'),os.path.join(os.getcwd() + r"\data\processed\images"), os.getcwd() + '/data/processed/yolo',images)
333
+ annotate_tables(os.getcwd() + '/data/processed/hand_labeled_tables/hand_labeled_tables')
scripts/make_dataset.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+ import tarfile
4
+ from pathlib import Path
5
+ import shutil
6
+ import zipfile
7
+ import os
8
+
9
+
10
+ def get_archive(path,url,Set):
11
+ try:
12
+ os.mkdir(path)
13
+ except:
14
+ path=path
15
+
16
+ urllib.request.urlretrieve(url,f"{path}/{Set}.tar")
17
+
18
+
19
+ def extract_tar(tar_file):
20
+ print(f'{os.getcwd()}/data/raw/{tar_file}.tar', end='\r')
21
+ file = tarfile.open(f'{os.getcwd()}/data/raw/{tar_file}.tar')
22
+ file.extractall(f'{os.getcwd()}/data/raw/{tar_file}')
23
+ file.close()
24
+ os.remove(f'{os.getcwd()}/data/raw/{tar_file}.tar')
25
+
26
+ def make_dir(target_dir):
27
+ if Path(target_dir).exists() and Path(target_dir).is_dir():
28
+ shutil.rmtree(Path(target_dir))
29
+ os.makedirs(target_dir, exist_ok=True)
30
+
31
+ def combine_dirs(source_dirs):
32
+ for source_dir in source_dirs:
33
+ for subdir, dirs, files in os.walk(os.getcwd() + '/data/raw/' + source_dir):
34
+ for file in files:
35
+ filepath = subdir + os.sep + file
36
+
37
+ if filepath.find('.jpg') != -1:
38
+ shutil.copy(filepath, target_dir)
39
+
40
+ if Path(os.getcwd() + '/data/raw/' + source_dir).exists():
41
+ shutil.rmtree(Path(os.getcwd() + '/data/raw/' + source_dir))
42
+
43
+
44
+ def unzip_file(zip_file_path, extract_to):
45
+ # Create the target directory if it doesn't exist
46
+ os.makedirs(extract_to, exist_ok=True)
47
+
48
+ # Open the zip file
49
+ with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
50
+ # Extract all contents to the specified directory
51
+ zip_ref.extractall(extract_to)
52
+
53
+
54
+ if __name__ == '__main__':
55
+ make_dir(os.getcwd() + '/data/raw')
56
+ make_dir(os.getcwd() + '/data/processed')
57
+ make_dir(os.getcwd() + '/data/outputs')
58
+ make_dir(os.getcwd() + '/models')
59
+
60
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/labels.tar.gz',"label")
61
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-0.tar.gz',"train0")
62
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-1.tar.gz',"train1")
63
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-2.tar.gz',"train2")
64
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-3.tar.gz',"train3")
65
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-4.tar.gz',"train4")
66
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-5.tar.gz',"train5")
67
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/train-6.tar.gz',"train6")
68
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/val.tar.gz',"val")
69
+ get_archive(os.getcwd() + '/data/raw','https://dax-cdn.cdn.appdomain.cloud/dax-publaynet/1.0.0/test.tar.gz',"test")
70
+
71
+ extract_tar("train0")
72
+ extract_tar("train1")
73
+ extract_tar("train2")
74
+ extract_tar("train3")
75
+ extract_tar("train4")
76
+ extract_tar("train5")
77
+ extract_tar("train6")
78
+ extract_tar("label")
79
+ extract_tar("val")
80
+ extract_tar("test")
81
+
82
+ target_dir = os.getcwd() + '/data/raw/train/publaynet/train/'
83
+ make_dir(target_dir)
84
+
85
+ source_dirs = ['train0','train1','train2','train3', 'train4', 'train5', 'train6']
86
+ combine_dirs(source_dirs)
87
+
88
+ source_dirs = ['val', 'test']
89
+ combine_dirs(source_dirs)
90
+
91
+ unzip_file('hand_labeled_tables.zip', os.getcwd() + '/data/processed/hand_labeled_tables')
scripts/model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ from PIL import Image
5
+ import pandas as pd
6
+ from IPython.display import Image
7
+ from ultralytics import YOLO
8
+ import torch
9
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
10
+ from datasets import load_dataset
11
+ import cv2
12
+ import pytesseract
13
+ from PIL import Image, ImageEnhance
14
+ import numpy as np
15
+
16
+ # Ensure you have installed Tesseract OCR and set the path
17
+ pytesseract.pytesseract.tesseract_cmd = r'C:/Program Files/Tesseract-OCR/tesseract.exe' # Update this path for your system
18
+
19
+ def ocr_core(image):
20
+ # Run Tesseract OCR on the preprocessed image
21
+ data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
22
+ df = pd.DataFrame(data)
23
+ df = df[df['conf'] != -1]
24
+ df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int)
25
+ df['prev_width'] = df['width'].shift(1).fillna(0).astype(int)
26
+ df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int)
27
+ df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1)
28
+ df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 100 else x['text'], axis=1)
29
+ ocr_text = ""
30
+ for text in df['text']:
31
+ ocr_text += text + ' '
32
+ return ocr_text
33
+
34
+ def improve_ocr_accuracy(img):
35
+ # Read image with PIL (for color preservation)
36
+ img =Image.open(img)
37
+
38
+ # Increase image size (can improve accuracy for small text)
39
+ img = img.resize((img.width * 4, img.height * 4))
40
+
41
+ # Increase contrast
42
+ enhancer = ImageEnhance.Contrast(img)
43
+ img = enhancer.enhance(2)
44
+
45
+ _, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV)
46
+
47
+ return thresh
48
+
49
+
50
+ def create_ocr_outputs():
51
+ directory_path = os.getcwd() + '/data/processed/hand_labeled_tables/hand_labeled_tables'
52
+
53
+ for root, dirs, files in os.walk(directory_path):
54
+ # Print the current directory
55
+ print(f"Current directory: {root}")
56
+
57
+ # Print all subdirectories in the current directory
58
+ print("Subdirectories:")
59
+ for dir in dirs:
60
+ print(f"- {dir}")
61
+
62
+ # Print all files in the current directory
63
+ print("Files:")
64
+ for image_path in files:
65
+ print(f"- {image_path}")
66
+ full_path = os.path.join(root, image_path)
67
+ # Preprocess the image
68
+ preprocessed_image = improve_ocr_accuracy(full_path)
69
+
70
+ ocr_text = ocr_core(preprocessed_image)
71
+ with open(os.getcwd() + f"/data/processed/annotations/{image_path.split('.')[0]}.txt", 'wb') as f:
72
+ f.write(ocr_text.encode('utf-8'))
73
+
74
+ print("\n") # Add a blank line for readability
75
+
76
+
77
+ def prepare_dataset(ocr_dir, csv_dir, output_file):
78
+ with open(output_file, 'w', encoding='utf-8') as jsonl_file:
79
+ for filename in os.listdir(ocr_dir):
80
+ if filename.endswith('.txt'):
81
+ ocr_path = os.path.join(ocr_dir, filename)
82
+ csv_path = os.path.join(csv_dir, filename)#.replace('.txt', '.csv'))
83
+ print(csv_path)
84
+ # if not os.path.exists(csv_path):
85
+ # print(f"Warning: Corresponding CSV file not found for {ocr_path}")
86
+ # continue
87
+
88
+ with open(ocr_path, 'r', encoding='utf-8') as ocr_file:
89
+ ocr_text = ocr_file.read()
90
+
91
+ with open(csv_path, 'r', encoding='utf-8') as csv_file:
92
+ csv_text = csv_file.read()
93
+
94
+ json_object = {
95
+ "prompt": ocr_text,
96
+ "completion": csv_text
97
+ }
98
+ jsonl_file.write(json.dumps(json_object) + '\n')
99
+
100
+ def tokenize_function(examples):
101
+ # Tokenize the inputs
102
+ inputs = tokenizer(examples['prompt'], truncation=True, padding='max_length', max_length=1012)
103
+
104
+ # Create labels which are the same as input_ids
105
+ inputs['labels'] = inputs['input_ids'].copy()
106
+ return inputs
107
+
108
+
109
+ if __name__ == '__name__':
110
+
111
+ # Ensure CUDA is available
112
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
+ print(f"Using device: {device}")
114
+
115
+ # Load a pretrained YOLOv8 model
116
+ model = YOLO('yolov8l.pt')
117
+
118
+ # Train the model on your custom dataset
119
+ results = model.train(
120
+ data='config.yaml',
121
+ epochs=10,
122
+ imgsz=640,
123
+ batch=8,
124
+ name='yolov8l_custom',
125
+ device=device
126
+ )
127
+
128
+ # Evaluate the model's performance
129
+ metrics = model.val()
130
+ print(metrics.box.map) # print the mean Average Precision
131
+ torch.save(model, os.getcwd() + '/models/trained_yolov8.pt')
132
+
133
+ create_ocr_outputs()
134
+
135
+ # Usage
136
+ ocr_dir = os.getcwd() + '/data/processed/annotations'
137
+ csv_dir = os.getcwd() + '/data/processed/hand_labeled_tables'
138
+ output_file = 'dataset.jsonl'
139
+ prepare_dataset(ocr_dir, csv_dir, output_file)
140
+
141
+
142
+ # Load the dataset
143
+ dataset = load_dataset('json', data_files={'train': 'dataset.jsonl'})
144
+ dataset = dataset['train'].train_test_split(test_size=0.1)
145
+
146
+ # Tokenization
147
+ model_name = 'gpt2' # You can choose other models like 'gpt2-medium', 'gpt2-large', etc.
148
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
149
+
150
+ # Add a new pad token
151
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
152
+
153
+ tokenized_dataset = dataset.map(tokenize_function, batched=True)
154
+
155
+ # Load the model
156
+ model = GPT2LMHeadModel.from_pretrained(model_name)
157
+
158
+ # Resize the model embeddings to accommodate the new pad token
159
+ model.resize_token_embeddings(len(tokenizer))
160
+
161
+ training_args = TrainingArguments(
162
+ output_dir='./results',
163
+ num_train_epochs=3,
164
+ per_device_train_batch_size=2,
165
+ per_device_eval_batch_size=2,
166
+ warmup_steps=500,
167
+ weight_decay=0.01,
168
+ logging_dir='./logs',
169
+ logging_steps=10,
170
+ evaluation_strategy="epoch", # Evaluate at the end of each epoch
171
+ save_strategy="epoch", # Save at the end of each epoch
172
+ load_best_model_at_end=True, # Load the best model when finished training (based on evaluation)
173
+ metric_for_best_model="eval_loss", # Use eval_loss to determine the best model
174
+ )
175
+
176
+ # Trainer
177
+ trainer = Trainer(
178
+ model=model,
179
+ args=training_args,
180
+ train_dataset=tokenized_dataset['train'],
181
+ eval_dataset=tokenized_dataset['test'],
182
+ )
183
+
184
+ # Train the model
185
+ trainer.train()
186
+
187
+ # Evaluate the model
188
+ eval_results = trainer.evaluate()
189
+ print(f"Evaluation results: {eval_results}")
190
+
191
+ # Save the model
192
+ model.save_pretrained(os.getcwd() + '/models/gpt')
193
+ tokenizer.save_pretrained(os.getcwd() + '/models/gpt')
setup.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+
4
+ script = 'make_dataset.py'
5
+ command = f'{sys.executable} scripts/{script}'
6
+ subprocess.run(command, shell=True)
7
+
8
+ script = 'build_features.py'
9
+ command = f'{sys.executable} python scripts/{script}'
10
+ subprocess.run(command, shell=True)
11
+
12
+ script = 'model.py'
13
+ command = f'{sys.executable} python scripts/{script}'
14
+ subprocess.run(command, shell=True)