yonatanbitton's picture
Update app.py
d5bd6a6 verified
raw
history blame
No virus
2.95 kB
import os
import gradio as gr
from datasets import load_dataset
auth_token = os.environ.get("auth_token")
visit_bench_all = load_dataset("mlfoundations/VisIT-Bench", use_auth_token=auth_token)
print('visit_bench_all')
print(visit_bench_all)
print('dataset keys:')
print(visit_bench_all.keys())
dataset_keys = list(visit_bench_all.keys())
assert len(dataset_keys) == 1
dataset_key = dataset_keys[0]
visit_bench = visit_bench_all[dataset_key]
print('first item:')
print(visit_bench[0])
df = visit_bench.to_pandas()
print(f"Got {len(df)} items in dataframe")
df = df.sample(frac=1)
# df['image'] = df['image_url'].apply(lambda x: '<a href= "' + str(x) + '" target="_blank"> <img src= "' + str(
# x) + '"/> </a>')
# df['image'] = df['image_url'].apply(lambda x: '<a href= "' + str(x) + '" target="_blank"> <img src= "' + str(
# x) + '" width="400"/> </a>')
# df['image'] = df['image'].apply(lambda x: '<a href= "' + str(x) + '" target="_blank"> <img src= "' + str(
# x) + '" width="400"/> </a>')
df['image'] = df['image'].apply(lambda x: '<a href= "' + str(x['path']) + '" target="_blank"> <img src= "' + str(
x['path']) + '" width="400"/> </a>')
cols = list(df.columns)
cols.insert(0, cols.pop(cols.index('image')))
df = df.reindex(columns=cols)
LINES_NUMBER = 20
def display_df():
df_images = df.head(LINES_NUMBER)
return df_images
def display_next(dataframe, end):
start = int(end or len(dataframe))
end = int(start) + int(LINES_NUMBER)
global df
if end >= len(df) - 1:
start = 0
end = LINES_NUMBER
df = df.sample(frac=1)
print(f"Shuffle")
df_images = df.iloc[start:end]
assert len(df_images) == LINES_NUMBER
return df_images, end
initial_dataframe = display_df()
# Gradio Blocks
with gr.Blocks() as demo:
gr.Markdown("<h1><center>VisIT-Bench Dataset Viewer</center></h1>")
with gr.Row():
num_end = gr.Number(visible=False)
b1 = gr.Button("Get Initial dataframe")
b2 = gr.Button("Next Rows")
with gr.Row():
# out_dataframe = gr.Dataframe(initial_dataframe, wrap=True, max_rows=LINES_NUMBER, overflow_row_behaviour="paginate",
# interactive=False,
# datatype=["markdown", "str", "str", "str", "bool",
# "bool", "bool", "str", "str", "str"])
out_dataframe = gr.Dataframe(initial_dataframe, wrap=True, row_count=LINES_NUMBER,
interactive=False,
datatype=["markdown", "str", "str", "str", "bool",
"bool", "bool", "str", "str", "str"])
b1.click(fn=display_df, outputs=out_dataframe, api_name="initial_dataframe")
b2.click(fn=display_next, inputs=[out_dataframe, num_end], outputs=[out_dataframe, num_end],
api_name="next_rows")
demo.launch(debug=True, show_error=True)