cv_quality / yolo_text_extraction.py
Nassiraaa's picture
Update yolo_text_extraction.py
7dde6db verified
from ultralytics import YOLO
from PIL import Image, ImageDraw
import numpy as np
from PIL import ImageFilter
from dotenv import load_dotenv
from ocr_functions import paddle_ocr, textract_ocr, tesseract_ocr
from pdf2image import convert_from_path
model = YOLO("yolo_model/best.pt")
def check_intersection(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
return not (x3 > x2 or x4 < x1 or y3 > y2 or y4 < y1)
def check_inclusion(bbox1, bbox2):
x1, y1, x2, y2 = bbox1
x3, y3, x4, y4 = bbox2
return x1 >= x3 and y1 >= y3 and x2 <= x4 and y2 <= y4
def union_bbox(bbox1, bbox2):
x1 = min(bbox1[0], bbox2[0])
y1 = min(bbox1[1], bbox2[1])
x2 = max(bbox1[2], bbox2[2])
y2 = max(bbox1[3], bbox2[3])
return [x1, y1, x2, y2]
def filter_bboxes(bboxes):
filtered_bboxes = []
for bbox1 in bboxes:
is_valid = True
for bbox2 in filtered_bboxes:
if check_intersection(bbox1, bbox2):
bbox1 = union_bbox(bbox1, bbox2)
is_valid = False
break
elif check_inclusion(bbox1, bbox2):
is_valid = False
break
if is_valid:
filtered_bboxes.append(bbox1)
return filtered_bboxes
def draw_bboxes(image, bboxes):
draw = ImageDraw.Draw(image)
for bbox in bboxes:
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
draw.rectangle([(x1, y1), (x2, y2)], outline=(255, 0, 0), width=2)
def extract_image(image, box):
x1, y1, x2, y2 = box
cropped_image = image.crop((x1, y1, x2, y2))
def text_image(image):
image = image.convert("RGB")
image = image.filter(ImageFilter.MedianFilter(3))
image_np = np.array(image)
result = model.predict(source=image_np, conf=0.10, save=False)
names = result[0].names
data = result[0].boxes.data.numpy()
xyxy = data[:, :]
bboxes = data[:, 0:4].tolist()
bboxes_filter = filter_bboxes(bboxes)
image_box = data[data[:, 5] == 11]
extract_image(image, image_box[0, 0:4])
draw_bboxes(image, bboxes_filter)
image.save("output.png")
texts = [textract_ocr(image, bbox) for bbox in bboxes_filter]
return "\n------section-------\n" + "\n------section-------\n".join(texts)
def pdf_to_text(pdf_file):
text = ""
images = convert_from_path(pdf_file)
for image in images:
text = text + text_image(image) + "\n"
return text