ygauravyy's picture
Upload 722 files
d49f7bc verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations # so we can refer to class Type inside class
import numpy as np
import numpy.typing as npt
from animated_drawings.model.vectors import Vectors
from animated_drawings.model.quaternions import Quaternions
import logging
from typing import Union, Optional, List, Tuple
class Transform():
"""Base class from which all other scene objects descend"""
def __init__(self,
parent: Optional[Transform] = None,
name: Optional[str] = None,
children: List[Transform] = [],
offset: Union[npt.NDArray[np.float32], Vectors, None] = None,
**kwargs
) -> None:
super().__init__(**kwargs)
self._parent: Optional[Transform] = parent
self._children: List[Transform] = []
for child in children:
self.add_child(child)
self.name: Optional[str] = name
self._translate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
self._rotate_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
self._scale_m: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
if offset is not None:
self.offset(offset)
self._local_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
self._world_transform: npt.NDArray[np.float32] = np.identity(4, dtype=np.float32)
self.dirty_bit: bool = True # are world/local transforms stale?
def update_transforms(self, parent_dirty_bit: bool = False, recurse_on_children: bool = True, update_ancestors: bool = False) -> None:
"""
Updates transforms if stale.
If own dirty bit is set, recompute local matrix
If own or parent's dirty bit is set, recompute world matrix
If own or parent's dirty bit is set, recurses on children, unless param recurse_on_children is false.
If update_ancestors is true, first find first ancestor, then call update_transforms upon it.
Set dirty bit back to false.
"""
if update_ancestors:
ancestor, ancestor_parent = self, self.get_parent()
while ancestor_parent is not None:
ancestor, ancestor_parent = ancestor_parent, ancestor_parent.get_parent()
ancestor.update_transforms()
if self.dirty_bit:
self.compute_local_transform()
if self.dirty_bit | parent_dirty_bit:
self.compute_world_transform()
if recurse_on_children:
for c in self.get_children():
c.update_transforms(self.dirty_bit | parent_dirty_bit)
self.dirty_bit = False
def compute_local_transform(self) -> None:
self._local_transform = self._translate_m @ self._rotate_m @ self._scale_m
def compute_world_transform(self) -> None:
self._world_transform = self._local_transform
if self._parent:
self._world_transform = self._parent._world_transform @ self._world_transform
def get_world_transform(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]:
"""
Get the transform's world matrix.
If update is true, check to ensure the world_transform is current
"""
if update_ancestors:
self.update_transforms(update_ancestors=True)
return np.copy(self._world_transform)
def set_scale(self, scale: float) -> None:
self._scale_m[:-1, :-1] = scale * np.identity(3, dtype=np.float32)
self.dirty_bit = True
def set_position(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None:
""" Set the absolute values of the translational elements of transform """
if isinstance(pos, Vectors):
pos = pos.vs
if pos.shape == (1, 3):
pos = np.squeeze(pos)
elif pos.shape == (3,):
pass
else:
msg = f'bad vector dim passed to set_position. Found: {pos.shape}'
logging.critical(msg)
assert False, msg
self._translate_m[:-1, -1] = pos
self.dirty_bit = True
def get_local_position(self) -> npt.NDArray[np.float32]:
""" Ensure local transform is up-to-date and return local xyz coordinates """
if self.dirty_bit:
self.compute_local_transform()
return np.copy(self._local_transform[:-1, -1])
def get_world_position(self, update_ancestors: bool = True) -> npt.NDArray[np.float32]:
"""
Ensure all parent transforms are update and return world xyz coordinates
If update_ancestor_transforms is true, update ancestor transforms to ensure
up-to-date world_transform before returning
"""
if update_ancestors:
self.update_transforms(update_ancestors=True)
return np.copy(self._world_transform[:-1, -1])
def offset(self, pos: Union[npt.NDArray[np.float32], Vectors]) -> None:
""" Translational offset by the specified amount """
if isinstance(pos, Vectors):
pos = pos.vs[0]
assert isinstance(pos, np.ndarray)
self.set_position(self._translate_m[:-1, -1] + pos)
def look_at(self, fwd_: Union[npt.NDArray[np.float32], Vectors, None]) -> None:
"""Given a forward vector, rotate the transform to face that position"""
if fwd_ is None:
fwd_ = Vectors(self.get_world_position())
elif isinstance(fwd_, np.ndarray):
fwd_ = Vectors(fwd_)
fwd: Vectors = fwd_.copy() # norming will change the vector
if fwd.vs.shape != (1, 3):
msg = f'look_at fwd_ vector must have shape [1,3]. Found: {fwd.vs.shape}'
logging.critical(msg)
assert False, msg
tmp: Vectors = Vectors([0.0, 1.0, 0.0])
# if fwd and tmp are same vector, modify tmp to avoid collapse
if np.isclose(fwd.vs, tmp.vs).all() or np.isclose(fwd.vs, -tmp.vs).all():
tmp.vs[0] += 0.001
right: Vectors = tmp.cross(fwd)
up: Vectors = fwd.cross(right)
fwd.norm()
right.norm()
up.norm()
rotate_m = np.identity(4, dtype=np.float32)
rotate_m[:-1, 0] = np.squeeze(right.vs)
rotate_m[:-1, 1] = np.squeeze(up.vs)
rotate_m[:-1, 2] = np.squeeze(fwd.vs)
self._rotate_m = rotate_m
self.dirty_bit = True
def get_right_up_fwd_vectors(self) -> Tuple[npt.NDArray[np.float32], npt.NDArray[np.float32], npt.NDArray[np.float32]]:
inverted: npt.NDArray[np.float32] = np.linalg.inv(self.get_world_transform())
right: npt.NDArray[np.float32] = inverted[:-1, 0]
up: npt.NDArray[np.float32] = inverted[:-1, 1]
fwd: npt.NDArray[np.float32] = inverted[:-1, 2]
return right, up, fwd
def set_rotation(self, q: Quaternions) -> None:
if q.qs.shape != (1, 4):
msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}'
logging.critical(msg)
assert False, msg
self._rotate_m = q.to_rotation_matrix()
self.dirty_bit = True
def rotation_offset(self, q: Quaternions) -> None:
if q.qs.shape != (1, 4):
msg = f'set_rotate q must have dimension (1, 4). Found: {q.qs.shape}'
logging.critical(msg)
assert False, msg
self._rotate_m = (q * Quaternions.from_rotation_matrix(self._rotate_m)).to_rotation_matrix()
self.dirty_bit = True
def add_child(self, child: Transform) -> None:
self._children.append(child)
child.set_parent(self)
def get_children(self) -> List[Transform]:
return self._children
def set_parent(self, parent: Transform) -> None:
self._parent = parent
self.dirty_bit = True
def get_parent(self) -> Optional[Transform]:
return self._parent
def get_transform_by_name(self, name: str) -> Optional[Transform]:
""" Search self and children for transform with matching name. Return it if found, None otherwise. """
# are we match?
if self.name == name:
return self
# recurse to check if a child is match
for c in self.get_children():
transform_or_none = c.get_transform_by_name(name)
if transform_or_none: # if we found it
return transform_or_none
# no match
return None
def draw(self, recurse: bool = True, **kwargs) -> None:
""" Draw this transform and recurse on children """
self._draw(**kwargs)
if recurse:
for child in self.get_children():
child.draw(**kwargs)
def _draw(self, **kwargs) -> None:
"""Transforms default to not being drawn. Subclasses must implement how they appear"""