File size: 4,075 Bytes
1e1a292
c495d6b
 
 
 
7ad3aff
5fef1f4
7b96e45
c495d6b
 
 
7b96e45
 
1e1a292
 
 
c495d6b
7b96e45
 
 
 
 
 
 
c495d6b
 
1e1a292
 
 
7b96e45
 
1e1a292
7b96e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c495d6b
1e1a292
c495d6b
a8ce530
 
 
 
1e1a292
 
 
a8ce530
 
 
1e1a292
a8ce530
 
 
1e1a292
 
 
a8ce530
 
7b96e45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8ce530
1e1a292
 
a8ce530
7b96e45
1e1a292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8ce530
 
 
 
 
 
 
 
 
1e1a292
a8ce530
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
from huggingface_hub import from_pretrained_keras
from keras_cv import models
from tensorflow import keras

keras_model_list = [
    "keras-dreambooth/keras_diffusion_lowpoly_world",
    "keras-dreambooth/keras-diffusion-traditional-furniture",
]

stable_prompt_list = [
    "photo of lowpoly_world",
    "photo of traditional_furniture",
]

stable_negative_prompt_list = ["bad, ugly", "deformed"]

keras.mixed_precision.set_global_policy("mixed_float16")
dreambooth_model = models.StableDiffusion(
    img_width=512,
    img_height=512,
    jit_compile=True,
)


def keras_stable_diffusion(
    model_path: str,
    prompt: str,
    negative_prompt: str,
    num_imgs_to_gen: int,
    num_steps: int,
):
    """
    This function is used to generate images using our fine-tuned keras dreambooth stable diffusion model.
    Args:
        prompt (str): The text input given by the user based on which images will be generated.
        num_imgs_to_gen (int): The number of images to be generated using given prompt.
        num_steps (int): The number of denoising steps
    Returns:
        generated_img (List): List of images that were generated using the model
    """
    loaded_diffusion_model = from_pretrained_keras(model_path)
    dreambooth_model._diffusion_model = loaded_diffusion_model

    generated_img = dreambooth_model.text_to_image(
        prompt,
        negative_prompt=negative_prompt,
        batch_size=num_imgs_to_gen,
        num_steps=num_steps,
    )

    return generated_img


def keras_stable_diffusion_app():
    with gr.Blocks():
        with gr.Row():
            with gr.Column():
                keras_text2image_model_path = gr.Dropdown(
                    choices=keras_model_list,
                    value=keras_model_list[0],
                    label="Text-Image Model Id",
                )

                keras_text2image_prompt = gr.Textbox(
                    lines=1, value=stable_prompt_list[0], label="Prompt"
                )

                keras_text2image_negative_prompt = gr.Textbox(
                    lines=1,
                    value=stable_negative_prompt_list[0],
                    label="Negative Prompt",
                )

                keras_text2image_guidance_scale = gr.Slider(
                    minimum=0.1,
                    maximum=15,
                    step=0.1,
                    value=7.5,
                    label="Guidance Scale",
                )

                keras_text2image_num_inference_step = gr.Slider(
                    minimum=1,
                    maximum=100,
                    step=1,
                    value=50,
                    label="Num Inference Step",
                )

                keras_text2image_predict = gr.Button(value="Generator")

            with gr.Column():
                output_image = gr.Gallery(label="Outputs").style(grid=(1, 2))

        gr.Examples(
            fn=keras_stable_diffusion,
            inputs=[
                keras_text2image_model_path,
                keras_text2image_prompt,
                keras_text2image_negative_prompt,
                keras_text2image_guidance_scale,
                keras_text2image_num_inference_step,
            ],
            outputs=[output_image],
            examples=[
                [
                    keras_model_list[0],
                    stable_prompt_list[0],
                    stable_negative_prompt_list[0],
                    7.5,
                    50,
                    512,
                    512,
                ],
            ],
            label="Keras Stable Diffusion Example",
            cache_examples=False,
        )

        keras_text2image_predict.click(
            fn=keras_stable_diffusion,
            inputs=[
                keras_text2image_model_path,
                keras_text2image_prompt,
                keras_text2image_negative_prompt,
                keras_text2image_guidance_scale,
                keras_text2image_num_inference_step,
            ],
            outputs=output_image,
        )