thomwolf's picture
thomwolf HF staff
update
f2c15d5
raw
history blame
4.72 kB
from typing import Any, Iterable, Optional, Union
import torch
def B_to_GiB(bytes: Union[int, float]) -> float:
return bytes / 2**30
def get_tensor_bytes(tensor: torch.Tensor) -> int:
"""
Returns the bytes of storage a given tensor takes up. If `tensor` is a view of a larger tensor,
this function only returns the bytes associated with the view.
"""
tensor_bytes = tensor.numel() * tensor.element_size()
return tensor_bytes
class AllocatedMemContext:
"""
Context manager which captures the allocated GPU memory at context exit and the change between
enter and exit.
Only includes `allocated_bytes.all.`-prefixed keys in `memory_stats` with all readings converted
to GiB.
Example:
```python
```
"""
def __init__(self) -> None:
# Ensure CUDA libraries are loaded:
torch.cuda.current_blas_handle()
self.before: dict[str, int] = {}
self.after: dict[str, int] = {}
self.delta: dict[str, int] = {}
self._mem_key_prefix = "allocated_bytes.all."
def _get_mem_dict(self) -> dict[str, int]:
return {
k.replace(self._mem_key_prefix, ""): v
for k, v in torch.cuda.memory_stats().items()
if self._mem_key_prefix in k
}
def __enter__(self) -> "AllocatedMemContext":
self.before = self._get_mem_dict()
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.after = self._get_mem_dict()
self.delta = {k: v - self.before[k] for k, v in self.after.items()}
class SavedTensorContext:
"""
Context manager which captures all tensors which are registered as being saved for backwards
within the context window. Does not work with `meta`-device tensors.
All saved tensors are stored in the `saved_tensor_dict` attr, which is an instance of torch's
WeakTensorKeyDictionary with tensor/data_ptr key/value pairs. Some of these tensors may be
views of the same underlying storage. The total memory of all saved tensors in bytes, accounting
for redundant views, can be accessed through `saved_tensor_mem`.
Use:
```
model = ...
with SavedTensorContext(ignored_tensors=model.parameters()) as saved:
# Do some computation with `model` and capture saved tensors which are not model weights
```
saved.saved_tensor_dict # WeakTensorKeyDictionary of all saved tensors.
saved.saved_tensor_mem # bytes from all saved tensors (activation memory).
"""
def __init__(
self,
ignored_tensors: Optional[Iterable[torch.Tensor]] = None,
) -> None:
# Track ignored tensors by their storage's data_ptr. Important to use storage's data_ptr,
# not just the data_ptr of the tensor itself.
self._ignored_data_ptrs = (
set()
if ignored_tensors is None
else {t.untyped_storage().data_ptr() for t in ignored_tensors}
)
# Use WeakTensorKeyDictionary instances to save non-trivial tensor references, since these
# won't keep the tensor alive if the only references to the tensor are within this object.
self.saved_tensor_dict = torch.utils.weak.WeakTensorKeyDictionary()
def pack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
data_ptr = saved_tensor.untyped_storage().data_ptr()
if data_ptr not in self._ignored_data_ptrs:
self.saved_tensor_dict[saved_tensor] = data_ptr
return saved_tensor
def unpack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
return saved_tensor
self._saved_tensors_hook = torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)
def __enter__(self) -> "SavedTensorContext":
self._saved_tensors_hook.__enter__()
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._saved_tensors_hook.__exit__(*args, **kwargs)
@property
def saved_tensor_mem(self) -> int:
"""
The memory in bytes of all saved tensors, accounting for views into the same storage.
"""
accounted_for = self._ignored_data_ptrs.copy()
total_bytes = 0
for t in self.saved_tensor_dict:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in accounted_for:
print(f"Tensor ptr: {t.untyped_storage().data_ptr()}, "
f"shape: {t.shape}, "
f"dtype: {t.dtype}, "
f"device: {t.device}"
)
total_bytes += t.untyped_storage().nbytes()
accounted_for.add(data_ptr)
return total_bytes