|
import torch |
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
class StackedRandomGenerator: |
|
def __init__(self, device, seeds): |
|
super().__init__() |
|
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] |
|
|
|
def randn(self, size, **kwargs): |
|
assert size[0] == len(self.generators) |
|
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) |
|
|
|
def randn_like(self, input): |
|
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) |
|
|
|
def randint(self, *args, size, **kwargs): |
|
assert size[0] == len(self.generators) |
|
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) |
|
|