Spaces:
Running
Running
from langchain.document_loaders.unstructured import UnstructuredFileLoader | |
from typing import List | |
import tqdm | |
class RapidOCRPPTLoader(UnstructuredFileLoader): | |
def _get_elements(self) -> List: | |
def ppt2text(filepath): | |
from pptx import Presentation | |
from PIL import Image | |
import numpy as np | |
from io import BytesIO | |
from rapidocr_onnxruntime import RapidOCR | |
ocr = RapidOCR() | |
prs = Presentation(filepath) | |
resp = "" | |
def extract_text(shape): | |
nonlocal resp | |
if shape.has_text_frame: | |
resp += shape.text.strip() + "\n" | |
if shape.has_table: | |
for row in shape.table.rows: | |
for cell in row.cells: | |
for paragraph in cell.text_frame.paragraphs: | |
resp += paragraph.text.strip() + "\n" | |
if shape.shape_type == 13: # 13 表示图片 | |
image = Image.open(BytesIO(shape.image.blob)) | |
result, _ = ocr(np.array(image)) | |
if result: | |
ocr_result = [line[1] for line in result] | |
resp += "\n".join(ocr_result) | |
elif shape.shape_type == 6: # 6 表示组合 | |
for child_shape in shape.shapes: | |
extract_text(child_shape) | |
b_unit = tqdm.tqdm(total=len(prs.slides), | |
desc="RapidOCRPPTLoader slide index: 1") | |
# 遍历所有幻灯片 | |
for slide_number, slide in enumerate(prs.slides, start=1): | |
b_unit.set_description( | |
"RapidOCRPPTLoader slide index: {}".format(slide_number)) | |
b_unit.refresh() | |
sorted_shapes = sorted(slide.shapes, | |
key=lambda x: (x.top, x.left)) # 从上到下、从左到右遍历 | |
for shape in sorted_shapes: | |
extract_text(shape) | |
b_unit.update(1) | |
return resp | |
text = ppt2text(self.file_path) | |
from unstructured.partition.text import partition_text | |
return partition_text(text=text, **self.unstructured_kwargs) | |
if __name__ == '__main__': | |
loader = RapidOCRPPTLoader(file_path="../tests/samples/ocr_test.pptx") | |
docs = loader.load() | |
print(docs) | |