Vincent-luo commited on
Commit
eb4334e
·
1 Parent(s): 8527cc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -3
app.py CHANGED
@@ -6,12 +6,71 @@ from PIL import Image
6
  from argparse import Namespace
7
  import gradio as gr
8
 
 
 
 
 
 
 
 
 
9
  from diffusers import (
10
  FlaxControlNetModel,
11
  FlaxStableDiffusionControlNetPipeline,
12
  )
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  args = Namespace(
16
  pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
17
  revision="non-ema",
@@ -53,7 +112,8 @@ def infer(prompt, negative_prompt, image):
53
  prompt_ids = pipeline.prepare_text_inputs(prompts)
54
  prompt_ids = shard(prompt_ids)
55
 
56
- validation_image = Image.fromarray(image).convert("RGB")
 
57
  processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
58
  processed_image = shard(processed_image)
59
 
@@ -73,7 +133,8 @@ def infer(prompt, negative_prompt, image):
73
 
74
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
75
 
76
- return images[0]
 
77
 
78
 
79
  with gr.Blocks(theme='gradio/soft') as demo:
@@ -84,7 +145,7 @@ with gr.Blocks(theme='gradio/soft') as demo:
84
  prompt_input = gr.Textbox(label="Prompt")
85
  negative_prompt = gr.Textbox(label="Negative Prompt")
86
  input_image = gr.Image(label="Input Image")
87
- output_image = gr.Image(label="Output Image")
88
  submit_btn = gr.Button(value = "Submit")
89
  inputs = [prompt_input, negative_prompt, input_image]
90
  submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])
 
6
  from argparse import Namespace
7
  import gradio as gr
8
 
9
+ import numpy as np
10
+ import mediapipe as mp
11
+ from mediapipe import solutions
12
+ from mediapipe.framework.formats import landmark_pb2
13
+ from mediapipe.tasks import python
14
+ from mediapipe.tasks.python import vision
15
+ import cv2
16
+
17
  from diffusers import (
18
  FlaxControlNetModel,
19
  FlaxStableDiffusionControlNetPipeline,
20
  )
21
 
22
 
23
+ # mediapipe annotation
24
+ MARGIN = 10 # pixels
25
+ FONT_SIZE = 1
26
+ FONT_THICKNESS = 1
27
+ HANDEDNESS_TEXT_COLOR = (88, 205, 54) # vibrant green
28
+
29
+ def draw_landmarks_on_image(rgb_image, detection_result):
30
+ hand_landmarks_list = detection_result.hand_landmarks
31
+ handedness_list = detection_result.handedness
32
+ annotated_image = np.zeros_like(rgb_image)
33
+
34
+ # Loop through the detected hands to visualize.
35
+ for idx in range(len(hand_landmarks_list)):
36
+ hand_landmarks = hand_landmarks_list[idx]
37
+ handedness = handedness_list[idx]
38
+
39
+ # Draw the hand landmarks.
40
+ hand_landmarks_proto = landmark_pb2.NormalizedLandmarkList()
41
+ hand_landmarks_proto.landmark.extend([
42
+ landmark_pb2.NormalizedLandmark(x=landmark.x, y=landmark.y, z=landmark.z) for landmark in hand_landmarks
43
+ ])
44
+ solutions.drawing_utils.draw_landmarks(
45
+ annotated_image,
46
+ hand_landmarks_proto,
47
+ solutions.hands.HAND_CONNECTIONS,
48
+ solutions.drawing_styles.get_default_hand_landmarks_style(),
49
+ solutions.drawing_styles.get_default_hand_connections_style())
50
+
51
+ return annotated_image
52
+
53
+ def generate_annotation(img):
54
+ """img(input): numpy array
55
+ annotated_image(output): numpy array
56
+ """
57
+ # STEP 2: Create an HandLandmarker object.
58
+ base_options = python.BaseOptions(model_asset_path='hand_landmarker.task')
59
+ options = vision.HandLandmarkerOptions(base_options=base_options,
60
+ num_hands=2)
61
+ detector = vision.HandLandmarker.create_from_options(options)
62
+
63
+ # STEP 3: Load the input image.
64
+ image = mp.Image(
65
+ image_format=mp.ImageFormat.SRGB, data=img)
66
+
67
+ # STEP 4: Detect hand landmarks from the input image.
68
+ detection_result = detector.detect(image)
69
+
70
+ # STEP 5: Process the classification result. In this case, visualize it.
71
+ annotated_image = draw_landmarks_on_image(image.numpy_view(), detection_result)
72
+ return annotated_image
73
+
74
  args = Namespace(
75
  pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
76
  revision="non-ema",
 
112
  prompt_ids = pipeline.prepare_text_inputs(prompts)
113
  prompt_ids = shard(prompt_ids)
114
 
115
+ annotated_image = generate_annotation(image)
116
+ validation_image = Image.fromarray(annotated_image).convert("RGB")
117
  processed_image = pipeline.prepare_image_inputs(num_samples * [validation_image])
118
  processed_image = shard(processed_image)
119
 
 
133
 
134
  images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
135
 
136
+ results = [i for i in images]
137
+ return [annotated_image] + results
138
 
139
 
140
  with gr.Blocks(theme='gradio/soft') as demo:
 
145
  prompt_input = gr.Textbox(label="Prompt")
146
  negative_prompt = gr.Textbox(label="Negative Prompt")
147
  input_image = gr.Image(label="Input Image")
148
+ output_image = gr.Gallery(label='Output Image', show_label=False, elem_id="gallery").style(grid=3, height='auto')
149
  submit_btn = gr.Button(value = "Submit")
150
  inputs = [prompt_input, negative_prompt, input_image]
151
  submit_btn.click(fn=infer, inputs=inputs, outputs=[output_image])