alps / deepdoc /vision /ragFlow.py
yumikimi381's picture
Upload folder using huggingface_hub
daf0288 verified
import copy
import time
import os
from huggingface_hub import snapshot_download
from .operators import *
import numpy as np
import onnxruntime as ort
import logging
from .postprocess import build_post_process
from typing import List
def get_deepdoc_directory():
PROJECT_BASE = os.path.abspath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir
)
)
return PROJECT_BASE
def transform(data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def create_operators(op_param_list, global_config=None):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(
op_param_list, list), ('operator config should be a list')
ops = []
for operator in op_param_list:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
if global_config is not None:
param.update(global_config)
op = eval(op_name)(**param)
ops.append(op)
return ops
def load_model(model_dir, nm):
model_file_path = os.path.join(model_dir, nm + ".onnx")
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(
model_file_path))
options = ort.SessionOptions()
options.enable_cpu_mem_arena = False
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
options.intra_op_num_threads = 2
options.inter_op_num_threads = 2
if False and ort.get_device() == "GPU":
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CUDAExecutionProvider'])
else:
sess = ort.InferenceSession(
model_file_path,
options=options,
providers=['CPUExecutionProvider'])
print(model_file_path)
print(sess.get_modelmeta().description)
return sess, sess.get_inputs()[0]
class RagFlowTextDetector:
"""
The class depends on TextDetector to perform its primary function of detecting text and retrieving bounding boxes.
"""
def __init__(self, model_dir):
pre_process_list = [{
'DetResizeForTest': {
'limit_side_len': 960,
'limit_type': "max",
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.5, "max_candidates": 1000,
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor = load_model(model_dir, 'det')
img_h, img_w = self.input_tensor.shape[2:]
if isinstance(img_h, str) or isinstance(img_w, str):
pass
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
pre_process_list[0] = {
'DetResizeForTest': {
'image_shape': [img_h, img_w]
}
}
self.preprocess_op = create_operators(pre_process_list)
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
diff = np.diff(np.array(tmp), axis=1)
rect[1] = tmp[np.argmin(diff)]
rect[3] = tmp[np.argmax(diff)]
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.order_points_clockwise(box)
box = self.clip_det_res(box, img_height, img_width)
rect_width = int(np.linalg.norm(box[0] - box[1]))
rect_height = int(np.linalg.norm(box[0] - box[3]))
if rect_width <= 3 or rect_height <= 3:
continue
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
if isinstance(box, list):
box = np.array(box)
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
st = time.time()
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
input_dict = {}
input_dict[self.input_tensor.name] = img
for i in range(100000):
try:
outputs = self.predictor.run(None, input_dict)
break
except Exception as e:
if i >= 3:
raise e
time.sleep(5)
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
dt_boxes = post_result[0]['points']
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
return dt_boxes, time.time() - st
class RagFlow():
def __init__(self, model_dir=None):
if not model_dir:
try:
model_dir = os.path.join(
get_deepdoc_directory(),
"models")
self.text_detector = RagFlowTextDetector(model_dir)
except Exception as e:
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
local_dir=os.path.join(get_deepdoc_directory(), "models"),
local_dir_use_symlinks=False)
self.text_detector = RagFlowTextDetector(model_dir)
self.drop_score = 0.5
self.crop_image_res_index = 0
def get_rotate_crop_image(self, img, points):
'''
img_height, img_width = img.shape[0:2]
left = int(np.min(points[:, 0]))
right = int(np.max(points[:, 0]))
top = int(np.min(points[:, 1]))
bottom = int(np.max(points[:, 1]))
img_crop = img[top:bottom, left:right, :].copy()
points[:, 0] = points[:, 0] - left
points[:, 1] = points[:, 1] - top
'''
assert len(points) == 4, "shape of points must be 4*2"
img_crop_width = int(
max(
np.linalg.norm(points[0] - points[1]),
np.linalg.norm(points[2] - points[3])))
img_crop_height = int(
max(
np.linalg.norm(points[0] - points[3]),
np.linalg.norm(points[1] - points[2])))
pts_std = np.float32([[0, 0], [img_crop_width, 0],
[img_crop_width, img_crop_height],
[0, img_crop_height]])
M = cv2.getPerspectiveTransform(points, pts_std)
dst_img = cv2.warpPerspective(
img,
M, (img_crop_width, img_crop_height),
borderMode=cv2.BORDER_REPLICATE,
flags=cv2.INTER_CUBIC)
dst_img_height, dst_img_width = dst_img.shape[0:2]
if dst_img_height * 1.0 / dst_img_width >= 1.5:
dst_img = np.rot90(dst_img)
return dst_img
def sorted_boxes(self, dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
def detect(self, img):
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
if img is None:
return None, None, time_dict
start = time.time()
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
return zip(self.sorted_boxes(dt_boxes), [
("", 0) for _ in range(len(dt_boxes))])
def recognize(self, ori_im, box):
img_crop = self.get_rotate_crop_image(ori_im, box)
rec_res, elapse = self.text_recognizer([img_crop])
text, score = rec_res[0]
if score < self.drop_score:
return ""
return text
def predict(self,img:np.ndarray=None)-> List[List[float]]:
"""
Return np array of bounding boxes - for each box 4 points of 2 coordinates
"""
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
dt_boxes, elapse = self.text_detector(img)
time_dict['det'] = elapse
dt_boxes = self.sorted_boxes(dt_boxes)
return dt_boxes