Colin Leong commited on
Commit
6dbb7a6
·
1 Parent(s): 6e290b7

Updates to the app such that downloading and visualization work more intuitively

Browse files
Files changed (1) hide show
  1. app.py +69 -33
app.py CHANGED
@@ -7,7 +7,6 @@ from pose_format.pose_visualizer import PoseVisualizer
7
  from pathlib import Path
8
  from pyzstd import decompress
9
  from PIL import Image
10
- import cv2
11
  import mediapipe as mp
12
 
13
  mp_holistic = mp.solutions.holistic
@@ -18,6 +17,10 @@ FACEMESH_CONTOURS_POINTS = [
18
  )
19
  ]
20
 
 
 
 
 
21
 
22
  def pose_normalization_info(pose_header):
23
  if pose_header.components[0].name == "POSE_LANDMARKS":
@@ -76,12 +79,12 @@ def get_pose_frames(pose: Pose, transparency: bool = False):
76
  return frames, images
77
 
78
 
79
- def get_pose_gif(pose: Pose, step: int = 1, fps: int = None):
80
  if fps is not None:
81
  pose.body.fps = fps
82
  v = PoseVisualizer(pose)
83
  frames = [frame_data for frame_data in v.draw()]
84
- frames = frames[::step]
85
  return v.save_gif(None, frames=frames)
86
 
87
 
@@ -110,28 +113,64 @@ if uploaded_file is not None:
110
  "How to select components?", options=["manual", "signclip"]
111
  )
112
 
 
 
 
 
 
113
  if component_selection == "manual":
114
- chosen_component_names = st.multiselect(
115
- "Select components to visualize", options=[c.name for c in pose.header.components]
 
 
116
  )
117
- if chosen_component_names:
118
- pose = pose.get_components(chosen_component_names)
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  elif component_selection == "signclip":
121
  st.write("Selected landmarks used for SignCLIP.")
122
- pose = pose.get_components(
123
- ["POSE_LANDMARKS", "FACE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]
124
- )
125
- pose = pose_hide_legs(pose)
126
 
127
  # Filter button logic
128
- if st.button("Filter Components/Points"):
129
- st.write("### Filtered .pose file")
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  with st.expander("Show header"):
131
- st.write(pose.header)
132
  with st.expander("Show body"):
133
- st.write(pose.body)
134
-
 
 
 
 
 
135
  pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
136
  with pose_file_out.open("wb") as f:
137
  pose.write(f)
@@ -139,22 +178,19 @@ if uploaded_file is not None:
139
  with pose_file_out.open("rb") as f:
140
  st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
141
 
 
 
 
 
 
 
 
 
 
 
142
  # Visualization button logic
143
  if st.button("Visualize"):
144
- st.write("### Visualization")
145
- step = st.select_slider("Step value to select every nth image", list(range(1, len(frames))), value=1)
146
- fps = st.slider("FPS for visualization", min_value=1.0, max_value=pose.body.fps, value=pose.body.fps)
147
- st.image(get_pose_gif(pose=pose, step=step, fps=fps))
148
-
149
-
150
-
151
-
152
-
153
- # st.write(pose.body.data.shape)
154
-
155
- # st.write(visualize_pose(pose=pose)) # bunch of ndarrays
156
- # st.write([Image.fromarray(v.cv2.cvtColor(frame, cv_code)) for frame in frames])
157
-
158
- # for i, image in enumerate(images[::n]):
159
- # print(f"i={i}")
160
- # st.image(image=image, width=width)
 
7
  from pathlib import Path
8
  from pyzstd import decompress
9
  from PIL import Image
 
10
  import mediapipe as mp
11
 
12
  mp_holistic = mp.solutions.holistic
 
17
  )
18
  ]
19
 
20
+ # Initialize session state
21
+ # if "filtered_pose" not in st.session_state:
22
+ # st.session_state.filtered_pose = None
23
+
24
 
25
  def pose_normalization_info(pose_header):
26
  if pose_header.components[0].name == "POSE_LANDMARKS":
 
79
  return frames, images
80
 
81
 
82
+ def get_pose_gif(pose: Pose, step: int = 1, start_frame:int=None, end_frame:int=None, fps: int = None):
83
  if fps is not None:
84
  pose.body.fps = fps
85
  v = PoseVisualizer(pose)
86
  frames = [frame_data for frame_data in v.draw()]
87
+ frames = frames[start_frame:end_frame:step]
88
  return v.save_gif(None, frames=frames)
89
 
90
 
 
113
  "How to select components?", options=["manual", "signclip"]
114
  )
115
 
116
+ component_names = [c.name for c in pose.header.components]
117
+ chosen_component_names = []
118
+ points_dict = {}
119
+ hide_legs = False
120
+
121
  if component_selection == "manual":
122
+
123
+
124
+ chosen_component_names = st.pills(
125
+ "Select components to visualize", options=component_names, default=component_names,selection_mode="multi"
126
  )
127
+
128
+ for component in pose.header.components:
129
+ if component.name in chosen_component_names:
130
+ with st.expander(f"Points for {component.name}"):
131
+ selected_points = st.multiselect(
132
+ f"Select points for component {component.name}:",
133
+ options=component.points,
134
+ default=component.points,
135
+ )
136
+ if selected_points != component.points: # Only add entry if not all points are selected
137
+ points_dict[component.name] = selected_points
138
+
139
+
140
 
141
  elif component_selection == "signclip":
142
  st.write("Selected landmarks used for SignCLIP.")
143
+ chosen_component_names = ["POSE_LANDMARKS", "FACE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"]
144
+ points_dict = {"FACE_LANDMARKS": FACEMESH_CONTOURS_POINTS}
145
+
146
+
147
 
148
  # Filter button logic
149
+ # Filter section
150
+ st.write("### Filter .pose File")
151
+ filtered = st.button("Filter")
152
+ if filtered:
153
+ pose = pose.get_components(chosen_component_names, points=points_dict if points_dict else None)
154
+ if hide_legs:
155
+ pose = pose_hide_legs(pose)
156
+
157
+ st.session_state.filtered_pose = pose
158
+
159
+ filtered_pose = st.session_state.get('filtered_pose', pose)
160
+ if filtered_pose:
161
+ filtered_pose = st.session_state.get('filtered_pose', pose)
162
+ st.write(f"#### Filtered .pose file")
163
+ st.write(f"Pose data shape: {filtered_pose.body.data.shape}")
164
  with st.expander("Show header"):
165
+ st.write(filtered_pose.header)
166
  with st.expander("Show body"):
167
+ st.write(filtered_pose.body)
168
+ # with st.expander("Show data:"):
169
+ # for frame in filtered_pose.body.data:
170
+ # st.write(f"Frame:{frame}")
171
+ # for person in frame:
172
+ # st.write(person)
173
+
174
  pose_file_out = Path(uploaded_file.name).with_suffix(".pose")
175
  with pose_file_out.open("wb") as f:
176
  pose.write(f)
 
178
  with pose_file_out.open("rb") as f:
179
  st.download_button("Download Filtered Pose", f, file_name=pose_file_out.name)
180
 
181
+
182
+ st.write("### Visualization")
183
+ step = st.select_slider("Step value to select every nth image", list(range(1, len(frames))), value=1)
184
+ fps = st.slider("FPS for visualization", min_value=1.0, max_value=filtered_pose.body.fps, value=filtered_pose.body.fps)
185
+ start_frame, end_frame = st.slider(
186
+ "Select Frame Range",
187
+ 0,
188
+ len(frames),
189
+ (0, len(frames)), # Default range
190
+ )
191
  # Visualization button logic
192
  if st.button("Visualize"):
193
+ # Load filtered pose if it exists; otherwise, use the unfiltered pose
194
+
195
+
196
+ st.image(get_pose_gif(pose=filtered_pose, step=step, start_frame=start_frame, end_frame=end_frame, fps=fps))