|
|
|
|
|
from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel, FlaxDPMSolverMultistepScheduler |
|
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed |
|
from flax.training.common_utils import shard |
|
from flax.jax_utils import replicate |
|
from diffusers.utils import load_image |
|
import jax.numpy as jnp |
|
import jax |
|
import cv2 |
|
from PIL import Image |
|
import numpy as np |
|
import gradio as gr |
|
|
|
def create_key(seed=0): |
|
return jax.random.PRNGKey(seed) |
|
|
|
def load_controlnet(controlnet_version): |
|
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( |
|
"Baptlem/baptlem-controlnet", |
|
subfolder=controlnet_version, |
|
from_flax=True, |
|
dtype=jnp.float32, |
|
) |
|
return controlnet, controlnet_params |
|
|
|
|
|
def load_sb_pipe(controlnet_version, sb_path="runwayml/stable-diffusion-v1-5"): |
|
controlnet, controlnet_params = load_controlnet(controlnet_version) |
|
|
|
scheduler, scheduler_params = FlaxDPMSolverMultistepScheduler.from_pretrained( |
|
base_model_path, |
|
subfolder="scheduler" |
|
) |
|
|
|
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( |
|
sb_path, |
|
controlnet=controlnet, |
|
revision="flax", |
|
dtype=jnp.bfloat16 |
|
) |
|
|
|
pipe.scheduler = scheduler |
|
params["controlnet"] = controlnet_params |
|
params["scheduler"] = scheduler_params |
|
return pipe, params |
|
|
|
|
|
|
|
controlnet_path = "Baptlem/baptlem-controlnet" |
|
controlnet_version = "coyo-500k" |
|
|
|
|
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
pipe, params = load_sb_pipe(controlnet_version) |
|
|
|
|
|
|
|
|
|
|
|
def pipe_inference( |
|
image, |
|
prompt, |
|
is_canny=False, |
|
num_samples=4, |
|
resolution=128, |
|
num_inference_steps=50, |
|
guidance_scale=7.5, |
|
seed=0, |
|
negative_prompt="", |
|
): |
|
|
|
if not isinstance(image, np.ndarray): |
|
image = np.array(image) |
|
|
|
resized_image = resize_image(image, resolution) |
|
|
|
if not is_canny: |
|
resized_image = preprocess_canny(resized_image, resolution) |
|
|
|
rng = create_key(seed) |
|
rng = jax.random.split(rng, jax.device_count()) |
|
|
|
prompt_ids = pipe.prepare_text_inputs([prompt] * num_samples) |
|
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompt] * num_samples) |
|
processed_image = pipe.prepare_image_inputs([resized_image] * num_samples) |
|
|
|
p_params = replicate(params) |
|
prompt_ids = shard(prompt_ids) |
|
negative_prompt_ids = shard(negative_prompt_ids) |
|
processed_image = shard(processed_image) |
|
|
|
output = pipe( |
|
prompt_ids=prompt_ids, |
|
image=processed_image, |
|
params=p_params, |
|
prng_seed=rng, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
neg_prompt_ids=negative_prompt_ids, |
|
jit=True, |
|
) |
|
all_outputs = [] |
|
all_outputs.append(image) |
|
if not is_canny: |
|
all_outputs.append(resized_image) |
|
|
|
for image in output.images: |
|
all_outputs.append(image) |
|
return all_outputs |
|
|
|
def resize_image(image, resolution): |
|
h, w = image.shape |
|
ratio = w/h |
|
if ratio > 1 : |
|
resized_image = cv2.resize(image, (int(resolution*ratio), resolution), interpolation=cv2.INTER_NEAREST) |
|
elif ratio < 1 : |
|
resized_image = cv2.resize(image, (resolution, int(resolution/ratio)), interpolation=cv2.INTER_NEAREST) |
|
else: |
|
resized_image = cv2.resize(image, (resolution, resolution), interpolation=cv2.INTER_NEAREST) |
|
return resized_image |
|
|
|
|
|
def preprocess_canny(image, resolution=128): |
|
processed_image = cv2.Canny(resized_image, low_threshold, high_threshold) |
|
processed_image = processed_image[:, :, None] |
|
processed_image = np.concatenate([processed_image, processed_image, processed_image], axis=2) |
|
|
|
resized_image = Image.fromarray(resized_image) |
|
processed_image = Image.fromarray(processed_image) |
|
return resized_image, processed_image |
|
|
|
|
|
def create_demo(process, max_images=12, default_num_images=4): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
gr.Markdown('## Control Stable Diffusion with Canny Edge Maps') |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(source='upload', type='numpy') |
|
prompt = gr.Textbox(label='Prompt') |
|
run_button = gr.Button(label='Run') |
|
with gr.Accordion('Advanced options', open=False): |
|
is_canny = gr.Checkbox( |
|
label='Is canny', value=False) |
|
num_samples = gr.Slider(label='Images', |
|
minimum=1, |
|
maximum=max_images, |
|
value=default_num_images, |
|
step=1) |
|
""" |
|
canny_low_threshold = gr.Slider( |
|
label='Canny low threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=100, |
|
step=1) |
|
canny_high_threshold = gr.Slider( |
|
label='Canny high threshold', |
|
minimum=1, |
|
maximum=255, |
|
value=200, |
|
step=1) |
|
""" |
|
resolution = gr.Slider(label='Resolution', |
|
minimum=128, |
|
maximum=128, |
|
value=128, |
|
step=1) |
|
num_steps = gr.Slider(label='Steps', |
|
minimum=1, |
|
maximum=100, |
|
value=20, |
|
step=1) |
|
guidance_scale = gr.Slider(label='Guidance Scale', |
|
minimum=0.1, |
|
maximum=30.0, |
|
value=7.5, |
|
step=0.1) |
|
seed = gr.Slider(label='Seed', |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
randomize=True) |
|
n_prompt = gr.Textbox( |
|
label='Negative Prompt', |
|
value= |
|
'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality' |
|
) |
|
with gr.Column(): |
|
result = gr.Gallery(label='Output', |
|
show_label=False, |
|
elem_id='gallery').style(grid=2, |
|
height='auto') |
|
inputs = [ |
|
input_image, |
|
prompt, |
|
is_canny, |
|
num_samples, |
|
resolution, |
|
|
|
|
|
num_steps, |
|
guidance_scale, |
|
seed, |
|
n_prompt, |
|
] |
|
prompt.submit(fn=process, inputs=inputs, outputs=result) |
|
run_button.click(fn=process, |
|
inputs=inputs, |
|
outputs=result, |
|
api_name='canny') |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
pipe_inference |
|
demo = create_demo(pipe_inference) |
|
demo.queue().launch() |
|
|
|
|