Spaces:
Running
on
A10G
Running
on
A10G
File size: 1,203 Bytes
0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
blocks = []
convs = [
(1, 64, (3, 9), 1, (1, 4)),
(64, 128, (3, 9), (1, 2), (1, 4)),
(128, 256, (3, 9), (1, 2), (1, 4)),
(256, 512, (3, 9), (1, 2), (1, 4)),
(512, 1024, (3, 3), 1, (1, 1)),
(1024, 1, (3, 3), 1, (1, 1)),
]
for idx, (in_channels, out_channels, kernel_size, stride, padding) in enumerate(
convs
):
blocks.append(
weight_norm(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
)
)
if idx != len(convs) - 1:
blocks.append(nn.SiLU(inplace=True))
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
return self.blocks(x[:, None])[:, 0]
if __name__ == "__main__":
model = Discriminator()
print(sum(p.numel() for p in model.parameters()) / 1_000_000)
x = torch.randn(1, 128, 1024)
y = model(x)
print(y.shape)
print(y)
|