ygauravyy's picture
Upload 722 files
d49f7bc verified
raw
history blame
11.8 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations # so we can refer to class Type inside class
import logging
from pathlib import Path
from typing import List, Tuple, Optional
import numpy as np
import numpy.typing as npt
from animated_drawings.model.transform import Transform
from animated_drawings.model.box import Box
from animated_drawings.model.quaternions import Quaternions
from animated_drawings.model.vectors import Vectors
from animated_drawings.model.joint import Joint
from animated_drawings.model.time_manager import TimeManager
from animated_drawings.utils import resolve_ad_filepath
class BVH_Joint(Joint):
"""
Joint class with channel order attribute and specialized vis widget
"""
def __init__(self, channel_order: List[str] = [], widget: bool = True, **kwargs) -> None:
super().__init__(**kwargs)
self.channel_order = channel_order
self.widget: Optional[Transform] = None
if widget:
self.widget = Box()
self.add_child(self.widget)
def _draw(self, **kwargs):
if self.widget:
self.widget.draw(**kwargs)
class BVH(Transform, TimeManager):
"""
Class to encapsulate BVH (Biovision Hierarchy) animation data.
Include a single skeletal hierarchy defined in the BVH, frame count and speed,
and skeletal pos/rot data for each frame
"""
def __init__(self,
name: str,
root_joint: BVH_Joint,
frame_max_num: int,
frame_time: float,
pos_data: npt.NDArray[np.float32],
rot_data: npt.NDArray[np.float32]
) -> None:
"""
Don't recommend calling this method directly. Instead, use BVH.from_file().
"""
super().__init__()
self.name: str = name
self.frame_max_num: int = frame_max_num
self.frame_time: float = frame_time
self.pos_data: npt.NDArray[np.float32] = pos_data
self.rot_data: npt.NDArray[np.float32] = rot_data
self.root_joint = root_joint
self.add_child(self.root_joint)
self.joint_num = self.root_joint.joint_count()
self.cur_frame = 0 # initialize skeleton pose to first frame
self.apply_frame(self.cur_frame)
def get_joint_names(self) -> List[str]:
""" Get names of joints in skeleton in the order in which BVH rotation data is stored. """
return self.root_joint.get_chain_joint_names()
def update(self) -> None:
"""Based upon internal time, determine which frame should be displayed and apply it"""
cur_time: float = self.get_time()
cur_frame = round(cur_time / self.frame_time) % self.frame_max_num
self.apply_frame(cur_frame)
def apply_frame(self, frame_num: int) -> None:
""" Apply root position and joint rotation data for specified frame_num """
self.root_joint.set_position(self.pos_data[frame_num])
self._apply_frame_rotations(self.root_joint, frame_num, ptr=np.array(0))
def _apply_frame_rotations(self, joint: BVH_Joint, frame_num: int, ptr: npt.NDArray[np.int32]) -> None:
q = Quaternions(self.rot_data[frame_num, ptr])
joint.set_rotation(q)
ptr += 1
for c in joint.get_children():
if not isinstance(c, BVH_Joint):
continue
self._apply_frame_rotations(c, frame_num, ptr)
def get_skeleton_fwd(self, forward_perp_vector_joint_names: List[Tuple[str, str]], update: bool = True) -> Vectors:
"""
Get current forward vector of skeleton in world coords. If update=True, ensure skeleton transforms are current.
Input forward_perp_vector_joint_names, a list of pairs of joint names (e.g. [[leftshould, rightshoulder], [lefthip, righthip]])
Finds average of vectors between joint pairs, then returns vector perpendicular to their average.
"""
if update:
self.root_joint.update_transforms(update_ancestors=True)
vectors_cw_perpendicular_to_fwd: List[Vectors] = []
for (start_joint_name, end_joint_name) in forward_perp_vector_joint_names:
start_joint = self.root_joint.get_transform_by_name(start_joint_name)
if not start_joint:
msg = f'Could not find BVH joint with name: {start_joint_name}'
logging.critical(msg)
assert False, msg
end_joint = self.root_joint.get_transform_by_name(end_joint_name)
if not end_joint:
msg = f'Could not find BVH joint with name: {end_joint_name}'
logging.critical(msg)
assert False, msg
bone_vector: Vectors = Vectors(end_joint.get_world_position()) - Vectors(start_joint.get_world_position())
bone_vector.norm()
vectors_cw_perpendicular_to_fwd.append(bone_vector)
return Vectors(vectors_cw_perpendicular_to_fwd).average().perpendicular()
@classmethod
def from_file(cls, bvh_fn: str, start_frame_idx: int = 0, end_frame_idx: Optional[int] = None) -> BVH:
""" Given a path to a .bvh, constructs and returns BVH object"""
# search for the BVH file specified
bvh_p: Path = resolve_ad_filepath(bvh_fn, 'bvh file')
logging.info(f'Using BVH file located at {bvh_p.resolve()}')
with open(str(bvh_p), 'r') as f:
lines = f.read().splitlines()
if lines.pop(0) != 'HIERARCHY':
msg = f'Malformed BVH in line preceding {lines}'
logging.critical(msg)
assert False, msg
# Parse the skeleton
root_joint: BVH_Joint = BVH._parse_skeleton(lines)
if lines.pop(0) != 'MOTION':
msg = f'Malformed BVH in line preceding {lines}'
logging.critical(msg)
assert False, msg
# Parse motion metadata
frame_max_num = int(lines.pop(0).split(':')[-1])
frame_time = float(lines.pop(0).split(':')[-1])
# Parse motion data
frames = [list(map(float, line.strip().split(' '))) for line in lines]
if len(frames) != frame_max_num:
msg = f'framenum specified ({frame_max_num}) and found ({len(frames)}) do not match'
logging.critical(msg)
assert False, msg
# Split logically distinct root position data from joint euler angle rotation data
pos_data: npt.NDArray[np.float32]
rot_data: npt.NDArray[np.float32]
pos_data, rot_data = BVH._process_frame_data(root_joint, frames)
# Set end_frame if not passed in
if not end_frame_idx:
end_frame_idx = frame_max_num
# Ensure end_frame_idx <= frame_max_num
if frame_max_num < end_frame_idx:
msg = f'config specified end_frame_idx > bvh frame_max_num ({end_frame_idx} > {frame_max_num}). Replacing with frame_max_num.'
logging.warning(msg)
end_frame_idx = frame_max_num
# slice position and rotation data using start and end frame indices
pos_data = pos_data[start_frame_idx:end_frame_idx, :]
rot_data = rot_data[start_frame_idx:end_frame_idx, :]
# new frame_max_num based is end_frame_idx minus start_frame_idx
frame_max_num = end_frame_idx - start_frame_idx
return BVH(bvh_p.name, root_joint, frame_max_num, frame_time, pos_data, rot_data)
@classmethod
def _parse_skeleton(cls, lines: List[str]) -> BVH_Joint:
"""
Called recursively to parse and construct skeleton from BVH
:param lines: partially-processed contents of BVH file. Is modified in-place.
:return: Joint
"""
# Get the joint name
if lines[0].strip().startswith('ROOT'):
_, joint_name = lines.pop(0).strip().split(' ')
elif lines[0].strip().startswith('JOINT'):
_, joint_name = lines.pop(0).strip().split(' ')
elif lines[0].strip().startswith('End Site'):
joint_name = lines.pop(0).strip()
else:
msg = f'Malformed BVH. Line: {lines[0]}'
logging.critical(msg)
assert False, msg
if lines.pop(0).strip() != '{':
msg = f'Malformed BVH in line preceding {lines}'
logging.critical(msg)
assert False, msg
# Get offset
if not lines[0].strip().startswith('OFFSET'):
msg = f'Malformed BVH in line preceding {lines}'
logging.critical(msg)
assert False, msg
_, *xyz = lines.pop(0).strip().split(' ')
offset = Vectors(list(map(float, xyz)))
# Get channels
if lines[0].strip().startswith('CHANNELS'):
channel_order = lines.pop(0).strip().split(' ')
_, channel_num, *channel_order = channel_order
else:
channel_num, channel_order = 0, []
if int(channel_num) != len(channel_order):
msg = f'Malformed BVH in line preceding {lines}'
logging.critical(msg)
assert False, msg
# Recurse for children
children: List[BVH_Joint] = []
while lines[0].strip() != '}':
children.append(BVH._parse_skeleton(lines))
lines.pop(0) # }
return BVH_Joint(name=joint_name, offset=offset, channel_order=channel_order, children=children)
@classmethod
def _process_frame_data(cls, skeleton: BVH_Joint, frames: List[List[float]]) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32]]:
""" Given skeleton and frame data, return root position data and joint quaternion data, separately"""
def _get_frame_channel_order(joint: BVH_Joint, channels=[]):
channels.extend(joint.channel_order)
for child in [child for child in joint.get_children() if isinstance(child, BVH_Joint)]:
_get_frame_channel_order(child, channels)
return channels
channels = _get_frame_channel_order(skeleton)
# create a mask so we retain only joint rotations and root position
mask = np.array(list(map(lambda x: True if 'rotation' in x else False, channels)))
mask[:3] = True # hack to make sure we keep root position
frames = np.array(frames, dtype=np.float32)[:, mask]
# split root pose data and joint euler angle data
pos_data, ea_rots = np.split(np.array(frames, dtype=np.float32), [3], axis=1)
# quaternion rot data will go here
rot_data = np.empty([len(frames), skeleton.joint_count(), 4], dtype=np.float32)
BVH._pose_ea_to_q(skeleton, ea_rots, rot_data)
return pos_data, rot_data
@classmethod
def _pose_ea_to_q(cls, joint: BVH_Joint, ea_rots: npt.NDArray[np.float32], q_rots: npt.NDArray[np.float32], p1: int = 0, p2: int = 0) -> Tuple[int, int]:
"""
Given joint and array of euler angle rotation data, converts to quaternions and stores in q_rots.
Only called by _process_frame_data(). Modifies q_rots inplace.
:param p1: pointer to find where in ea_rots to read euler angles from
:param p2: pointer to determine where in q_rots to input quaternion
"""
axis_chars = "".join([c[0].lower() for c in joint.channel_order if c.endswith('rotation')]) # e.g. 'xyz'
q_rots[:, p2] = Quaternions.from_euler_angles(axis_chars, ea_rots[:, p1:p1+len(axis_chars)]).qs
p1 += len(axis_chars)
p2 += 1
for child in joint.get_children():
if isinstance(child, BVH_Joint):
p1, p2 = BVH._pose_ea_to_q(child, ea_rots, q_rots, p1, p2)
return p1, p2