MuhammadHanif commited on
Commit
8776e89
1 Parent(s): e02d7a5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import jax
3
+ import numpy as np
4
+ import jax.numpy as jnp
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionPipeline
9
+
10
+ def create_key(seed=0):
11
+ return jax.random.PRNGKey(seed)
12
+
13
+
14
+ pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
15
+ "MuhammadHanif/stable-diffusion-v1-5-high-res",
16
+ dtype=jnp.bfloat16,
17
+ use_memory_efficient_attention=True
18
+ )
19
+
20
+ def infer(prompts, negative_prompts):
21
+
22
+ num_samples = 1 #jax.device_count()
23
+ rng = create_key(0)
24
+ rng = jax.random.split(rng, jax.device_count())
25
+
26
+ prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
27
+ negative_prompt_ids = pipe.prepare_inputs([negative_prompts] * num_samples)
28
+
29
+ p_params = replicate(params)
30
+ prompt_ids = shard(prompt_ids)
31
+ negative_prompt_ids = shard(negative_prompt_ids)
32
+
33
+ output = pipe(
34
+ prompt_ids=prompt_ids,
35
+ params=p_params,
36
+ height=1088,
37
+ width=1088,
38
+ prng_seed=rng,
39
+ num_inference_steps=50,
40
+ neg_prompt_ids=negative_prompt_ids,
41
+ jit=True,
42
+ ).images
43
+
44
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
45
+ return output_images
46
+
47
+ gr.Interface(infer, inputs=["text", "text"], outputs="gallery").launch()