mp-02's picture
Upload folder using huggingface_hub
6d1caf6 verified
raw
history blame
2.48 kB
from cord_inference import prediction as cord_prediction
from sroie_inference import prediction as sroie_prediction
import gradio as gr
import json
def prediction(image_path: str):
#we first use mp-02/layoutlmv3-finetuned-cord on the image, which gives us a JSON with some info and a blurred image
d, image = sroie_prediction(image_path)
#we save the blurred image in order to pass it to the other model
image_path_blurred = image_path.split('.')[0] + '_blurred.' + image_path.split('.')[1]
image.save(image_path_blurred)
#then we use the model fine-tuned on sroie (for now it is Theivaprakasham/layoutlmv3-finetuned-sroie)
d1, image1 = cord_prediction(image_path_blurred)
#we then link the two json files
if len(d) == 0:
k = d1
else:
k = json.dumps(d).split('}')[0] + ', ' + json.dumps(d1).split('{')[1]
return d, image, d1, image1, k
# p,i,j = prediction("11990982-img.png")
# print(p)
title = "Interactive demo: LayoutLMv3 for receipts"
description = "Demo for Microsoft's LayoutLMv3, a Transformer for state-of-the-art document image understanding tasks. This particular model is fine-tuned on CORD and SROIE, which are datasets of receipts.\n It firsts uses the fine-tune on SROIE to extract date, company and address, then the fine-tune on CORD for the other info.\n To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
examples = [['image.jpg']]
css = """.output_image, .input_image {height: 600px !important}"""
# we use a gradio interface that takes in input an image and return a JSON file that contains its info
# we show also the intermediate steps (first we take some info with the model fine-tuned on SROIE and we blur the relative boxes
# then we pass the image to the model fine-tuned on CORD
iface = gr.Interface(fn=prediction,
inputs=gr.Image(type="filepath"),
outputs=[gr.JSON(label="json parsing"),
gr.Image(type="pil", label="blurred image"),
gr.JSON(label="json parsing"),
gr.Image(type="pil", label="annotated image"),
gr.JSON(label="json parsing")],
title=title,
description=description,
examples=examples,
css=css)
iface.launch()