MuhammadHanif's picture
adding more examples
94ed80b
raw
history blame
No virus
2.98 kB
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionPipeline
def create_key(seed=0):
return jax.random.PRNGKey(seed)
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"MuhammadHanif/stable-diffusion-v1-5-high-res",
dtype=jnp.bfloat16,
use_memory_efficient_attention=True
)
def infer(prompts, negative_prompts, width=1088, height=1088, inference_steps=30, seed=0):
num_samples = 1 #jax.device_count()
rng = create_key(int(seed))
rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
output = pipe(
prompt_ids=prompt_ids,
params=p_params,
height=height,
width=width,
prng_seed=rng,
num_inference_steps=inference_steps,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images[0]
prompt_input = gr.inputs.Textbox(
label="Prompt",
placeholder="a highly detailed mansion in the autumn by studio ghibli, makoto shinkai"
)
neg_prompt_input = gr.inputs.Textbox(
label="Negative Prompt",
placeholder=""
)
width_slider = gr.inputs.Slider(
minimum=512, maximum=2048, default=1088, step=64, label="width"
)
height_slider = gr.inputs.Slider(
minimum=512, maximum=2048, default=1088, step=64, label="height"
)
inf_steps_input = gr.inputs.Slider(
minimum=1, maximum=100, default=30, step=1, label="Inference Steps"
)
seed_input = gr.inputs.Number(default=0, label="Seed")
app = gr.Interface(
fn=infer,
inputs=[prompt_input, neg_prompt_input, width_slider, height_slider, inf_steps_input, seed_input],
outputs="image",
title="Stable Diffusion High Resolution",
description=(
"Based on stable diffusion 1.5 and fine-tuned on 576x576 up to 1088x1088 images, "
"Stable Diffusion High Resolution is compartible with another SD1.5 model and mergeable with other SD1.5 model, "
"giving other model to generate high resolution images without using upscaler."
),
examples=[
["a highly detailed mansion in the autumn by studio ghibli, makoto shinkai","", 1088, 1088, 30, 0],
["best high quality landscape, in the morning light, Overlooking TOKYO beautiful city with Fujiyama, from a tall house, by greg rutkowski and thomas kinkade, Trending on artstation makoto shinkai style","", 1088, 576, 30, 0],
[" assassin's creed black flag, hd, 4k, dlsr ","", 960, 960, 30, 4154731],
],
)
app.launch()