MuhammadHanif's picture
Create app.py
8776e89
raw
history blame
No virus
1.36 kB
import gradio as gr
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionPipeline
def create_key(seed=0):
return jax.random.PRNGKey(seed)
pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
"MuhammadHanif/stable-diffusion-v1-5-high-res",
dtype=jnp.bfloat16,
use_memory_efficient_attention=True
)
def infer(prompts, negative_prompts):
num_samples = 1 #jax.device_count()
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
negative_prompt_ids = shard(negative_prompt_ids)
output = pipe(
prompt_ids=prompt_ids,
params=p_params,
height=1088,
width=1088,
prng_seed=rng,
num_inference_steps=50,
neg_prompt_ids=negative_prompt_ids,
jit=True,
).images
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
return output_images
gr.Interface(infer, inputs=["text", "text"], outputs="gallery").launch()