jbetker commited on
Commit
468a8be
·
1 Parent(s): 26611b1

classifier proto

Browse files
Files changed (1) hide show
  1. models/classifier.py +153 -0
models/classifier.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ResBlock(nn.Module):
5
+ def __init__(
6
+ self,
7
+ channels,
8
+ dropout,
9
+ out_channels=None,
10
+ use_conv=False,
11
+ use_scale_shift_norm=False,
12
+ dims=2,
13
+ up=False,
14
+ down=False,
15
+ kernel_size=3,
16
+ do_checkpoint=True,
17
+ ):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.dropout = dropout
21
+ self.out_channels = out_channels or channels
22
+ self.use_conv = use_conv
23
+ self.use_scale_shift_norm = use_scale_shift_norm
24
+ self.do_checkpoint = do_checkpoint
25
+ padding = 1 if kernel_size == 3 else 2
26
+
27
+ self.in_layers = nn.Sequential(
28
+ normalization(channels),
29
+ nn.SiLU(),
30
+ conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
31
+ )
32
+
33
+ self.updown = up or down
34
+
35
+ if up:
36
+ self.h_upd = Upsample(channels, False, dims)
37
+ self.x_upd = Upsample(channels, False, dims)
38
+ elif down:
39
+ self.h_upd = Downsample(channels, False, dims)
40
+ self.x_upd = Downsample(channels, False, dims)
41
+ else:
42
+ self.h_upd = self.x_upd = nn.Identity()
43
+
44
+ self.out_layers = nn.Sequential(
45
+ normalization(self.out_channels),
46
+ nn.SiLU(),
47
+ nn.Dropout(p=dropout),
48
+ zero_module(
49
+ conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
50
+ ),
51
+ )
52
+
53
+ if self.out_channels == channels:
54
+ self.skip_connection = nn.Identity()
55
+ elif use_conv:
56
+ self.skip_connection = conv_nd(
57
+ dims, channels, self.out_channels, kernel_size, padding=padding
58
+ )
59
+ else:
60
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
61
+
62
+ def forward(self, x):
63
+ if self.do_checkpoint:
64
+ return checkpoint(
65
+ self._forward, x
66
+ )
67
+ else:
68
+ return self._forward(x)
69
+
70
+ def _forward(self, x):
71
+ if self.updown:
72
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
73
+ h = in_rest(x)
74
+ h = self.h_upd(h)
75
+ x = self.x_upd(x)
76
+ h = in_conv(h)
77
+ else:
78
+ h = self.in_layers(x)
79
+ h = self.out_layers(h)
80
+ return self.skip_connection(x) + h
81
+
82
+
83
+ class AudioMiniEncoder(nn.Module):
84
+ def __init__(self,
85
+ spec_dim,
86
+ embedding_dim,
87
+ base_channels=128,
88
+ depth=2,
89
+ resnet_blocks=2,
90
+ attn_blocks=4,
91
+ num_attn_heads=4,
92
+ dropout=0,
93
+ downsample_factor=2,
94
+ kernel_size=3):
95
+ super().__init__()
96
+ self.init = nn.Sequential(
97
+ conv_nd(1, spec_dim, base_channels, 3, padding=1)
98
+ )
99
+ ch = base_channels
100
+ res = []
101
+ self.layers = depth
102
+ for l in range(depth):
103
+ for r in range(resnet_blocks):
104
+ res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size))
105
+ res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor))
106
+ ch *= 2
107
+ self.res = nn.Sequential(*res)
108
+ self.final = nn.Sequential(
109
+ normalization(ch),
110
+ nn.SiLU(),
111
+ conv_nd(1, ch, embedding_dim, 1)
112
+ )
113
+ attn = []
114
+ for a in range(attn_blocks):
115
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
116
+ self.attn = nn.Sequential(*attn)
117
+ self.dim = embedding_dim
118
+
119
+ def forward(self, x):
120
+ h = self.init(x)
121
+ h = sequential_checkpoint(self.res, self.layers, h)
122
+ h = self.final(h)
123
+ for blk in self.attn:
124
+ h = checkpoint(blk, h)
125
+ return h[:, :, 0]
126
+
127
+
128
+ class AudioMiniEncoderWithClassifierHead(nn.Module):
129
+ def __init__(self, classes, distribute_zero_label=True, **kwargs):
130
+ super().__init__()
131
+ self.enc = AudioMiniEncoder(**kwargs)
132
+ self.head = nn.Linear(self.enc.dim, classes)
133
+ self.num_classes = classes
134
+ self.distribute_zero_label = distribute_zero_label
135
+
136
+ def forward(self, x, labels=None):
137
+ h = self.enc(x)
138
+ logits = self.head(h)
139
+ if labels is None:
140
+ return logits
141
+ else:
142
+ if self.distribute_zero_label:
143
+ oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
144
+ zeros_indices = (labels == 0).unsqueeze(-1)
145
+ # Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
146
+ zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
147
+ zero_extra_mass[:, 0] = -.2
148
+ zero_extra_mass = zero_extra_mass * zeros_indices
149
+ oh_labels = oh_labels + zero_extra_mass
150
+ else:
151
+ oh_labels = labels
152
+ loss = nn.functional.cross_entropy(logits, oh_labels)
153
+ return loss