File size: 2,539 Bytes
2692b6e
822d597
 
 
 
 
 
 
 
 
049b8a4
 
2692b6e
 
dbc73d9
822d597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
049b8a4
 
 
 
 
 
 
 
 
e28884c
049b8a4
e28884c
049b8a4
822d597
 
 
 
 
 
1dd294a
 
822d597
 
 
049b8a4
 
822d597
 
 
 
 
 
 
 
 
 
 
 
 
5bed8cd
822d597
 
 
 
 
 
 
5bed8cd
822d597
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
import cv2



title = "ControlNet for Cartoon-ifying"
description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]



# Constants
low_threshold = 100
high_threshold = 200

base_model_path = "runwayml/stable-diffusion-v1-5"
controlnet_path = "lmattingly/controlnet-uncanny-simpsons"
#controlnet_path = "JFoz/dog-cat-pose"

# Models
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
    controlnet_path, dtype=jnp.bfloat16
)
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
)

def resize_image(im, max_size):
    im_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    height, width = im_np.shape[:2]

    scale_factor = max_size / max(height, width)

    resized_np = cv2.resize(im_np, (int(width * scale_factor), int(height * scale_factor)))

    #resized_im = Image.fromarray(resized_np)

    return resized_np

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

def infer(prompts, image):
    params["controlnet"] = controlnet_params
    
    im = image
    image = resize_image(im, 500)
    num_samples = 1 #jax.device_count()
    rng = create_key(0)
    rng = jax.random.split(rng, jax.device_count())
    #im = image
    #image = Image.fromarray(im)
    
    prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
    processed_image = pipe.prepare_image_inputs([image] * num_samples)
    
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    processed_image = shard(processed_image)
    
    output = pipe(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        num_inference_steps=5,
        jit=True,
    ).images
    
    output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
    return output_images


gr.Interface(fn = infer, inputs = ["text", "image"], outputs = "gallery",
             title = title, description = description, theme='gradio/soft',
             examples=[["a simpsons cartoon character", "simpsons_human_1.jpg"]]
).launch()