Magix / app.py
Singularity666's picture
Update app.py
2c282ba verified
import gradio as gr
import os
import shutil
from main import fine_tune_model
from diffusers import StableDiffusionPipeline, DDIMScheduler
import torch
MODEL_NAME = "runwayml/stable-diffusion-v1-5"
OUTPUT_DIR = "/home/user/app/stable_diffusion_weights/custom_model"
def fine_tune(instance_prompt, image1, image2=None):
instance_data_dir = "/home/user/app/instance_images"
try:
if os.path.exists(instance_data_dir):
shutil.rmtree(instance_data_dir)
os.makedirs(instance_data_dir, exist_ok=True)
image1.save(os.path.join(instance_data_dir, "instance_0.png"))
if image2 is not None:
image2.save(os.path.join(instance_data_dir, "instance_1.png"))
fine_tune_model(instance_data_dir, instance_prompt, MODEL_NAME, OUTPUT_DIR)
return "Model fine-tuning complete."
except Exception as e:
return str(e)
def generate_images(prompt, num_samples, height, width, num_inference_steps, guidance_scale):
try:
if not os.path.exists(OUTPUT_DIR):
return "The model path does not exist."
pipe = StableDiffusionPipeline.from_pretrained(OUTPUT_DIR, safety_checker=None, torch_dtype=torch.float16).to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
g_cuda = torch.Generator(device='cuda').manual_seed(1337)
with torch.autocast("cuda"), torch.inference_mode():
images = pipe(
prompt, height=height, width=width, num_images_per_prompt=num_samples,
num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=g_cuda
).images
return images
except Exception as e:
return str(e)
def gradio_app():
with gr.Blocks() as demo:
with gr.Tab("Fine-Tune Model"):
with gr.Row():
with gr.Column():
instance_prompt = gr.Textbox(label="Instance Prompt")
image1 = gr.Image(label="Upload Image 1", type="pil")
image2 = gr.Image(label="Upload Image 2 (Optional)", type="pil")
fine_tune_button = gr.Button("Fine-Tune Model")
output_text = gr.Textbox(label="Output")
fine_tune_button.click(fine_tune, inputs=[instance_prompt, image1, image2], outputs=output_text)
with gr.Tab("Generate Images"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt")
num_samples = gr.Number(label="Number of Samples", value=1)
guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
height = gr.Number(label="Height", value=512)
width = gr.Number(label="Width", value=512)
num_inference_steps = gr.Slider(label="Steps", value=50, minimum=1, maximum=100)
generate_button = gr.Button("Generate Images")
with gr.Column():
gallery = gr.Gallery(label="Generated Images")
generate_button.click(generate_images, inputs=[prompt, num_samples, height, width, num_inference_steps, guidance_scale], outputs=gallery)
demo.launch()
if __name__ == "__main__":
gradio_app()