File size: 5,825 Bytes
d49f7bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations  # so we can refer to class Type inside class
import numpy as np
import numpy.typing as npt
import logging
from typing import Union, Iterable, Tuple
from numbers import Number
from copy import copy
from animated_drawings.utils import TOLERANCE


class Vectors():
    """
    Wrapper class around ndarray interpreted as one or more vectors of equal dimensionality
    When passing in existing Vectors, new Vectors object will share the underlying nparray, so be careful.
    """

    def __init__(self, vs_: Union[Iterable[Union[float, int, Vectors, npt.NDArray[np.float32]]], Vectors]) -> None:  # noqa: C901

        self.vs: npt.NDArray[np.float32]

        # initialize from single ndarray
        if isinstance(vs_, np.ndarray):
            if len(vs_.shape) == 1:
                vs_ = np.expand_dims(vs_, axis=0)
            self.vs = vs_

        # initialize from tuple or list of numbers
        elif isinstance(vs_, (tuple, list)) and isinstance(vs_[0], Number):
            try:
                vs_ = np.array(vs_)
                if len(vs_.shape) == 1:
                    vs_ = np.expand_dims(vs_, axis=0)
            except Exception as e:
                msg = f'Error initializing Vectors: {str(e)}'
                logging.critical(msg)
                assert False, msg
            self.vs = vs_

        # initialize from tuple or list of ndarrays
        elif isinstance(vs_, (tuple, list)) and isinstance(vs_[0], np.ndarray):
            try:
                vs_ = np.stack(vs_)  # pyright: ignore[reportGeneralTypeIssues]
            except Exception as e:
                msg = f'Error initializing Vectors: {str(e)}'
                logging.critical(msg)
                assert False, msg
            self.vs = vs_  # pyright: ignore[reportGeneralTypeIssues]

        # initialize from tuple or list of Vectors
        elif isinstance(vs_, (tuple, list)) and isinstance(vs_[0], Vectors):
            try:
                vs_ = np.stack([v.vs.squeeze() for v in vs_])  # pyright: ignore[reportGeneralTypeIssues]
            except Exception as e:
                msg = f'Error initializing Vectors: {str(e)}'
                logging.critical(msg)
                assert False, msg
            self.vs = vs_

        # initialize from single Vectors
        elif isinstance(vs_, Vectors):
            self.vs =  vs_.vs

        else:
            msg = 'Vectors must be constructed from Vectors, ndarray, or Tuples/List of floats/ints or Vectors'
            logging.critical(msg)
            assert False, msg

    def norm(self) -> None:
        ns: npt.NDArray[np.float64] = np.linalg.norm(self.vs, axis=-1)

        if np.min(ns) < TOLERANCE:
            logging.info(f"Encountered values close to zero in vector norm. Replacing with {TOLERANCE}")
            ns[ns < TOLERANCE] = TOLERANCE

        self.vs = self.vs / np.expand_dims(ns, axis=-1)

    def cross(self, v2: Vectors) -> Vectors:
        """ Cross product of a series of 2 or 3 dimensional vectors. All dimensions of vs must match."""

        if self.vs.shape != v2.vs.shape:
            msg = f'Cannot cross product different sized vectors: {self.vs.shape} {v2.vs.shape}.'
            logging.critical(msg)
            assert False, msg

        if not self.vs.shape[-1] in [2, 3]:
            msg = f'Cannot cross product vectors of size: {self.vs.shape[-1]}. Must be 2 or 3.'
            logging.critical(msg)
            assert False, msg

        return Vectors(np.cross(self.vs, v2.vs))

    def perpendicular(self, ccw: bool = True) -> Vectors:
        """
        Returns ndarray of vectors perpendicular to the original ones.
        Only 2D and 3D vectors are supported.
        By default returns the counter clockwise vector, but passing ccw=False returns clockwise
        """
        if not self.vs.shape[-1] in [2, 3]:
            msg = f'Cannot get perpendicular of vectors of size: {self.vs.shape[-1]}. Must be 2 or 3.'
            logging.critical(msg)
            assert False, msg

        v_up: Vectors = Vectors(np.tile([0.0, 1.0, 0.0], [*self.shape[:-1], 1]))

        v_perp = v_up.cross(self)
        v_perp.norm()

        if not ccw:
            v_perp *= -1

        return v_perp

    def average(self) -> Vectors:
        """ Return the average of a collection of vectors, along the first axis"""
        return Vectors(np.mean(self.vs, axis=0))

    def copy(self) -> Vectors:
        return copy(self)

    @property
    def shape(self) -> Tuple[int, ...]:
        return self.vs.shape

    @property
    def length(self) -> npt.NDArray[np.float32]:
        return np.linalg.norm(self.vs, axis=-1).astype(np.float32)

    def __mul__(self, val: float) -> Vectors:
        return Vectors(self.vs * val)

    def __truediv__(self, scale: Union[int, float]) -> Vectors:
        return Vectors(self.vs / scale)

    def __sub__(self, other: Vectors) -> Vectors:
        if self.vs.shape != other.vs.shape:
            msg = 'Attempted to subtract Vectors with different dimensions'
            logging.critical(msg)
            assert False, msg
        return Vectors(np.subtract(self.vs, other.vs))

    def __add__(self, other: Vectors) -> Vectors:
        if self.vs.shape != other.vs.shape:
            msg = 'Attempted to add Vectors with different dimensions'
            logging.critical(msg)
            assert False, msg
        return Vectors(np.add(self.vs, other.vs))

    def __copy__(self) -> Vectors:
        return Vectors(self)

    def __str__(self) -> str:
        return f"Vectors({str(self.vs)})"

    def __repr__(self) -> str:
        return f"Vectors({str(self.vs)})"