Spaces:
Runtime error
Runtime error
# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me. | |
import contextlib | |
from typing import AsyncIterator, Dict, Sequence | |
import torch | |
from hivemind import DHT, P2PContext, TensorDescriptor, deserialize_torch_tensor, nested_flatten, serialize_torch_tensor | |
from hivemind.moe.server.connection_handler import ConnectionHandler | |
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE | |
from hivemind.proto import runtime_pb2 | |
from hivemind.utils import as_aiter | |
from hivemind.utils.asyncio import anext | |
from hivemind.utils.streaming import split_for_streaming | |
from src.data_structures import CHAIN_DELIMITER, ModuleUID | |
from src.server.backend import MAX_LENGTH, TransformerBackend | |
class TransformerConnectionHandler(ConnectionHandler): | |
"""Handles three request types: forward, backward and forward-incremental (inference)""" | |
module_backends: Dict[ModuleUID, TransformerBackend] | |
def __init__(self, dht: DHT, module_backends: Dict[str, TransformerBackend]): | |
super().__init__(dht, module_backends) | |
for module_backend in self.module_backends.values(): | |
assert isinstance(module_backend, TransformerBackend) | |
async def rpc_inference( | |
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext | |
) -> AsyncIterator[runtime_pb2.ExpertRequest]: | |
"""Compute a single step of inference using attention cache; update attention cache accordingly.""" | |
try: | |
print("OPENED RPC_INFERENCE") | |
request = await anext(requests) | |
requested_uids = self._check_header(request) | |
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) | |
cache_metadata = torch.tensor([[-1, -1]], dtype=torch.int64) # [cache_handle, prefix_length] | |
prefix_length = 0 | |
async with self._allocate_caches(requested_backends) as cache_handles: | |
assert len(cache_handles) == len(requested_backends) | |
while request.tensors: # iterate while user is willing to supply tensors | |
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] | |
# run request tensors through all requested modules, update caches | |
for backend, cache_handle in zip(requested_backends, cache_handles): | |
cache_metadata[0, 0], cache_metadata[0, 1] = cache_handle, prefix_length | |
assert ( | |
len(hidden_states) == 1 and hidden_states[0].ndim == 3 | |
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" | |
hidden_states = await backend.inference_pool.submit_task(cache_metadata, *hidden_states) | |
assert isinstance(hidden_states, (list, tuple)) | |
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 | |
# serialize and send last layer outputs | |
yield runtime_pb2.ExpertResponse( | |
tensors=[ | |
serialize_torch_tensor(result, proto.compression, allow_inplace=True) | |
for result, proto in zip( | |
hidden_states, nested_flatten(requested_backends[-1].outputs_schema) | |
) | |
] | |
) | |
# prepare for next step | |
prefix_length += hidden_states[0].shape[1] | |
request = await (anext(requests)) | |
finally: | |
print("CLOSED RPC_INFERENCE") | |
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: | |
# Parse request and prepare backends | |
hidden_states = [deserialize_torch_tensor(tensor) for tensor in request.tensors] | |
requested_uids = self._check_header(request) | |
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) | |
# Run a chain of requested backends | |
for backend in requested_backends: | |
assert isinstance(hidden_states, (list, tuple)) | |
assert ( | |
len(hidden_states) == 1 and hidden_states[0].ndim == 3 | |
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" | |
hidden_states = await backend.forward_pool.submit_task(*hidden_states) | |
# Serialize the overall output and respond | |
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 | |
return runtime_pb2.ExpertResponse( | |
tensors=[ | |
serialize_torch_tensor(result, proto.compression, allow_inplace=True) | |
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema)) | |
] | |
) | |
async def rpc_forward_stream( | |
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext | |
) -> AsyncIterator[runtime_pb2.ExpertRequest]: | |
# Parse requests and prepare backends | |
uids_header, hidden_states = await self._gather_inputs(requests, context) | |
requested_uids = self._check_header_str(uids_header) | |
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) | |
# Run a chain of requested backends | |
for backend in requested_backends: | |
assert isinstance(hidden_states, (list, tuple)) | |
assert ( | |
len(hidden_states) == 1 and hidden_states[0].ndim == 3 | |
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states" | |
hidden_states = await backend.forward_pool.submit_task(*hidden_states) | |
# Serialize the overall output | |
assert len(hidden_states) == 1 and hidden_states[0].ndim == 3 | |
serialized_output = [ | |
serialize_torch_tensor(result, proto.compression, allow_inplace=True) | |
for result, proto in zip(hidden_states, nested_flatten(requested_backends[-1].outputs_schema)) | |
] | |
# Split the serialized_output for streaming and respond | |
output_split = [ | |
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) | |
] | |
async for part in as_aiter(*output_split): | |
yield runtime_pb2.ExpertResponse(tensors=[part]) | |
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse: | |
# Parse requests and prepare backends | |
inputs, grads = [deserialize_torch_tensor(tensor) for tensor in request.tensors] | |
requested_uids = self._check_header(request) | |
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) | |
# Run a forward chain to collect intermediate inputs | |
# Note that we do not forward for the last module since we do not need its output | |
inter_inputs = [inputs] | |
for backend in requested_backends[:-1]: | |
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" | |
inputs = await backend.forward_pool.submit_task(inputs) | |
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 | |
inputs = inputs[0] | |
inter_inputs.append(inputs) | |
# Run a chain of requested backends | |
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): | |
inputs_and_grads = [inp, grads] | |
grads = await backend.backward_pool.submit_task(*inputs_and_grads) | |
assert isinstance(grads, (list, tuple)) and len(grads) == 1 | |
grads = grads[0] | |
# Serialize the overall grad_input and respond | |
return runtime_pb2.ExpertResponse( | |
tensors=[ | |
serialize_torch_tensor(result, proto.compression, allow_inplace=True) | |
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema)) | |
] | |
) | |
async def rpc_backward_stream( | |
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext | |
) -> AsyncIterator[runtime_pb2.ExpertResponse]: | |
uids_header, inputs_and_grads = await self._gather_inputs(requests, context) | |
inputs, grads = inputs_and_grads | |
requested_uids = self._check_header_str(uids_header) | |
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids) | |
# Run a forward chain to collect intermediate inputs | |
# Note that we do not forward for the last module since we do not need its outputs | |
inter_inputs = [inputs] | |
for backend in requested_backends[:-1]: | |
assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states" | |
inputs = await backend.forward_pool.submit_task(inputs) | |
assert isinstance(inputs, (list, tuple)) and len(inputs) == 1 | |
inputs = inputs[0] | |
inter_inputs.append(inputs) | |
# Run a backward chain for requested backends | |
for inp, backend in zip(inter_inputs[::-1], requested_backends[::-1]): | |
inputs_and_grads = [inp, grads] | |
grads = await backend.backward_pool.submit_task(*inputs_and_grads) | |
assert isinstance(grads, (list, tuple)) and len(grads) == 1 | |
grads = grads[0] | |
# Serialize the overall grad_inputs | |
serialized_grad_inputs = [ | |
serialize_torch_tensor(result, proto.compression, allow_inplace=True) | |
for result, proto in zip([grads], nested_flatten(requested_backends[0].grad_inputs_schema)) | |
] | |
# Split the serialized_grad_inputs for streaming and respond | |
output_split = [ | |
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE) | |
] | |
async for part in as_aiter(*output_split): | |
yield runtime_pb2.ExpertResponse(tensors=[part]) | |
def _check_header(self, request: runtime_pb2.ExpertRequest) -> Sequence[ModuleUID]: | |
"""Check that the first request to rpc_inference is valid""" | |
uids = (request.uid or "").split(CHAIN_DELIMITER) | |
if not uids: | |
raise RuntimeError("User did not provide any uids") | |
for uid in uids: | |
if uid not in self.module_backends: | |
raise RuntimeError(f"Remote peer does not serve {uid}") | |
return tuple(uids) | |
def _check_header_str(self, header) -> Sequence[ModuleUID]: | |
"""Check that the first request to rpc_inference is valid""" | |
uids = (header or "").split(CHAIN_DELIMITER) | |
if not uids: | |
raise RuntimeError("User did not provide any uids") | |
for uid in uids: | |
if uid not in self.module_backends: | |
raise RuntimeError(f"Remote peer does not serve {uid}") | |
return tuple(uids) | |
async def _allocate_caches(self, backends: Sequence[TransformerBackend]) -> Sequence[int]: | |
"""Allocate memory caches for each transformer block, return cache handles""" | |
async with contextlib.AsyncExitStack() as stack: | |
handles = [] | |
for backend in backends: | |
num_heads = backend.module.self_attention.num_heads | |
head_dim = backend.module.self_attention.head_dim | |
cache_descriptor = TensorDescriptor(size=(2, 1, MAX_LENGTH, num_heads, head_dim), dtype=torch.float32) | |
# [key_or_value, batch_size, max_length, num_heads, head_dim] | |
handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor))) | |
yield handles | |