Spaces:
Runtime error
Runtime error
import jax | |
import jax.numpy as jnp | |
from flax import jax_utils | |
from flax.training.common_utils import shard | |
from PIL import Image | |
from argparse import Namespace | |
import gradio as gr | |
import numpy as np | |
import mediapipe as mp | |
from mediapipe import solutions | |
from mediapipe.framework.formats import landmark_pb2 | |
from mediapipe.tasks import python | |
from mediapipe.tasks.python import vision | |
import cv2 | |
from diffusers import ( | |
FlaxControlNetModel, | |
FlaxStableDiffusionControlNetPipeline, | |
) | |
# mediapipe annotation | |
MARGIN = 10 # pixels | |
FONT_SIZE = 1 | |
FONT_THICKNESS = 1 | |
HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green | |
def draw_landmarks_on_image(rgb_image, detection_result): | |
hand_landmarks_list = detection_result.hand_landmarks | |
handedness_list = detection_result.handedness | |
annotated_image = np.zeros_like(rgb_image) | |
# Loop through the detected hands to visualize. | |
for idx in range(len(hand_landmarks_list)): | |
hand_landmarks = hand_landmarks_list[idx] | |
handedness = handedness_list[idx] | |
# Draw the hand landmarks. | |
hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList() | |
hand_landmarks_proto.landmark.extend([ | |
landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks | |
]) | |
solutions.drawing_utils.draw_landmarks( | |
annotated_image, | |
hand_landmarks_proto, | |
solutions.hands.HAND_CONNECTIONS, | |
solutions.drawing_styles.get_default_hand_landmarks_style(), | |
solutions.drawing_styles.get_default_hand_connections_style()) | |
return annotated_image | |
def generate_annotation(img): | |
"""img(input): numpy array | |
annotated_image(output): numpy array | |
""" | |
# STEP 2: Create an HandLandmarker object. | |
base_options = python.BaseOptions(model_asset_path='hand_landmarker.task') | |
options = vision.HandLandmarkerOptions(base_options=base_options, | |
num_hands=2) | |
detector = vision.HandLandmarker.create_from_options(options) | |
# STEP 3: Load the input image. | |
image = mp.Image( | |
image_format=mp.ImageFormat.SRGB, data=img) | |
# STEP 4: Detect hand landmarks from the input image. | |
detection_result = detector.detect(image) | |
# STEP 5: Process the classification result. In this case, visualize it. | |
annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result) | |
return annotated_image | |
args = Namespace( | |
pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", | |
revision="non-ema", | |
from_pt=True, | |
controlnet_model_name_or_path="Vincent-luo/controlnet-hands", | |
controlnet_revision=None, | |
controlnet_from_pt=False, | |
) | |
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( | |
args.controlnet_model_name_or_path, | |
revision=args.controlnet_revision, | |
from_pt=args.controlnet_from_pt, | |
dtype=jnp.bfloat16, | |
) | |
pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained( | |
args.pretrained_model_name_or_path, | |
# tokenizer=tokenizer, | |
controlnet=controlnet, | |
safety_checker=None, | |
dtype=jnp.bfloat16, | |
revision=args.revision, | |
from_pt=args.from_pt, | |
) | |
pipeline_params["controlnet"] = controlnet_params | |
pipeline_params = jax_utils.replicate(pipeline_params) | |
rng = jax.random.PRNGKey(0) | |
num_samples = jax.device_count() | |
prng_seed = jax.random.split(rng, jax.device_count()) | |
def infer(prompt, negative_prompt, image): | |
prompts = num_samples * [prompt] | |
prompt_ids = pipeline.prepare_text_inputs(prompts) | |
prompt_ids = shard(prompt_ids) | |
annotated_image = generate_annotation(image) | |
validation_image = Image.fromarray(annotated_image).convert("RGB") | |
processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image]) | |
processed_image = shard(processed_image) | |
negative_prompt_ids = pipeline.prepare_text_inputs([negative_prompt] * num_samples) | |
negative_prompt_ids = shard(negative_prompt_ids) | |
images = pipeline( | |
prompt_ids=prompt_ids, | |
image=processed_image, | |
params=pipeline_params, | |
prng_seed=prng_seed, | |
num_inference_steps=50, | |
neg_prompt_ids=negative_prompt_ids, | |
jit=True, | |
).images | |
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) | |
results = [i for i in images] | |
return [annotated_image] + results | |
with gr.Blocks(theme='gradio/soft') as demo: | |
gr.Markdown("## Stable Diffusion with Hand Control") | |
with gr.Column(): | |
prompt_input = gr.Textbox(label="Prompt") | |
negative_prompt = gr.Textbox(label="Negative Prompt") | |
input_image = gr.Image(label="Input Image") | |
output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto') | |
submit_btn = gr.Button(value = "Submit") | |
inputs = [prompt_input, negative_prompt, input_image] | |
submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image]) | |
demo.launch() |