ultimate_gradio / app.py
Sangjun2's picture
new_new_new_vaiv_app.py
fb56a77 verified
import gradio as gr
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, T5Tokenizer, T5ForConditionalGeneration, Pix2StructProcessor, BartConfig,ViTConfig,VisionEncoderDecoderConfig, DonutProcessor, VisionEncoderDecoderModel, AutoTokenizer, AutoModel
from PIL import Image
import torch
import warnings
import re
import json
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import argparse
from scipy import optimize
from typing import Optional
import dataclasses
import editdistance
import itertools
import sys
import time
import logging
import subprocess
import spaces
import openai
import base64
from io import StringIO
# Git LFS pull λͺ…λ Ήμ–΄ μ‹€ν–‰
result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
# λͺ…λ Ήμ–΄ μ‹€ν–‰ κ²°κ³Ό 좜λ ₯ (선택 사항)
if result.returncode == 0:
print("LFS 파일이 μ„±κ³΅μ μœΌλ‘œ λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
else:
print(f"였λ₯˜ λ°œμƒ: {result.stderr}")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger()
warnings.filterwarnings('ignore')
MAX_PATCHES = 512
# Load the models and processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Paths to the models
ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin'
# Load first model ko-deplot
def load_model1():
processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
model1.to(torch.device("cuda"))
return processor1, model1
processor1, model1 = load_model1()
# Function to format output
def format_output(prediction):
return prediction.replace('<0x0A>', '\n')
# First model prediction: ko-deplot
def predict_model1(image):
images = [image]
inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()} # Move to GPU
model1.eval()
with torch.no_grad():
predictions = model1.generate(**inputs, max_new_tokens=4096)
outputs = [processor1.decode(pred, skip_special_tokens=True) for pred in predictions]
formatted_output = format_output(outputs[0])
return formatted_output
# Set your OpenAI API key
openai.api_key = "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"
# Function to encode the image as base64
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
# Second model prediction: gpt-4o-mini
def predict_model2(image):
# Encode the uploaded image to base64
image_data = encode_image(image)
# Prepare the request content
response = openai.ChatCompletion.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ."
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
}
}
]
}
]
)
# Return the table data from the response
return response.choices[0]["message"]["content"]
def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe
lines = label_table_str.strip().split("\n")
data=[]
title= lines[0].split(" | ")[1]
if(len(lines[1].split("|")) == len(lines[2].split("|"))):
headers=lines[1].split(" | ")
for line in lines[2:]:
data.append(line.split(" | "))
df = pd.DataFrame(data, columns=headers)
return df, title
else:
legend_row=lines[1].split("|")
legend_row.insert(0," ")
for line in lines[2:]:
data.append(line.split(" | "))
df = pd.DataFrame(data, columns=legend_row)
return df, title
def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe
try:
# Split the text into lines
lines = table_text.strip().split("\n")
title=lines[0]
lines.pop(1)
lines.pop(2)
# Process the remaining lines to create the DataFrame
data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items
dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers
return dataframe, title
except Exception as e:
return f"Error converting table to DataFrame: {e}"
def real_time_check(image_file):
image = Image.open(image_file)
ko_deplot_generated_txt = predict_model1(image)
parts=ko_deplot_generated_txt.split("\n")
del parts[-1]
ko_deplot_generated_txt="\n".join(parts)
gpt_generated_txt=predict_model2(image_file)
try:
ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt)
gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt)
return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0
except Exception as e:
return None,None,ko_deplot_generated_txt,gpt_generated_txt,1
flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens.
def inference(image_uploader,mode_selector):
if(mode_selector=="파일 μ—…λ‘œλ“œ"):
ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader)
if flag==1:
return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
else:
return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
else:
ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index])
if flag==1:
return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
else:
return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
def toggle_model(selected_models,flag):
# Create a visibility list initialized to False for all components
visibility = [False] * 6
# Update visibility based on the selected models
if "VAIV_DePlot" in selected_models:
visibility[4]= True
if flag:
visibility[2]= True
else:
visibility[0]= True
if "gpt-4o-mini" in selected_models:
visibility[5]= True
if flag:
visibility[3]= True
else:
visibility[1]= True
if "all" in selected_models:
visibility[4]=True
visibility[5]=True
if flag:
visibility[2]= True
visibility[3]= True
else:
visibility[0]= True
visibility[1]= True
# Return gr.update for each component with the corresponding visibility status
return tuple(gr.update(visible=v) for v in visibility)
def toggle_mode(mode):
if mode == "파일 μ—…λ‘œλ“œ":
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
def display_image(image_file):
image=Image.open(image_file)
return image, os.path.basename(image_file)
# Function to display the images in the folder sequentially
image_files = []
current_image_index = 0
image_files_cnt=0
def display_folder_images(image_file_path_list):
global image_files, current_image_index,image_files_cnt
image_files = image_file_path_list
image_files_cnt=len(image_files)
current_image_index = 0
if image_files:
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True)
return None, "No images found"
def next_image():
global current_image_index
if image_files:
current_image_index = (current_image_index + 1)
prev_disabled = current_image_index == 0
next_disabled = current_image_index == (len(image_files) - 1)
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
return None, "No images found"
def prev_image():
global current_image_index
if image_files:
current_image_index = (current_image_index - 1)
prev_disabled = current_image_index == 0
next_disabled = current_image_index == (len(image_files) - 1)
return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
return None, "No images found"
css = """
.dataframe-class {
overflow-y: auto !important; /* μŠ€ν¬λ‘€μ„ κ°€λŠ₯ν•˜κ²Œ */
height: 250px
}
"""
with gr.Blocks(css=css) as iface:
with gr.Row():
gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>")
gr.Markdown("<hr style='border: 1px solid #ddd;' />")
with gr.Row():
with gr.Column():
mode_selector = gr.Radio(["파일 μ—…λ‘œλ“œ", "폴더 μ—…λ‘œλ“œ"], label="Upload Mode", value="파일 μ—…λ‘œλ“œ")
image_uploader = gr.File(file_count="single", file_types=["image"], visible=True)
folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50)
model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True)
image_displayer = gr.Image(visible=True)
image_name = gr.Text("", visible=True)
with gr.Row():
prev_button = gr.Button("이전", visible=False, interactive=False)
next_button = gr.Button("λ‹€μŒ", visible=False, interactive=False)
inference_button = gr.Button("μΆ”λ‘ ")
with gr.Column():
md1 = gr.Markdown("# VAIV_DePlot Inference Result")
ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
ko_deplot_generated_txt = gr.Text(visible=False)
with gr.Column():
md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False)
gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class")
gpt_generated_txt = gr.Text(visible=False)
#label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1)
model_type.change(
toggle_model,
inputs=[model_type, gr.State(flag)],
outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2]
)
mode_selector.change(
toggle_mode,
inputs=[mode_selector],
outputs=[image_uploader, folder_uploader, prev_button, next_button]
)
image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button])
prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button])
next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button])
inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt])
if __name__ == "__main__":
iface.launch(share=True)