Gael Le Lan
Initial commit
9d0d223
raw
history blame
4.49 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Streaming module API that should be implemented by all Streaming components,
"""
from contextlib import contextmanager
import typing as tp
from torch import nn
import torch
State = tp.Dict[str, torch.Tensor]
class StreamingModule(nn.Module):
"""Common API for streaming components.
Each streaming component has a streaming state, which is just a dict[str, Tensor].
By convention, the first dim of each tensor must be the batch size.
Don't use dots in the key names, as this would clash with submodules
(like in state_dict).
If `self._is_streaming` is True, the component should use and remember
the proper state inside `self._streaming_state`.
To set a streaming component in streaming state, use
with module.streaming():
...
This will automatically reset the streaming state when exiting the context manager.
This also automatically propagates to all streaming children module.
Some module might also implement the `StreamingModule.flush` method, although
this one is trickier, as all parents module must be StreamingModule and implement
it as well for it to work properly. See `StreamingSequential` after.
"""
def __init__(self) -> None:
super().__init__()
self._streaming_state: State = {}
self._is_streaming = False
def _apply_named_streaming(self, fn: tp.Any):
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def _set_streaming(self, streaming: bool):
def _set_streaming(name, module):
module._is_streaming = streaming
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self):
"""Context manager to enter streaming mode. Reset streaming state on exit."""
self._set_streaming(True)
try:
yield
finally:
self._set_streaming(False)
self.reset_streaming()
def reset_streaming(self):
"""Reset the streaming state."""
def _reset(name: str, module: StreamingModule):
module._streaming_state.clear()
self._apply_named_streaming(_reset)
def get_streaming_state(self) -> State:
"""Return the streaming state, including that of sub-modules."""
state: State = {}
def _add(name: str, module: StreamingModule):
if name:
name += "."
for key, value in module._streaming_state.items():
state[name + key] = value
self._apply_named_streaming(_add)
return state
def set_streaming_state(self, state: State):
"""Set the streaming state, including that of sub-modules."""
state = dict(state)
def _set(name: str, module: StreamingModule):
if name:
name += "."
module._streaming_state.clear()
for key, value in list(state.items()):
# complexity is not ideal here, but probably fine.
if key.startswith(name):
local_key = key[len(name):]
if '.' not in local_key:
module._streaming_state[local_key] = value
del state[key]
self._apply_named_streaming(_set)
assert len(state) == 0, list(state.keys())
def flush(self, x: tp.Optional[torch.Tensor] = None):
"""Flush any remaining outputs that were waiting for completion.
Typically, for convolutions, this will add the final padding
and process the last buffer.
This should take an optional argument `x`, which will be provided
if a module before this one in the streaming pipeline has already
spitted out a flushed out buffer.
"""
if x is None:
return None
else:
return self(x)
class StreamingSequential(StreamingModule, nn.Sequential):
"""A streaming compatible alternative of `nn.Sequential`.
"""
def flush(self, x: tp.Optional[torch.Tensor] = None):
for module in self:
if isinstance(module, StreamingModule):
x = module.flush(x)
elif x is not None:
x = module(x)
return x