Update README.md

#1
by merve HF staff - opened
Files changed (1) hide show
  1. app.py +7 -49
app.py CHANGED
@@ -17,10 +17,10 @@ pipe, params = FlaxStableDiffusionPipeline.from_pretrained(
17
  use_memory_efficient_attention=True
18
  )
19
 
20
- def infer(prompts, negative_prompts, width=1088, height=1088, inference_steps=30, seed=0):
21
 
22
  num_samples = 1 #jax.device_count()
23
- rng = create_key(int(seed))
24
  rng = jax.random.split(rng, jax.device_count())
25
 
26
  prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
@@ -33,57 +33,15 @@ def infer(prompts, negative_prompts, width=1088, height=1088, inference_steps=30
33
  output = pipe(
34
  prompt_ids=prompt_ids,
35
  params=p_params,
36
- height=height,
37
- width=width,
38
  prng_seed=rng,
39
- num_inference_steps=inference_steps,
40
  neg_prompt_ids=negative_prompt_ids,
41
  jit=True,
42
  ).images
43
 
44
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
45
- return output_images[0]
46
 
47
- prompt_input = gr.inputs.Textbox(
48
- label="Prompt",
49
- placeholder="a highly detailed mansion in the autumn by studio ghibli, makoto shinkai"
50
- )
51
- neg_prompt_input = gr.inputs.Textbox(
52
- label="Negative Prompt",
53
- placeholder=""
54
- )
55
-
56
- width_slider = gr.inputs.Slider(
57
- minimum=512, maximum=2048, default=1088, step=64, label="width"
58
- )
59
-
60
- height_slider = gr.inputs.Slider(
61
- minimum=512, maximum=2048, default=1088, step=64, label="height"
62
- )
63
-
64
- inf_steps_input = gr.inputs.Slider(
65
- minimum=1, maximum=100, default=30, step=1, label="Inference Steps"
66
- )
67
-
68
-
69
- seed_input = gr.inputs.Number(default=0, label="Seed")
70
-
71
- app = gr.Interface(
72
- fn=infer,
73
- inputs=[prompt_input, neg_prompt_input, width_slider, height_slider, inf_steps_input, seed_input],
74
- outputs="image",
75
- title="Stable Diffusion High Resolution",
76
- description=(
77
- "Based on stable diffusion 1.5 and fine-tuned on 576x576 up to 1088x1088 images, "
78
- "Stable Diffusion High Resolution is compartible with another SD1.5 model and mergeable with other SD1.5 model, "
79
- "giving other model to generate high resolution images without using upscaler."
80
- ),
81
- examples=[
82
- ["a highly detailed mansion in the autumn by studio ghibli, makoto shinkai","", 1088, 1088, 30, 0],
83
- ["best high quality landscape, in the morning light, Overlooking TOKYO beautiful city with Fujiyama, from a tall house, by greg rutkowski and thomas kinkade, Trending on artstation makoto shinkai style","", 1088, 576, 30, 0],
84
- [" assassin's creed black flag, hd, 4k, dlsr ","", 960, 960, 30, 4154731],
85
- ],
86
-
87
- )
88
-
89
- app.launch()
 
17
  use_memory_efficient_attention=True
18
  )
19
 
20
+ def infer(prompts, negative_prompts):
21
 
22
  num_samples = 1 #jax.device_count()
23
+ rng = create_key(0)
24
  rng = jax.random.split(rng, jax.device_count())
25
 
26
  prompt_ids = pipe.prepare_inputs([prompts] * num_samples)
 
33
  output = pipe(
34
  prompt_ids=prompt_ids,
35
  params=p_params,
36
+ height=1088,
37
+ width=1088,
38
  prng_seed=rng,
39
+ num_inference_steps=50,
40
  neg_prompt_ids=negative_prompt_ids,
41
  jit=True,
42
  ).images
43
 
44
  output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
45
+ return output_images
46
 
47
+ gr.Interface(infer, inputs=["text", "text"], outputs="gallery").launch()