from huggingface_hub import from_pretrained_keras from keras_cv import models import gradio as gr sd_dreambooth_model = models.StableDiffusion( img_width=512, img_height=512 ) db_diffusion_model = from_pretrained_keras("keras-dreambooth/dreambooth_dosa") sd_dreambooth_model._diffusion_model = db_diffusion_model # generate images def generate_images(prompt, negative_prompt, num_imgs_to_gen, num_steps, guidance_scale): generated_images = sd_dreambooth_model.text_to_image( prompt, negative_prompt=negative_prompt, batch_size=num_imgs_to_gen, num_steps=num_steps, unconditional_guidance_scale=guidance_scale ) return generated_images with gr.Blocks() as demo: gr.HTML("

Keras Dreambooth - The Humble Dosa

") gr.HTML("

This model has been fine-tuned to learn the concept of a dosa.
To use this demo, insert the string bhr dosa in your prompt

") with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", lines=1, value="bhr dosa") negative_prompt = gr.Textbox(label="Negative Prompt", lines=1, value="deformed") samples = gr.Slider(label="Number of Images", minimum=1, maximum=4, value=1, step=1) num_steps = gr.Slider(label="Inference Steps", minimum=25, maximum=100, value=50, step=1) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=12, value=7.5, step=0.5) run = gr.Button(value="Run") with gr.Column(): gallery = gr.Gallery(label="Outputs").style(grid=(2,2)) run.click(fn=generate_images, inputs=[prompt, negative_prompt, samples, num_steps, guidance_scale], outputs=gallery) gr.Examples([["realistic picture of a man eating a bhr dosa", "home", 1, 50, 7.5], ["realistic picture of a bhr dosa on a plate", "chutney", 1, 50, 7.5], ["realistic picture of a bhr dosa in a restaurant", "sambar", 1, 50, 7.5], ], [prompt, negative_prompt, samples, num_steps, guidance_scale], gallery, generate_images, cache_examples=True) gr.Markdown('Demo created by [Bharat Raghunathan](https://huggingface.co/bharat-raghunathan/)') # pass function, input type for prompt, the output for multiple images demo.queue(concurrency_count=2).launch()