Spaces:
Running
Running
Colin Leong
commited on
Commit
·
6dbb7a6
1
Parent(s):
6e290b7
Updates to the app such that downloading and visualization work more intuitively
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ from pose_format.pose_visualizer import PoseVisualizer
|
|
7 |
from pathlib import Path
|
8 |
from pyzstd import decompress
|
9 |
from PIL import Image
|
10 |
-
import cv2
|
11 |
import mediapipe as mp
|
12 |
|
13 |
mp_holistic = mp.solutions.holistic
|
@@ -18,6 +17,10 @@ FACEMESH_CONTOURS_POINTS = [
|
|
18 |
)
|
19 |
]
|
20 |
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def pose_normalization_info(pose_header):
|
23 |
if pose_header.components[0].name == "POSE_LANDMARKS":
|
@@ -76,12 +79,12 @@ def get_pose_frames(pose: Pose, transparency: bool = False):
|
|
76 |
return frames, images
|
77 |
|
78 |
|
79 |
-
def get_pose_gif(pose: Pose, step: int = 1, fps: int = None):
|
80 |
if fps is not None:
|
81 |
pose.body.fps = fps
|
82 |
v = PoseVisualizer(pose)
|
83 |
frames = [frame_data for frame_data in v.draw()]
|
84 |
-
frames = frames[
|
85 |
return v.save_gif(None, frames=frames)
|
86 |
|
87 |
|
@@ -110,28 +113,64 @@ if uploaded_file is not None:
|
|
110 |
"How to select components?", options=["manual", "signclip"]
|
111 |
)
|
112 |
|
|
|
|
|
|
|
|
|
|
|
113 |
if component_selection == "manual":
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
)
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
elif component_selection == "signclip":
|
121 |
st.write("Selected landmarks used for SignCLIP.")
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
|
127 |
# Filter button logic
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
with st.expander("Show header"):
|
131 |
-
st.write(
|
132 |
with st.expander("Show body"):
|
133 |
-
st.write(
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
135 |
pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
|
136 |
with pose_file_out.open("wb") as f:
|
137 |
pose.write(f)
|
@@ -139,22 +178,19 @@ if uploaded_file is not None:
|
|
139 |
with pose_file_out.open("rb") as f:
|
140 |
st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
# Visualization button logic
|
143 |
if st.button("Visualize"):
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
st.image(get_pose_gif(pose=
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
# st.write(pose.body.data.shape)
|
154 |
-
|
155 |
-
# st.write(visualize_pose(pose=pose)) # bunch of ndarrays
|
156 |
-
# st.write([Image.fromarray(v.cv2.cvtColor(frame, cv_code)) for frame in frames])
|
157 |
-
|
158 |
-
# for i, image in enumerate(images[::n]):
|
159 |
-
# print(f"i={i}")
|
160 |
-
# st.image(image=image, width=width)
|
|
|
7 |
from pathlib import Path
|
8 |
from pyzstd import decompress
|
9 |
from PIL import Image
|
|
|
10 |
import mediapipe as mp
|
11 |
|
12 |
mp_holistic = mp.solutions.holistic
|
|
|
17 |
)
|
18 |
]
|
19 |
|
20 |
+
# Initialize session state
|
21 |
+
# if "filtered_pose" not in st.session_state:
|
22 |
+
# st.session_state.filtered_pose = None
|
23 |
+
|
24 |
|
25 |
def pose_normalization_info(pose_header):
|
26 |
if pose_header.components[0].name == "POSE_LANDMARKS":
|
|
|
79 |
return frames, images
|
80 |
|
81 |
|
82 |
+
def get_pose_gif(pose: Pose, step: int = 1, start_frame:int=None, end_frame:int=None, fps: int = None):
|
83 |
if fps is not None:
|
84 |
pose.body.fps = fps
|
85 |
v = PoseVisualizer(pose)
|
86 |
frames = [frame_data for frame_data in v.draw()]
|
87 |
+
frames = frames[start_frame:end_frame:step]
|
88 |
return v.save_gif(None, frames=frames)
|
89 |
|
90 |
|
|
|
113 |
"How to select components?", options=["manual", "signclip"]
|
114 |
)
|
115 |
|
116 |
+
component_names = [c.name for c in pose.header.components]
|
117 |
+
chosen_component_names = []
|
118 |
+
points_dict = {}
|
119 |
+
hide_legs = False
|
120 |
+
|
121 |
if component_selection == "manual":
|
122 |
+
|
123 |
+
|
124 |
+
chosen_component_names = st.pills(
|
125 |
+
"Select components to visualize", options=component_names, default=component_names,selection_mode="multi"
|
126 |
)
|
127 |
+
|
128 |
+
for component in pose.header.components:
|
129 |
+
if component.name in chosen_component_names:
|
130 |
+
with st.expander(f"Points for {component.name}"):
|
131 |
+
selected_points = st.multiselect(
|
132 |
+
f"Select points for component {component.name}:",
|
133 |
+
options=component.points,
|
134 |
+
default=component.points,
|
135 |
+
)
|
136 |
+
if selected_points != component.points: # Only add entry if not all points are selected
|
137 |
+
points_dict[component.name] = selected_points
|
138 |
+
|
139 |
+
|
140 |
|
141 |
elif component_selection == "signclip":
|
142 |
st.write("Selected landmarks used for SignCLIP.")
|
143 |
+
chosen_component_names = ["POSE_LANDMARKS", "FACE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]
|
144 |
+
points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
|
145 |
+
|
146 |
+
|
147 |
|
148 |
# Filter button logic
|
149 |
+
# Filter section
|
150 |
+
st.write("### Filter .pose File")
|
151 |
+
filtered = st.button("Filter")
|
152 |
+
if filtered:
|
153 |
+
pose = pose.get_components(chosen_component_names, points=points_dict if points_dict else None)
|
154 |
+
if hide_legs:
|
155 |
+
pose = pose_hide_legs(pose)
|
156 |
+
|
157 |
+
st.session_state.filtered_pose = pose
|
158 |
+
|
159 |
+
filtered_pose = st.session_state.get('filtered_pose', pose)
|
160 |
+
if filtered_pose:
|
161 |
+
filtered_pose = st.session_state.get('filtered_pose', pose)
|
162 |
+
st.write(f"#### Filtered .pose file")
|
163 |
+
st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
|
164 |
with st.expander("Show header"):
|
165 |
+
st.write(filtered_pose.header)
|
166 |
with st.expander("Show body"):
|
167 |
+
st.write(filtered_pose.body)
|
168 |
+
# with st.expander("Show data:"):
|
169 |
+
# for frame in filtered_pose.body.data:
|
170 |
+
# st.write(f"Frame:{frame}")
|
171 |
+
# for person in frame:
|
172 |
+
# st.write(person)
|
173 |
+
|
174 |
pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
|
175 |
with pose_file_out.open("wb") as f:
|
176 |
pose.write(f)
|
|
|
178 |
with pose_file_out.open("rb") as f:
|
179 |
st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
|
180 |
|
181 |
+
|
182 |
+
st.write("### Visualization")
|
183 |
+
step = st.select_slider("Step value to select every nth image", list(range(1, len(frames))), value=1)
|
184 |
+
fps = st.slider("FPS for visualization", min_value=1.0, max_value=filtered_pose.body.fps, value=filtered_pose.body.fps)
|
185 |
+
start_frame, end_frame = st.slider(
|
186 |
+
"Select Frame Range",
|
187 |
+
0,
|
188 |
+
len(frames),
|
189 |
+
(0, len(frames)), # Default range
|
190 |
+
)
|
191 |
# Visualization button logic
|
192 |
if st.button("Visualize"):
|
193 |
+
# Load filtered pose if it exists; otherwise, use the unfiltered pose
|
194 |
+
|
195 |
+
|
196 |
+
st.image(get_pose_gif(pose=filtered_pose, step=step, start_frame=start_frame, end_frame=end_frame, fps=fps))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|