File size: 2,155 Bytes
2692b6e
822d597
 
 
 
 
 
 
 
 
2692b6e
 
dbc73d9
822d597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bed8cd
822d597
 
 
 
 
 
 
5bed8cd
822d597
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2

title = "ControlNet for Cartoon-ifying"
description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]



# Constants
low_threshold = 100
high_threshold = 200

base_model_path = "runwayml/stable-diffusion-v1-5"
controlnet_path = "lmattingly/controlnet-uncanny-simpsons"
#controlnet_path = "JFoz/dog-cat-pose"

# Models
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    controlnet_path, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def infer(prompts, image):
    params["controlnet"] = controlnet_params
    
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    im = image
    image = Image.fromarray(im)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=5,
        jit=True,
    ).images
    
    output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return output_images


gr.Interface(fn = infer, inputs = ["text", "image"], outputs = "gallery",
             title = title, description = description, theme='gradio/soft',
             examples=[["a simpsons cartoon character", "simpsons_human_1.jpg"]]
).launch()