Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from __future__ import annotations # so we can refer to class Type inside class | |
import numpy as np | |
import numpy.typing as npt | |
import logging | |
from typing import Union, Iterable, Tuple | |
from numbers import Number | |
from copy import copy | |
from animated_drawings.utils import TOLERANCE | |
class Vectors(): | |
""" | |
Wrapper class around ndarray interpreted as one or more vectors of equal dimensionality | |
When passing in existing Vectors, new Vectors object will share the underlying nparray, so be careful. | |
""" | |
def __init__(self, vs_: Union[Iterable[Union[float, int, Vectors, npt.NDArray[np.float32]]], Vectors]) -> None: # noqa: C901 | |
self.vs: npt.NDArray[np.float32] | |
# initialize from single ndarray | |
if isinstance(vs_, np.ndarray): | |
if len(vs_.shape) == 1: | |
vs_ = np.expand_dims(vs_, axis=0) | |
self.vs = vs_ | |
# initialize from tuple or list of numbers | |
elif isinstance(vs_, (tuple, list)) and isinstance(vs_[0], Number): | |
try: | |
vs_ = np.array(vs_) | |
if len(vs_.shape) == 1: | |
vs_ = np.expand_dims(vs_, axis=0) | |
except Exception as e: | |
msg = f'Error initializing Vectors: {str(e)}' | |
logging.critical(msg) | |
assert False, msg | |
self.vs = vs_ | |
# initialize from tuple or list of ndarrays | |
elif isinstance(vs_, (tuple, list)) and isinstance(vs_[0], np.ndarray): | |
try: | |
vs_ = np.stack(vs_) # pyright: ignore[reportGeneralTypeIssues] | |
except Exception as e: | |
msg = f'Error initializing Vectors: {str(e)}' | |
logging.critical(msg) | |
assert False, msg | |
self.vs = vs_ # pyright: ignore[reportGeneralTypeIssues] | |
# initialize from tuple or list of Vectors | |
elif isinstance(vs_, (tuple, list)) and isinstance(vs_[0], Vectors): | |
try: | |
vs_ = np.stack([v.vs.squeeze() for v in vs_]) # pyright: ignore[reportGeneralTypeIssues] | |
except Exception as e: | |
msg = f'Error initializing Vectors: {str(e)}' | |
logging.critical(msg) | |
assert False, msg | |
self.vs = vs_ | |
# initialize from single Vectors | |
elif isinstance(vs_, Vectors): | |
self.vs = vs_.vs | |
else: | |
msg = 'Vectors must be constructed from Vectors, ndarray, or Tuples/List of floats/ints or Vectors' | |
logging.critical(msg) | |
assert False, msg | |
def norm(self) -> None: | |
ns: npt.NDArray[np.float64] = np.linalg.norm(self.vs, axis=-1) | |
if np.min(ns) < TOLERANCE: | |
logging.info(f"Encountered values close to zero in vector norm. Replacing with {TOLERANCE}") | |
ns[ns < TOLERANCE] = TOLERANCE | |
self.vs = self.vs / np.expand_dims(ns, axis=-1) | |
def cross(self, v2: Vectors) -> Vectors: | |
""" Cross product of a series of 2 or 3 dimensional vectors. All dimensions of vs must match.""" | |
if self.vs.shape != v2.vs.shape: | |
msg = f'Cannot cross product different sized vectors: {self.vs.shape} {v2.vs.shape}.' | |
logging.critical(msg) | |
assert False, msg | |
if not self.vs.shape[-1] in [2, 3]: | |
msg = f'Cannot cross product vectors of size: {self.vs.shape[-1]}. Must be 2 or 3.' | |
logging.critical(msg) | |
assert False, msg | |
return Vectors(np.cross(self.vs, v2.vs)) | |
def perpendicular(self, ccw: bool = True) -> Vectors: | |
""" | |
Returns ndarray of vectors perpendicular to the original ones. | |
Only 2D and 3D vectors are supported. | |
By default returns the counter clockwise vector, but passing ccw=False returns clockwise | |
""" | |
if not self.vs.shape[-1] in [2, 3]: | |
msg = f'Cannot get perpendicular of vectors of size: {self.vs.shape[-1]}. Must be 2 or 3.' | |
logging.critical(msg) | |
assert False, msg | |
v_up: Vectors = Vectors(np.tile([0.0, 1.0, 0.0], [*self.shape[:-1], 1])) | |
v_perp = v_up.cross(self) | |
v_perp.norm() | |
if not ccw: | |
v_perp *= -1 | |
return v_perp | |
def average(self) -> Vectors: | |
""" Return the average of a collection of vectors, along the first axis""" | |
return Vectors(np.mean(self.vs, axis=0)) | |
def copy(self) -> Vectors: | |
return copy(self) | |
def shape(self) -> Tuple[int, ...]: | |
return self.vs.shape | |
def length(self) -> npt.NDArray[np.float32]: | |
return np.linalg.norm(self.vs, axis=-1).astype(np.float32) | |
def __mul__(self, val: float) -> Vectors: | |
return Vectors(self.vs * val) | |
def __truediv__(self, scale: Union[int, float]) -> Vectors: | |
return Vectors(self.vs / scale) | |
def __sub__(self, other: Vectors) -> Vectors: | |
if self.vs.shape != other.vs.shape: | |
msg = 'Attempted to subtract Vectors with different dimensions' | |
logging.critical(msg) | |
assert False, msg | |
return Vectors(np.subtract(self.vs, other.vs)) | |
def __add__(self, other: Vectors) -> Vectors: | |
if self.vs.shape != other.vs.shape: | |
msg = 'Attempted to add Vectors with different dimensions' | |
logging.critical(msg) | |
assert False, msg | |
return Vectors(np.add(self.vs, other.vs)) | |
def __copy__(self) -> Vectors: | |
return Vectors(self) | |
def __str__(self) -> str: | |
return f"Vectors({str(self.vs)})" | |
def __repr__(self) -> str: | |
return f"Vectors({str(self.vs)})" | |