|
""" |
|
Attribution: https://github.com/AIPI540/AIPI540-Deep-Learning-Applications/ |
|
|
|
Jon Reifschneider |
|
Brinnae Bent |
|
|
|
""" |
|
|
|
import streamlit as st |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import pandas as pd |
|
import os |
|
import json |
|
import pandas as pd |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import matplotlib.pyplot as plt |
|
from ultralytics import YOLO |
|
from PIL import Image, ImageDraw, ImageFont |
|
import numpy as np |
|
import cv2 |
|
import pytesseract |
|
from PIL import ImageEnhance |
|
import numpy as np |
|
import os |
|
import json |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments |
|
from datasets import load_dataset |
|
from transformers import DataCollatorForLanguageModeling |
|
from PIL import Image, ImageEnhance |
|
from io import StringIO |
|
|
|
|
|
def crop_image(model, original_image): |
|
""" |
|
Crop the region of interest (table) from an image using a YOLO model. |
|
|
|
Inputs: |
|
model (YOLO): The YOLO model used for object detection. |
|
original_image (PIL.image): The image to be processed. |
|
|
|
Returns: |
|
PIL.Image: The cropped image containing the detected table. |
|
""" |
|
image_array = np.array(image) |
|
results = model(image_array) |
|
|
|
for r in results: |
|
boxes = r.boxes |
|
|
|
for box in boxes: |
|
if box.cls == 3: |
|
x1, y1, x2, y2 = box.xyxy[0] |
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) |
|
|
|
table_image = original_image.crop((x1, y1, x2, y2)) |
|
|
|
return table_image |
|
return |
|
|
|
def process_image(model, image): |
|
""" |
|
Process the uploaded image with YOLO model and draw bounding boxes with class-specific colors. |
|
|
|
Inputs: |
|
model: The trained YOLO model |
|
image: The image file uploaded through Streamlit. |
|
|
|
Returns: |
|
PIL.Image: The processed image with bounding boxes and labels. |
|
""" |
|
colors = {'title': (255, 0, 0), |
|
'text': (0, 255, 0), |
|
'figure': (0, 0, 255), |
|
'table': (255, 255, 0), |
|
'list': (0, 255, 255)} |
|
|
|
image_array = np.array(image) |
|
results = model(image_array) |
|
|
|
for result in results: |
|
boxes = result.boxes.cpu().numpy() |
|
for box in boxes: |
|
r = box.xyxy[0].astype(int) |
|
label = result.names[int(box.cls)] |
|
color = colors.get(label.lower(), (255, 255, 255)) |
|
|
|
cv2.rectangle(image_array, r[:2], r[2:], color, 2) |
|
|
|
label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) |
|
top_left = (r[0], r[1] - label_size[1] - baseline) |
|
bottom_right = (r[0] + label_size[0], r[1]) |
|
cv2.rectangle(image_array, top_left, bottom_right, color, cv2.FILLED) |
|
cv2.putText(image_array, label, (r[0], r[1] - baseline), |
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) |
|
|
|
return Image.fromarray(image_array) |
|
|
|
def improve_ocr_accuracy(img): |
|
""" |
|
Preprocess the image to improve OCR accuracy. |
|
|
|
This function resizes the image, increases contrast, and applies thresholding |
|
to enhance the image for better OCR results. |
|
|
|
Inputs: |
|
img (PIL.Image): The input image to be processed. |
|
|
|
Returns: |
|
numpy.ndarray: A binary thresholded image as a numpy array. |
|
""" |
|
img = img.resize((img.width * 4, img.height * 4)) |
|
|
|
enhancer = ImageEnhance.Contrast(img) |
|
img = enhancer.enhance(2) |
|
|
|
_, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV) |
|
|
|
return thresh |
|
|
|
def ocr_core(image): |
|
""" |
|
Perform OCR on the given image and process the extracted text. |
|
|
|
This function uses pytesseract to extract text from the image and then |
|
processes the extracted data to format it with appropriate line breaks |
|
and spacing. |
|
|
|
Inputs: |
|
image (numpy.ndarray): The preprocessed image as a numpy array. |
|
|
|
Returns: |
|
str: The extracted and formatted text from the image. |
|
""" |
|
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) |
|
df = pd.DataFrame(data) |
|
df = df[df['conf'] != -1] |
|
df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int) |
|
df['prev_width'] = df['width'].shift(1).fillna(0).astype(int) |
|
df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int) |
|
df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1) |
|
df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 80 else x['text'], axis=1) |
|
ocr_text = "" |
|
for text in df['text']: |
|
ocr_text += text + ' ' |
|
return ocr_text |
|
|
|
def generate_csv_from_text(tokenizer, model, ocr_text): |
|
""" |
|
Generate CSV text from OCR extracted text using the gpt model |
|
|
|
This function takes the OCR extracted text, processes it through a language model, |
|
and generates CSV formatted text. |
|
|
|
Inputs: |
|
tokenizer: The tokenizer for the gpt model |
|
model: The gpt model used for csv |
|
ocr_text (str): The text extracted from OCR |
|
|
|
Returns: |
|
str: The generated CSV formatted text. |
|
""" |
|
inputs = tokenizer.encode(ocr_text, return_tensors='pt') |
|
outputs = model.generate(inputs, max_length=1000, num_return_sequences=1) |
|
csv_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return csv_text |
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model = YOLO(os.getcwd() + '/models/trained_yolov8.pt') |
|
gpt_model = GPT2LMHeadModel.from_pretrained(os.getcwd() + '/models/gpt_model') |
|
tokenizer = GPT2Tokenizer.from_pretrained(os.getcwd() + '/models/gpt_model') |
|
|
|
st.header(''' |
|
Intelligent Document Processing: Table Extraction |
|
''') |
|
|
|
header_img = Image.open('assets/header_img.png') |
|
st.image(header_img, use_column_width=True) |
|
|
|
st.subheader("Please upload an image of a scanned document with a table using the sidebar") |
|
|
|
with st.sidebar: |
|
user_image = st.file_uploader("Upload an image of a scanned document", type=["png", "jpg", "jpeg"]) |
|
|
|
if user_image is not None: |
|
st.divider() |
|
image = Image.open(user_image) |
|
st.image(image, caption='Uploaded Image', use_column_width=True) |
|
|
|
st.divider() |
|
st.subheader("Document Classes:") |
|
processed_image = process_image(model, image) |
|
st.image(processed_image, caption='Processed Image', use_column_width=True) |
|
|
|
try: |
|
cropped_table = crop_image(model, image) |
|
st.divider() |
|
st.subheader("Table Cropped Image:") |
|
st.image(cropped_table, caption='Cropped Table', use_column_width=True) |
|
|
|
improved_image = improve_ocr_accuracy(cropped_table) |
|
st.divider() |
|
st.subheader("Improved Table Image:") |
|
st.image(improved_image, caption='Improved Table Image', use_column_width=True) |
|
|
|
ocr_text = ocr_core(improved_image) |
|
st.divider() |
|
st.subheader("OCR Text:") |
|
st.write(ocr_text) |
|
|
|
csv_output = generate_csv_from_text(tokenizer,gpt_model,ocr_text) |
|
st.divider() |
|
st.subheader("CSV Output:") |
|
st.write(csv_output.encode('utf-8')) |
|
except: |
|
st.divider() |
|
st.subheader("Error:") |
|
st.write("Please upload a scanned document with a table") |
|
|
|
|
|
|
|
|