Spaces:
Build error
Build error
import gradio as gr | |
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel | |
from PIL import Image | |
import torch | |
import warnings | |
import re | |
import json | |
import os | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import argparse | |
from scipy import optimize | |
from typing import Optional | |
import dataclasses | |
import editdistance | |
import itertools | |
import sys | |
import time | |
import logging | |
import subprocess | |
import spaces | |
import openai | |
import base64 | |
from io import StringIO | |
# Git LFS pull λͺ λ Ήμ΄ μ€ν | |
result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True) | |
# λͺ λ Ήμ΄ μ€ν κ²°κ³Ό μΆλ ₯ (μ ν μ¬ν) | |
if result.returncode == 0: | |
print("LFS νμΌμ΄ μ±κ³΅μ μΌλ‘ λ€μ΄λ‘λλμμ΅λλ€.") | |
else: | |
print(f"μ€λ₯ λ°μ: {result.stderr}") | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger() | |
warnings.filterwarnings('ignore') | |
MAX_PATCHES = 512 | |
# Load the models and processor | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Paths to the models | |
ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin' | |
# Load first model ko-deplot | |
def load_model1(): | |
processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot') | |
model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot') | |
model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu")) | |
model1.to(torch.device("cuda")) | |
return processor1, model1 | |
processor1, model1 = load_model1() | |
# Function to format output | |
def format_output(prediction): | |
return prediction.replace('<0x0A>', '\n') | |
# First model prediction: ko-deplot | |
def predict_model1(image): | |
images = [image] | |
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True) | |
inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU | |
model1.eval() | |
with torch.no_grad(): | |
predictions = model1.generate(**inputs, max_new_tokens=4096) | |
outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions] | |
formatted_output = format_output(outputs[0]) | |
return formatted_output | |
# Set your OpenAI API key | |
openai.api_key = "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" | |
# Function to encode the image as base64 | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode("utf-8") | |
# Second model prediction: gpt-4o-mini | |
def predict_model2(image): | |
# Encode the uploaded image to base64 | |
image_data = encode_image(image) | |
# Prepare the request content | |
response = openai.ChatCompletion.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "text", | |
"text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ." | |
}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/jpeg;base64,{image_data}" | |
} | |
} | |
] | |
} | |
] | |
) | |
# Return the table data from the response | |
return response.choices[0]["message"]["content"] | |
def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe | |
lines = label_table_str.strip().split("\n") | |
data=[] | |
title= lines[0].split(" | ")[1] | |
if(len(lines[1].split("|")) == len(lines[2].split("|"))): | |
headers=lines[1].split(" | ") | |
for line in lines[2:]: | |
data.append(line.split(" | ")) | |
df = pd.DataFrame(data, columns=headers) | |
return df, title | |
else: | |
legend_row=lines[1].split("|") | |
legend_row.insert(0," ") | |
for line in lines[2:]: | |
data.append(line.split(" | ")) | |
df = pd.DataFrame(data, columns=legend_row) | |
return df, title | |
def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe | |
try: | |
# Split the text into lines | |
lines = table_text.strip().split("\n") | |
title=lines[0] | |
lines.pop(1) | |
lines.pop(2) | |
# Process the remaining lines to create the DataFrame | |
data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items | |
dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers | |
return dataframe, title | |
except Exception as e: | |
return f"Error converting table to DataFrame: {e}" | |
def real_time_check(image_file): | |
image = Image.open(image_file) | |
ko_deplot_generated_txt = predict_model1(image) | |
parts=ko_deplot_generated_txt.split("\n") | |
del parts[-1] | |
ko_deplot_generated_txt="\n".join(parts) | |
gpt_generated_txt=predict_model2(image_file) | |
try: | |
ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt) | |
gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt) | |
return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0 | |
except Exception as e: | |
return None,None,ko_deplot_generated_txt,gpt_generated_txt,1 | |
flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens. | |
def inference(image_uploader,mode_selector): | |
if(mode_selector=="νμΌ μ λ‘λ"): | |
ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader) | |
if flag==1: | |
return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True) | |
else: | |
return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False) | |
else: | |
ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index]) | |
if flag==1: | |
return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True) | |
else: | |
return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False) | |
def toggle_model(selected_models,flag): | |
# Create a visibility list initialized to False for all components | |
visibility = [False] * 6 | |
# Update visibility based on the selected models | |
if "VAIV_DePlot" in selected_models: | |
visibility[4]= True | |
if flag: | |
visibility[2]= True | |
else: | |
visibility[0]= True | |
if "gpt-4o-mini" in selected_models: | |
visibility[5]= True | |
if flag: | |
visibility[3]= True | |
else: | |
visibility[1]= True | |
if "all" in selected_models: | |
visibility[4]=True | |
visibility[5]=True | |
if flag: | |
visibility[2]= True | |
visibility[3]= True | |
else: | |
visibility[0]= True | |
visibility[1]= True | |
# Return gr.update for each component with the corresponding visibility status | |
return tuple(gr.update(visible=v) for v in visibility) | |
def toggle_mode(mode): | |
if mode == "νμΌ μ λ‘λ": | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) | |
else: | |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
def display_image(image_file): | |
image=Image.open(image_file) | |
return image, os.path.basename(image_file) | |
# Function to display the images in the folder sequentially | |
image_files = [] | |
current_image_index = 0 | |
image_files_cnt=0 | |
def display_folder_images(image_file_path_list): | |
global image_files, current_image_index,image_files_cnt | |
image_files = image_file_path_list | |
image_files_cnt=len(image_files) | |
current_image_index = 0 | |
if image_files: | |
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True) | |
return None, "No images found" | |
def next_image(): | |
global current_image_index | |
if image_files: | |
current_image_index = (current_image_index + 1) | |
prev_disabled = current_image_index == 0 | |
next_disabled = current_image_index == (len(image_files) - 1) | |
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled) | |
return None, "No images found" | |
def prev_image(): | |
global current_image_index | |
if image_files: | |
current_image_index = (current_image_index - 1) | |
prev_disabled = current_image_index == 0 | |
next_disabled = current_image_index == (len(image_files) - 1) | |
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled) | |
return None, "No images found" | |
css = """ | |
.dataframe-class { | |
overflow-y: auto !important; /* μ€ν¬λ‘€μ κ°λ₯νκ² */ | |
height: 250px | |
} | |
""" | |
with gr.Blocks(css=css) as iface: | |
with gr.Row(): | |
gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>") | |
gr.Markdown("<hr style='border: 1px solid #ddd;' />") | |
with gr.Row(): | |
with gr.Column(): | |
mode_selector = gr.Radio(["νμΌ μ λ‘λ", "ν΄λ μ λ‘λ"], label="Upload Mode", value="νμΌ μ λ‘λ") | |
image_uploader = gr.File(file_count="single", file_types=["image"], visible=True) | |
folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50) | |
model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True) | |
image_displayer = gr.Image(visible=True) | |
image_name = gr.Text("", visible=True) | |
with gr.Row(): | |
prev_button = gr.Button("μ΄μ ", visible=False, interactive=False) | |
next_button = gr.Button("λ€μ", visible=False, interactive=False) | |
inference_button = gr.Button("μΆλ‘ ") | |
with gr.Column(): | |
md1 = gr.Markdown("# VAIV_DePlot Inference Result") | |
ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class") | |
ko_deplot_generated_txt = gr.Text(visible=False) | |
with gr.Column(): | |
md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False) | |
gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class") | |
gpt_generated_txt = gr.Text(visible=False) | |
#label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1) | |
model_type.change( | |
toggle_model, | |
inputs=[model_type, gr.State(flag)], | |
outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2] | |
) | |
mode_selector.change( | |
toggle_mode, | |
inputs=[mode_selector], | |
outputs=[image_uploader, folder_uploader, prev_button, next_button] | |
) | |
image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name]) | |
folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button]) | |
prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button]) | |
next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button]) | |
inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt]) | |
if __name__ == "__main__": | |
iface.launch(share=True) | |