Spaces:
Runtime error
Runtime error
File size: 1,587 Bytes
f0533a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import torch
import torch.nn as nn
import math
import torch.distributed as dist
def _all_to_all(
input_: torch.Tensor,
world_size: int,
group: dist.ProcessGroup,
scatter_dim: int,
gather_dim: int,
):
if world_size == 1:
return input_
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim):
ctx.process_group = process_group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
ctx.world_size = world_size
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
return output
@staticmethod
def backward(ctx, grad_output):
grad_output = _all_to_all(
grad_output,
ctx.world_size,
ctx.process_group,
ctx.gather_dim,
ctx.scatter_dim,
)
return (
grad_output,
None,
None,
None,
None,
)
def all_to_all(
input_: torch.Tensor,
process_group: dist.ProcessGroup,
world_size: int = 1,
scatter_dim: int = 2,
gather_dim: int = 1,
):
return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim) |