File size: 2,426 Bytes
f52a776
 
 
 
 
 
 
 
 
108dbc4
f52a776
 
 
 
0394a07
f52a776
 
 
 
 
 
 
2b2d98d
 
 
f52a776
 
9f23c08
f52a776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f23c08
f52a776
 
 
 
 
 
 
 
 
9f23c08
f52a776
9f23c08
f52a776
 
 
9f23c08
f52a776
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Adapted from https://huggingface.co/spaces/stabilityai/stable-diffusion
"""

import time
import keras_cv
import gradio as gr
from tensorflow import keras
from share_btn import community_icon_html, loading_icon_html, share_js
from constants import css, img_height, img_width, num_images_to_gen, unconditional_guidance_scale
keras.mixed_precision.set_global_policy("mixed_float16")

# Load model
weights_path = keras.utils.get_file(
    origin="https://huggingface.co/clint-greene/magic-the-gathering-sd/resolve/main/magic-the-gathering-sd.h5",
)

magic_model = keras_cv.models.StableDiffusion(
    img_width=img_width, img_height=img_height
)

magic_model.diffusion_model.load_weights(weights_path)
magic_model.diffusion_model.compile(jit_compile=True)
magic_model.decoder.compile(jit_compile=True)
magic_model.text_encoder.compile(jit_compile=True)

# Warm-up the model
_ = magic_model.text_to_image("flying dragons", batch_size=num_images_to_gen, num_steps=15)

def generate_image_fn(prompt: str, steps: int) -> list:
    start_time = time.time()
    # `images is an `np.ndarray`. So we convert it to a list of ndarrays.
    # Each ndarray represents a generated image.
    # Reference: https://gradio.app/docs/#gallery
    images = magic_model.text_to_image(
        prompt,
        batch_size=num_images_to_gen,
        num_steps=steps,
        unconditional_guidance_scale=unconditional_guidance_scale,
    )
    end_time = time.time()
    print(f"Time taken: {end_time - start_time} seconds.")
    return [image for image in images]


description = "This Space demonstrates a fine-tuned Stable Diffusion model. You can use it for generating custom Magic the Gathering cards. To get started, either enter a prompt or pick one from the examples below. For details on the fine-tuning procedure, refer to [this tutorial](https://gpuopen.com/)."
article = "We use mixed-precision and XLA to speed up the inference latency."
gr.Interface(
    generate_image_fn,
    inputs=[
        gr.Textbox(
            label="Enter your prompt",
            max_lines=1,
            placeholder="Jedi",
        ),
        gr.Slider(value=30, minimum=10, maximum=100, step=1),
    ],
    outputs=gr.Gallery().style(height="auto"),
    title="Generate custom magic the gathering cards",
    description=description,
    article=article,
    examples=[["Yoda", 30], ["Lisa Su", 30]],
    allow_flagging=False,
).launch(enable_queue=True)