pcuenq HF staff commited on
Commit
75a8299
1 Parent(s): 7a74ef3

replicate params once

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -24,10 +24,10 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
24
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
25
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
26
  )
 
 
27
 
28
  def infer(prompts, negative_prompts, image):
29
- params["controlnet"] = controlnet_params
30
-
31
  num_samples = 1 #jax.device_count()
32
  rng = create_key(0)
33
  rng = jax.random.split(rng, jax.device_count())
@@ -38,7 +38,6 @@ def infer(prompts, negative_prompts, image):
38
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
39
  processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
40
 
41
- p_params = replicate(params)
42
  prompt_ids = shard(prompt_ids)
43
  negative_prompt_ids = shard(negative_prompt_ids)
44
  processed_image = shard(processed_image)
 
24
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
25
  "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
26
  )
27
+ params["controlnet"] = controlnet_params
28
+ p_params = replicate(params)
29
 
30
  def infer(prompts, negative_prompts, image):
 
 
31
  num_samples = 1 #jax.device_count()
32
  rng = create_key(0)
33
  rng = jax.random.split(rng, jax.device_count())
 
38
  negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
39
  processed_image = pipe.prepare_image_inputs([canny_image] * num_samples)
40
 
 
41
  prompt_ids = shard(prompt_ids)
42
  negative_prompt_ids = shard(negative_prompt_ids)
43
  processed_image = shard(processed_image)