# Copyright (c) Meta Platforms, Inc. and affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from __future__ import annotations # so we can refer to class Type inside class import numpy as np import numpy.typing as npt import logging from typing import Union, Iterable, List, Tuple from animated_drawings.model.vectors import Vectors import math from animated_drawings.utils import TOLERANCE from functools import reduce class Quaternions: """ Wrapper class around ndarray interpreted as one or more quaternions. Quaternion order is [w, x, y, z] When passing in existing Quaternions, new Quaternions object will share the underlying nparray, so be careful. Strongly influenced by Daniel Holden's excellent Quaternions class. """ def __init__(self, qs: Union[Iterable[Union[int, float]], npt.NDArray[np.float32], Quaternions]) -> None: self.qs: npt.NDArray[np.float32] if isinstance(qs, np.ndarray): if not qs.shape[-1] == 4: msg = f'Final dimension passed to Quaternions must be 4. Found {qs.shape[-1]}' logging.critical(msg) assert False, msg if len(qs.shape) == 1: qs = np.expand_dims(qs, axis=0) self.qs = qs elif isinstance(qs, tuple) or isinstance(qs, list): try: qs = np.array(qs) assert qs.shape[-1] == 4 except Exception: msg = 'Could not convert quaternion data to ndarray with shape[-1] == 4' logging.critical(msg) assert False, msg if len(qs.shape) == 1: qs = np.expand_dims(qs, axis=0) self.qs = qs elif isinstance(qs, Quaternions): self.qs = qs.qs else: msg = 'Quaternions must be constructed from Quaternions or numpy array' logging.critical(msg) assert False, msg self.normalize() def normalize(self) -> None: self.qs = self.qs / np.expand_dims(np.sum(self.qs ** 2.0, axis=-1) ** 0.5, axis=-1) def to_rotation_matrix(self) -> npt.NDArray[np.float32]: """ From Ken Shoemake https://www.ljll.math.upmc.fr/~frey/papers/scientific%20visualisation/Shoemake%20K.,%20Quaternions.pdf :return: 4x4 rotation matrix representation of quaternions """ w = self.qs[..., 0].squeeze() x = self.qs[..., 1].squeeze() y = self.qs[..., 2].squeeze() z = self.qs[..., 3].squeeze() xx, yy, zz = x**2, y**2, z**2 wx, wy, wz = w*x, w*y, w*z xy, xz = x*y, x*z # no yz = y*z # Row 1 r00 = 1 - 2 * (yy + zz) r01 = 2 * (xy - wz) r02 = 2 * (xz + wy) # Row 2 r10 = 2 * (xy + wz) r11 = 1 - 2 * (xx + zz) r12 = 2 * (yz - wx) # Row 3 r20 = 2 * (xz - wy) r21 = 2 * (yz + wx) r22 = 1 - 2 * (xx + yy) return np.array([[r00, r01, r02, 0.0], [r10, r11, r12, 0.0], [r20, r21, r22, 0.0], [0.0, 0.0, 0.0, 1.0]], dtype=np.float32) @classmethod def rotate_between_vectors(cls, v1: Vectors, v2: Vectors) -> Quaternions: """ Computes quaternion rotating from v1 to v2. """ xyz: List[float] = v1.cross(v2).vs.squeeze().tolist() w: float = math.sqrt((v1.length**2) * (v2.length**2)) + np.dot(v1.vs.squeeze(), v2.vs.squeeze()) ret_q = Quaternions([w, *xyz]) ret_q.normalize() return ret_q @classmethod def from_angle_axis(cls, angles: npt.NDArray[np.float32], axes: Vectors) -> Quaternions: axes.norm() if len(angles.shape) == 1: angles = np.expand_dims(angles, axis=0) ss = np.sin(angles / 2.0) cs = np.cos(angles / 2.0) return Quaternions(np.concatenate([cs, axes.vs * ss], axis=-1)) @classmethod def identity(cls, ret_shape: Tuple[int]) -> Quaternions: qs = np.broadcast_to(np.array([1.0, 0.0, 0.0, 0.0]), [*ret_shape, 4]) return Quaternions(qs) @classmethod def from_euler_angles(cls, order: str, angles: npt.NDArray[np.float32]) -> Quaternions: """ Applies a series of euler angle rotations. Angles applied from right to left :param order: string comprised of x, y, and/or z :param angles: angles in degrees """ if len(angles.shape) == 1: angles = np.expand_dims(angles, axis=0) if len(order) != angles.shape[-1]: msg = 'length of orders and angles does not match' logging.critical(msg) assert False, msg _quats = [Quaternions.identity(angles.shape[:-1])] for axis_char, pos in zip(order, range(len(order))): angle = angles[..., pos] * np.pi / 180 angle = np.expand_dims(angle, axis=1) axis_char = axis_char.lower() if axis_char not in 'xyz': msg = f'order contained unsupported char:{axis_char}' logging.critical(msg) assert False, msg axis = np.zeros([*angles.shape[:-1], 3]) axis[..., ord(axis_char) - ord('x')] = 1.0 _quats.insert(0, Quaternions.from_angle_axis(angle, Vectors(axis))) ret_q = reduce(lambda a, b: b * a, _quats) return ret_q @classmethod def from_rotation_matrix(cls, M: npt.NDArray[np.float32]) -> Quaternions: """ As described here: https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf """ is_orthogonal = np.isclose(M @ M.T, np.identity(4), atol=TOLERANCE) if not is_orthogonal.all(): msg = "attempted to create quaternion from non-orthogonal rotation matrix" logging.critical(msg) assert False, msg if not np.isclose(np.linalg.det(M), 1.0): msg = "attempted to create quaternion from rotation matrix with det != 1" logging.critical(msg) assert False, msg # Note: Mike Day's article uses row vectors, whereas we used column, so here use transpose of matrix MT = M.T m00, m01, m02 = MT[0, 0], MT[0, 1], MT[0, 2] m10, m11, m12 = MT[1, 0], MT[1, 1], MT[1, 2] m20, m21, m22 = MT[2, 0], MT[2, 1], MT[2, 2] if m22 < 0: if m00 > m11: t = 1 + m00 - m11 - m22 q = np.array([m12-m21, t, m01+m10, m20+m02]) else: t = 1 - m00 + m11 - m22 q = np.array([m20-m02, m01+m10, t, m12+m21]) else: if m00 < -m11: t = 1 - m00 - m11 + m22 q = np.array([m01-m10, m20+m02, m12+m21, t]) else: t = 1 + m00 + m11 + m22 q = np.array([ t, m12-m21, m20-m02, m01-m10]) q *= (0.5 / math.sqrt(t)) ret_q = Quaternions(q) ret_q.normalize() return ret_q def __mul__(self, other: Quaternions): """ From https://danceswithcode.net/engineeringnotes/quaternions/quaternions.html """ s0 = self.qs[..., 0] s1 = self.qs[..., 1] s2 = self.qs[..., 2] s3 = self.qs[..., 3] r0 = other.qs[..., 0] r1 = other.qs[..., 1] r2 = other.qs[..., 2] r3 = other.qs[..., 3] t = np.empty(self.qs.shape) t[..., 0] = r0*s0 - r1*s1 - r2*s2 - r3*s3 t[..., 1] = r0*s1 + r1*s0 - r2*s3 + r3*s2 t[..., 2] = r0*s2 + r1*s3 + r2*s0 - r3*s1 t[..., 3] = r0*s3 - r1*s2 + r2*s1 + r3*s0 return Quaternions(t) def __neg__(self): return Quaternions(self.qs * np.array([1, -1, -1, -1])) def __str__(self): return f"Quaternions({str(self.qs)})" def __repr__(self): return f"Quaternions({str(self.qs)})"