Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import shutil | |
from main import fine_tune_model | |
from diffusers import StableDiffusionPipeline, DDIMScheduler | |
import torch | |
MODEL_NAME = "runwayml/stable-diffusion-v1-5" | |
OUTPUT_DIR = "/home/user/app/stable_diffusion_weights/custom_model" | |
def fine_tune(instance_prompt, image1, image2=None): | |
instance_data_dir = "/home/user/app/instance_images" | |
try: | |
if os.path.exists(instance_data_dir): | |
shutil.rmtree(instance_data_dir) | |
os.makedirs(instance_data_dir, exist_ok=True) | |
image1.save(os.path.join(instance_data_dir, "instance_0.png")) | |
if image2 is not None: | |
image2.save(os.path.join(instance_data_dir, "instance_1.png")) | |
fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR) | |
return "Model fine-tuning complete." | |
except Exception as e: | |
return str(e) | |
def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale): | |
try: | |
if not os.path.exists(OUTPUT_DIR): | |
return "The model path does not exist." | |
pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float16).to("cuda") | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
g_cuda = torch.Generator(device='cuda').manual_seed(1337) | |
with torch.autocast("cuda"), torch.inference_mode(): | |
images = pipe( | |
prompt, height=height, width=width, num_images_per_prompt=num_samples, | |
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda | |
).images | |
return images | |
except Exception as e: | |
return str(e) | |
def gradio_app(): | |
with gr.Blocks() as demo: | |
with gr.Tab("Fine-Tune Model"): | |
with gr.Row(): | |
with gr.Column(): | |
instance_prompt = gr.Textbox(label="Instance Prompt") | |
image1 = gr.Image(label="Upload Image 1", type="pil") | |
image2 = gr.Image(label="Upload Image 2 (Optional)", type="pil") | |
fine_tune_button = gr.Button("Fine-Tune Model") | |
output_text = gr.Textbox(label="Output") | |
fine_tune_button.click(fine_tune, inputs=[instance_prompt, image1, image2], outputs=output_text) | |
with gr.Tab("Generate Images"): | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
num_samples = gr.Number(label="Number of Samples", value=1) | |
guidance_scale = gr.Number(label="Guidance Scale", value=7.5) | |
height = gr.Number(label="Height", value=512) | |
width = gr.Number(label="Width", value=512) | |
num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100) | |
generate_button = gr.Button("Generate Images") | |
with gr.Column(): | |
gallery = gr.Gallery(label="Generated Images") | |
generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery) | |
demo.launch() | |
if __name__ == "__main__": | |
gradio_app() |