File size: 4,717 Bytes
f2c15d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from typing import Any, Iterable, Optional, Union

import torch


def B_to_GiB(bytes: Union[int, float]) -> float:
    return bytes / 2**30


def get_tensor_bytes(tensor: torch.Tensor) -> int:
    """
    Returns the bytes of storage a given tensor takes up. If `tensor` is a view of a larger tensor,
    this function only returns the bytes associated with the view.
    """
    tensor_bytes = tensor.numel() * tensor.element_size()
    return tensor_bytes


class AllocatedMemContext:
    """
    Context manager which captures the allocated GPU memory at context exit and the change between
    enter and exit.

    Only includes `allocated_bytes.all.`-prefixed keys in `memory_stats` with all readings converted
    to GiB.

    Example:

    ```python

    ```
    """

    def __init__(self) -> None:
        # Ensure CUDA libraries are loaded:
        torch.cuda.current_blas_handle()

        self.before: dict[str, int] = {}
        self.after: dict[str, int] = {}
        self.delta: dict[str, int] = {}

        self._mem_key_prefix = "allocated_bytes.all."

    def _get_mem_dict(self) -> dict[str, int]:
        return {
            k.replace(self._mem_key_prefix, ""): v
            for k, v in torch.cuda.memory_stats().items()
            if self._mem_key_prefix in k
        }

    def __enter__(self) -> "AllocatedMemContext":
        self.before = self._get_mem_dict()
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        self.after = self._get_mem_dict()
        self.delta = {k: v - self.before[k] for k, v in self.after.items()}


class SavedTensorContext:
    """
    Context manager which captures all tensors which are registered as being saved for backwards
    within the context window. Does not work with `meta`-device tensors.

    All saved tensors are stored in the `saved_tensor_dict` attr, which is an instance of torch's
    WeakTensorKeyDictionary with tensor/data_ptr key/value pairs. Some of these tensors may be
    views of the same underlying storage. The total memory of all saved tensors in bytes, accounting
    for redundant views, can be accessed through `saved_tensor_mem`.

    Use:
    ```
    model = ...
    with SavedTensorContext(ignored_tensors=model.parameters()) as saved:
        # Do some computation with `model` and capture saved tensors which are not model weights

    ```
    saved.saved_tensor_dict # WeakTensorKeyDictionary of all saved tensors.
    saved.saved_tensor_mem # bytes from all saved tensors (activation memory).
    """

    def __init__(
        self,
        ignored_tensors: Optional[Iterable[torch.Tensor]] = None,
    ) -> None:
        # Track ignored tensors by their storage's data_ptr. Important to use storage's data_ptr,
        # not just the data_ptr of the tensor itself.
        self._ignored_data_ptrs = (
            set()
            if ignored_tensors is None
            else {t.untyped_storage().data_ptr() for t in ignored_tensors}
        )

        # Use WeakTensorKeyDictionary instances to save non-trivial tensor references, since these
        # won't keep the tensor alive if the only references to the tensor are within this object.
        self.saved_tensor_dict = torch.utils.weak.WeakTensorKeyDictionary()

        def pack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
            data_ptr = saved_tensor.untyped_storage().data_ptr()
            if data_ptr not in self._ignored_data_ptrs:
                self.saved_tensor_dict[saved_tensor] = data_ptr
            return saved_tensor

        def unpack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
            return saved_tensor

        self._saved_tensors_hook = torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)

    def __enter__(self) -> "SavedTensorContext":
        self._saved_tensors_hook.__enter__()
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        self._saved_tensors_hook.__exit__(*args, **kwargs)

    @property
    def saved_tensor_mem(self) -> int:
        """
        The memory in bytes of all saved tensors, accounting for views into the same storage.
        """
        accounted_for = self._ignored_data_ptrs.copy()
        total_bytes = 0
        for t in self.saved_tensor_dict:
            data_ptr = t.untyped_storage().data_ptr()
            if data_ptr not in accounted_for:
                print(f"Tensor ptr: {t.untyped_storage().data_ptr()}, "
                    f"shape: {t.shape}, "
                    f"dtype: {t.dtype}, "
                    f"device: {t.device}"
                    )
                total_bytes += t.untyped_storage().nbytes()
                accounted_for.add(data_ptr)
        return total_bytes