Spaces:
Build error
Build error
from transformers import AutoProcessor, Pix2StructForConditionalGeneration | |
import gradio as gr | |
import torch | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from typing import Tuple | |
from PIL import Image | |
import os | |
import sys | |
os.system("pip install 'git+https://github.com/facebookresearch/detectron2.git'") | |
os.system("git clone https://github.com/microsoft/unilm.git; cd unilm; git checkout 9102ed91f8e56baa31d7ae7e09e0ec98e77d779c; cd ..") | |
sys.path.append("unilm") | |
from unilm.dit.object_detection.ditod import add_vit_config | |
from detectron2.config import CfgNode as CN | |
from detectron2.config import get_cfg | |
from detectron2.data import MetadataCatalog | |
from detectron2.engine import DefaultPredictor | |
#Plot settings | |
sns.set_style("darkgrid") | |
palette = sns.color_palette("pastel") | |
sns.set_palette(palette) | |
plt.switch_backend("Agg") | |
# Load the DiT model config | |
cfg = get_cfg() | |
add_vit_config(cfg) | |
cfg.merge_from_file("unilm/dit/object_detection/publaynet_configs/cascade/cascade_dit_base.yaml") | |
# Get the model weights | |
cfg.MODEL.WEIGHTS = "https://layoutlm.blob.core.windows.net/dit/dit-fts/publaynet_dit-b_cascade.pth" | |
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Define the model predictor | |
predictor = DefaultPredictor(cfg) | |
# Load the DePlot model | |
model = Pix2StructForConditionalGeneration.from_pretrained("google/deplot").to(cfg.MODEL.DEVICE) | |
processor = AutoProcessor.from_pretrained("google/deplot") | |
def crop_figure(img: Image.Image , threshold: float = 0.5) -> Image.Image: | |
"""Prediction function for the figure cropping model using DiT backend. | |
Args: | |
img (Image.Image): Input document image. | |
threshold (float, optional): Detection threshold. Defaults to 0.5. | |
Returns: | |
(Image.Image): The cropped figure image. | |
""" | |
md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) | |
md.set(thing_classes=["text","title","list","table","figure"]) | |
output = predictor(np.array(img))["instances"] | |
boxes, scores, classes = output.pred_boxes.tensor.cpu().numpy(), output.scores.cpu().numpy(), output.pred_classes.cpu().numpy() | |
boxes = boxes[classes == 4] # 4 is the class for figures | |
scores = scores[classes == 4] | |
if len(boxes) == 0: | |
return [] | |
print(boxes, scores) | |
# sort boxes by score | |
crop_box = boxes[np.argsort(scores)[::-1]][0] | |
# Add white space around the figure | |
margin = 0.1 | |
box_size = crop_box[-2:] - crop_box[:2] | |
size = tuple((box_size + np.array([margin, margin]) * box_size).astype(int)) | |
new = Image.new('RGB', size, (255, 255, 255)) | |
image = img.crop(crop_box) | |
new.paste(image, (int((size[0] - image.size[0]) / 2), int(((size[1]) - image.size[1]) / 2))) | |
return new | |
def extract_tables(image: Image.Image) -> Tuple[str]: | |
"""Prediction function for the table extraction model using DePlot backend. | |
Args: | |
image (Image.Image): Input figure image. | |
Returns: | |
Tuple[str]: The table title, the table as a pandas dataframe, and the table as an HTML string, if the table was successfully extracted. | |
""" | |
inputs = processor(image, text="Generate a data table using only the data you see in the graph below: ", return_tensors="pt").to(cfg.MODEL.DEVICE) | |
with torch.no_grad(): | |
outputs = model.generate(**inputs, max_new_tokens=512) | |
decoded = processor.decode(outputs[0], skip_special_tokens=True) | |
print(decoded.replace("<0x0A>", "\n") ) | |
data = [row.split(" | ") for row in decoded.split("<0x0A>")] | |
try: | |
if data[0][0].lower().startswith("title"): | |
title = data[0][1] | |
table = pd.DataFrame(data[2:], columns=data[1]) | |
else: | |
title = "Table" | |
table = pd.DataFrame(data[1:], columns=data[0]) | |
return title, table, table.to_html() | |
except: | |
return "Table", list(list()), decoded.replace("<0x0A>", "\n") | |
def update(df: pd.DataFrame, plot_type: str) -> plt.figure: | |
"""Update callback for the gradio interface, that updates the plot based on the table data and selected plot type. | |
Args: | |
df (pd.DataFrame): The extracted table data. | |
plot_type (str): The selected plot type to generate. | |
Returns: | |
plt.figure: The updated plot. | |
""" | |
plt.close("all") | |
df = df.apply(pd.to_numeric, errors="ignore") | |
fig = plt.figure(figsize=(8, 6)) | |
ax = fig.add_subplot(111) | |
cols = df.columns | |
if len(cols) == 0: | |
return fig | |
if len(cols) > 1: | |
df.set_index(cols[0], inplace=True) | |
try: | |
if plot_type == "Line": | |
sns.lineplot(data=df, ax=ax) | |
elif plot_type == "Bar": | |
df = df.reset_index() | |
if len(cols) == 1: | |
sns.barplot(x=df.index, y=df[df.columns[0]], ax=ax) | |
elif len(cols) == 2: | |
sns.barplot(x=df[df.columns[0]], y=df[df.columns[1]], ax=ax) | |
else: | |
df = df.melt(id_vars=cols[0], value_vars=cols[1:], value_name="Value") | |
sns.barplot(x=df[cols[0]], y=df["Value"], hue=df["variable"], ax=ax) | |
elif plot_type == "Scatter": | |
sns.scatterplot(data=df, ax=ax) | |
elif plot_type == "Pie": | |
ax.pie(df[df.columns[0]], labels=df.index, autopct='%1.1f%%', colors=palette) | |
ax.axis('equal') | |
except: | |
pass | |
plt.tight_layout() | |
return fig | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1 align=center>Data extraction from charts</h1>") | |
gr.Markdown("This Space illustrates an experimental extraction pipeline using two pretrained models:" | |
" DiT is used to to find figures in a document and crop them." | |
" DePlot is used to extract the data from the plot and covert it to tabular format." | |
" Alternatively, you can paste a figure directly into the right Image field for data extraction." | |
" Finally, you can re-plot the extracted table using the Plot Type selector. And copy the HTML code to paste it elsewhere.") | |
with gr.Row() as row1: | |
input = gr.Image(image_mode="RGB", label="Document Page", type='pil') | |
cropped = gr.Image(image_mode="RGB", label="Cropped Image", type='pil') | |
with gr.Row() as row12: | |
crop_btn = gr.Button("Crop Figure") | |
extract_btn = gr.Button("Extract") | |
with gr.Row() as row13: | |
gr.Examples(["./2304.08069_2.png"], input) | |
gr.Examples(["./chartVQA.png"], cropped) | |
title = gr.Textbox(label="Title") | |
with gr.Row() as row2: | |
with gr.Column() as col1: | |
tab_data = gr.DataFrame(label="Table") | |
plot_type = gr.Radio(["Line", "Bar", "Scatter", "Pie"], label="Plot Type", default="Line") | |
plot_btn = gr.Button("Plot") | |
display = gr.Plot() | |
with gr.Row() as row3: | |
html_data = gr.Textbox(label="HTML copy-paste").style(show_copy_button=True, copy_button_label="Copy to clipboard") | |
crop_btn.click(crop_figure, input, [cropped]) | |
extract_btn.click(extract_tables, cropped, [title, tab_data, html_data]) | |
plot_btn.click(update, [tab_data, plot_type], display) | |
if __name__ == "__main__": | |
demo.launch() |