Spaces:
Runtime error
Runtime error
JustinLin610
commited on
Commit
•
2915058
1
Parent(s):
7883098
debug
Browse files- app.py +9 -9
- data/mm_data/ocr_dataset.py +10 -4
app.py
CHANGED
@@ -70,7 +70,7 @@ def get_images(img: str, reader: ReaderLite, **kwargs):
|
|
70 |
return results
|
71 |
|
72 |
|
73 |
-
def draw_boxes(image, bounds, color='red', width=
|
74 |
draw = ImageDraw.Draw(image)
|
75 |
for i, bound in enumerate(bounds):
|
76 |
p0, p1, p2, p3 = bound
|
@@ -102,7 +102,7 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
|
|
102 |
_patch_resize_transform = transforms.Compose(
|
103 |
[
|
104 |
lambda image: ocr_resize(
|
105 |
-
image, patch_image_size, is_document=is_document
|
106 |
),
|
107 |
transforms.ToTensor(),
|
108 |
transforms.Normalize(mean=mean, std=std),
|
@@ -113,7 +113,7 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
|
|
113 |
|
114 |
|
115 |
reader = ReaderLite()
|
116 |
-
overrides={"eval_cider": False, "beam":
|
117 |
"orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 7}
|
118 |
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
119 |
utils.split_paths('checkpoints/ocr_general_clean.pt'),
|
@@ -163,9 +163,9 @@ def apply_half(t):
|
|
163 |
return t
|
164 |
|
165 |
|
166 |
-
def ocr(
|
167 |
-
out_img = Image.open(
|
168 |
-
results = get_images(
|
169 |
box_list, image_list = zip(*results)
|
170 |
draw_boxes(out_img, box_list)
|
171 |
|
@@ -191,9 +191,9 @@ description = "Gradio Demo for OFA-OCR. Upload your own image or click any one o
|
|
191 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
192 |
"Repo</a></p> "
|
193 |
examples = [['lihe.png']]
|
194 |
-
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath'),
|
195 |
-
outputs=[gr.outputs.Image(type='pil'), gr.outputs.Textbox(label="OCR result")],
|
196 |
title=title, description=description, article=article, examples=examples,
|
197 |
-
allow_flagging=
|
198 |
io.launch(cache_examples=True)
|
199 |
|
|
|
70 |
return results
|
71 |
|
72 |
|
73 |
+
def draw_boxes(image, bounds, color='red', width=10):
|
74 |
draw = ImageDraw.Draw(image)
|
75 |
for i, bound in enumerate(bounds):
|
76 |
p0, p1, p2, p3 = bound
|
|
|
102 |
_patch_resize_transform = transforms.Compose(
|
103 |
[
|
104 |
lambda image: ocr_resize(
|
105 |
+
image, patch_image_size, is_document=is_document, split='test',
|
106 |
),
|
107 |
transforms.ToTensor(),
|
108 |
transforms.Normalize(mean=mean, std=std),
|
|
|
113 |
|
114 |
|
115 |
reader = ReaderLite()
|
116 |
+
overrides={"eval_cider": False, "beam": 4, "max_len_b": 32, "patch_image_size": 480,
|
117 |
"orig_patch_image_size": 224, "no_repeat_ngram_size": 0, "seed": 7}
|
118 |
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
119 |
utils.split_paths('checkpoints/ocr_general_clean.pt'),
|
|
|
163 |
return t
|
164 |
|
165 |
|
166 |
+
def ocr(Image):
|
167 |
+
out_img = Image.open(Image)
|
168 |
+
results = get_images(Image, reader, link_threshold=0.2)
|
169 |
box_list, image_list = zip(*results)
|
170 |
draw_boxes(out_img, box_list)
|
171 |
|
|
|
191 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
192 |
"Repo</a></p> "
|
193 |
examples = [['lihe.png']]
|
194 |
+
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='filepath', label='Image'),
|
195 |
+
outputs=[gr.outputs.Image(type='pil', label='Image'), gr.outputs.Textbox(label="OCR result")],
|
196 |
title=title, description=description, article=article, examples=examples,
|
197 |
+
allow_flagging='never', allow_screenshot=False)
|
198 |
io.launch(cache_examples=True)
|
199 |
|
data/mm_data/ocr_dataset.py
CHANGED
@@ -82,7 +82,7 @@ def collate(samples, pad_idx, eos_idx):
|
|
82 |
return batch
|
83 |
|
84 |
|
85 |
-
def ocr_resize(img, patch_image_size, is_document=False):
|
86 |
img = img.convert("RGB")
|
87 |
width, height = img.size
|
88 |
|
@@ -92,13 +92,19 @@ def ocr_resize(img, patch_image_size, is_document=False):
|
|
92 |
if width >= height:
|
93 |
new_width = max(64, patch_image_size)
|
94 |
new_height = max(64, int(patch_image_size * (height / width)))
|
95 |
-
|
|
|
|
|
|
|
96 |
bottom = patch_image_size - new_height - top
|
97 |
left, right = 0, 0
|
98 |
else:
|
99 |
new_height = max(64, patch_image_size)
|
100 |
new_width = max(64, int(patch_image_size * (width / height)))
|
101 |
-
|
|
|
|
|
|
|
102 |
right = patch_image_size - new_width - left
|
103 |
top, bottom = 0, 0
|
104 |
|
@@ -151,7 +157,7 @@ class OcrDataset(OFADataset):
|
|
151 |
self.patch_resize_transform = transforms.Compose(
|
152 |
[
|
153 |
lambda image: ocr_resize(
|
154 |
-
image, patch_image_size, is_document=is_document
|
155 |
),
|
156 |
transforms.ToTensor(),
|
157 |
transforms.Normalize(mean=mean, std=std),
|
|
|
82 |
return batch
|
83 |
|
84 |
|
85 |
+
def ocr_resize(img, patch_image_size, is_document=False, split='train'):
|
86 |
img = img.convert("RGB")
|
87 |
width, height = img.size
|
88 |
|
|
|
92 |
if width >= height:
|
93 |
new_width = max(64, patch_image_size)
|
94 |
new_height = max(64, int(patch_image_size * (height / width)))
|
95 |
+
if split != 'train':
|
96 |
+
top = int((patch_image_size - new_height) // 2)
|
97 |
+
else:
|
98 |
+
top = random.randint(0, patch_image_size - new_height)
|
99 |
bottom = patch_image_size - new_height - top
|
100 |
left, right = 0, 0
|
101 |
else:
|
102 |
new_height = max(64, patch_image_size)
|
103 |
new_width = max(64, int(patch_image_size * (width / height)))
|
104 |
+
if split != 'train':
|
105 |
+
left = int((patch_image_size - new_width) // 2)
|
106 |
+
else:
|
107 |
+
left = random.randint(0, patch_image_size - new_width)
|
108 |
right = patch_image_size - new_width - left
|
109 |
top, bottom = 0, 0
|
110 |
|
|
|
157 |
self.patch_resize_transform = transforms.Compose(
|
158 |
[
|
159 |
lambda image: ocr_resize(
|
160 |
+
image, patch_image_size, is_document=is_document, split=split,
|
161 |
),
|
162 |
transforms.ToTensor(),
|
163 |
transforms.Normalize(mean=mean, std=std),
|