throaway2854's picture
Update app.py
208758b verified
raw
history blame
10.1 kB
import gradio as gr
import os
import zipfile
import json
from io import BytesIO
import base64
from PIL import Image
import uuid
import tempfile
def save_dataset_to_zip(dataset_name, dataset):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
dataset_path = os.path.join(temp_dir, dataset_name)
os.makedirs(dataset_path, exist_ok=True)
images_dir = os.path.join(dataset_path, 'images')
os.makedirs(images_dir, exist_ok=True)
annotations = []
for idx, entry in enumerate(dataset):
image_data = entry['image']
prompt = entry['prompt']
# Save image to images directory
image_filename = f"{uuid.uuid4().hex}.png"
image_path = os.path.join(images_dir, image_filename)
image = Image.open(BytesIO(base64.b64decode(image_data.split(",")[1])))
image.save(image_path)
# Add annotation
annotations.append({
'file_name': os.path.join('images', image_filename),
'text': prompt
})
# Save annotations to JSONL file
annotations_path = os.path.join(dataset_path, 'annotations.jsonl')
with open(annotations_path, 'w') as f:
for ann in annotations:
f.write(json.dumps(ann) + '\n')
# Create a zip file
zip_buffer = BytesIO()
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(dataset_path):
for file in files:
abs_file = os.path.join(root, file)
rel_file = os.path.relpath(abs_file, dataset_path)
zipf.write(abs_file, rel_file)
zip_buffer.seek(0)
return zip_buffer
def load_dataset_from_zip(zip_file):
temp_dir = tempfile.mkdtemp()
with zipfile.ZipFile(zip_file.name, 'r') as zip_ref:
zip_ref.extractall(temp_dir)
# Assuming the dataset folder is the first folder in the zip
dataset_name = os.listdir(temp_dir)[0]
dataset_path = os.path.join(temp_dir, dataset_name)
dataset = []
images_dir = os.path.join(dataset_path, 'images')
annotations_path = os.path.join(dataset_path, 'annotations.jsonl')
if os.path.exists(annotations_path):
with open(annotations_path, 'r') as f:
for line in f:
ann = json.loads(line)
file_name = ann['file_name']
prompt = ann['text']
image_path = os.path.join(dataset_path, file_name)
# Read image and convert to base64
with open(image_path, 'rb') as img_f:
image_bytes = img_f.read()
encoded = base64.b64encode(image_bytes).decode()
mime_type = "image/png"
image_data = f"data:{mime_type};base64,{encoded}"
dataset.append({
'image': image_data,
'prompt': prompt
})
return dataset_name, dataset
def display_dataset_html(dataset):
if dataset:
html_content = ""
for idx, entry in enumerate(dataset):
image_data = entry['image']
prompt = entry['prompt']
html_content += f"""
<div style="display: flex; align-items: center; margin-bottom: 10px;">
<div style="width: 50px;">{idx}</div>
<img src="{image_data}" alt="Image {idx}" style="max-height: 100px; margin-right: 10px;"/>
<div>{prompt}</div>
</div>
"""
return html_content
else:
return "<div>No entries in dataset.</div>"
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>Dataset Builder</h1>")
datasets = gr.State({})
current_dataset_name = gr.State("")
dataset_selector = gr.Dropdown(label="Select Dataset", interactive=True)
dataset_html = gr.HTML()
message_box = gr.Textbox(interactive=False, label="Message")
with gr.Tab("Create / Upload Dataset"):
with gr.Row():
with gr.Column():
gr.Markdown("### Create a New Dataset")
dataset_name_input = gr.Textbox(label="New Dataset Name")
create_button = gr.Button("Create Dataset")
with gr.Column():
gr.Markdown("### Upload Existing Dataset")
upload_input = gr.File(label="Upload Dataset Zip", file_types=['.zip'])
upload_button = gr.Button("Upload Dataset")
def create_dataset(name, datasets):
if not name:
return gr.update(), "Please enter a dataset name."
if name in datasets:
return gr.update(), f"Dataset '{name}' already exists."
datasets[name] = []
return gr.update(choices=list(datasets.keys()), value=name), f"Dataset '{name}' created."
create_button.click(
create_dataset,
inputs=[dataset_name_input, datasets],
outputs=[dataset_selector, message_box]
)
def upload_dataset(zip_file, datasets):
if zip_file is None:
return gr.update(), "Please upload a zip file."
dataset_name, dataset = load_dataset_from_zip(zip_file)
if dataset_name in datasets:
return gr.update(), f"Dataset '{dataset_name}' already exists."
datasets[dataset_name] = dataset
return gr.update(choices=list(datasets.keys()), value=dataset_name), f"Dataset '{dataset_name}' uploaded."
upload_button.click(
upload_dataset,
inputs=[upload_input, datasets],
outputs=[dataset_selector, message_box]
)
def select_dataset(dataset_name, datasets):
if dataset_name in datasets:
dataset = datasets[dataset_name]
html_content = display_dataset_html(dataset)
return current_dataset_name.update(value=dataset_name), gr.update(value=html_content), ""
else:
return current_dataset_name.update(value=""), gr.update(value="<div>Select a dataset.</div>"), ""
dataset_selector.change(
select_dataset,
inputs=[dataset_selector, datasets],
outputs=[current_dataset_name, dataset_html, message_box]
)
with gr.Tab("Add Entry"):
with gr.Row():
image_input = gr.Image(label="Upload Image")
prompt_input = gr.Textbox(label="Prompt")
add_button = gr.Button("Add Entry")
def add_entry(image_data, prompt, current_dataset_name, datasets):
if not current_dataset_name:
return datasets, gr.update(), "No dataset selected."
if image_data is None or not prompt:
return datasets, gr.update(), "Please provide both an image and a prompt."
datasets[current_dataset_name].append({'image': image_data, 'prompt': prompt})
html_content = display_dataset_html(datasets[current_dataset_name])
return datasets, gr.update(value=html_content), f"Entry added to dataset '{current_dataset_name}'."
add_button.click(
add_entry,
inputs=[image_input, prompt_input, current_dataset_name, datasets],
outputs=[datasets, dataset_html, message_box]
)
with gr.Tab("Edit / Delete Entry"):
index_input = gr.Number(label="Entry Index", precision=0)
new_prompt_input = gr.Textbox(label="New Prompt (for Edit)")
with gr.Row():
edit_button = gr.Button("Edit Entry")
delete_button = gr.Button("Delete Entry")
def edit_entry(index, new_prompt, current_dataset_name, datasets):
if not current_dataset_name:
return datasets, gr.update(), "No dataset selected."
if index is None or new_prompt is None or new_prompt.strip() == '':
return datasets, gr.update(), "Please provide both index and new prompt."
index = int(index)
if 0 <= index < len(datasets[current_dataset_name]):
datasets[current_dataset_name][index]['prompt'] = new_prompt
html_content = display_dataset_html(datasets[current_dataset_name])
return datasets, gr.update(value=html_content), f"Entry {index} updated."
else:
return datasets, gr.update(), "Invalid index."
edit_button.click(
edit_entry,
inputs=[index_input, new_prompt_input, current_dataset_name, datasets],
outputs=[datasets, dataset_html, message_box]
)
def delete_entry(index, current_dataset_name, datasets):
if not current_dataset_name:
return datasets, gr.update(), "No dataset selected."
if index is None:
return datasets, gr.update(), "Please provide an index."
index = int(index)
if 0 <= index < len(datasets[current_dataset_name]):
del datasets[current_dataset_name][index]
html_content = display_dataset_html(datasets[current_dataset_name])
return datasets, gr.update(value=html_content), f"Entry {index} deleted."
else:
return datasets, gr.update(), "Invalid index."
delete_button.click(
delete_entry,
inputs=[index_input, current_dataset_name, datasets],
outputs=[datasets, dataset_html, message_box]
)
with gr.Tab("Download Dataset"):
download_button = gr.Button("Download Dataset")
download_output = gr.File(label="Download Zip")
def download_dataset(current_dataset_name, datasets):
if not current_dataset_name:
return None, "No dataset selected."
zip_buffer = save_dataset_to_zip(current_dataset_name, datasets[current_dataset_name])
return zip_buffer.getvalue(), f"Dataset '{current_dataset_name}' is ready for download."
download_button.click(
download_dataset,
inputs=[current_dataset_name, datasets],
outputs=[download_output, message_box]
)
# Initially update dataset_selector
def initialize_datasets(datasets):
return gr.update(choices=list(datasets.keys()))
demo.load(
initialize_datasets,
inputs=[datasets],
outputs=[dataset_selector]
)
demo.launch()