Upload modeling_moonshine.py
Browse filesAdd full support for batching. Update decoding loop and input mask logic.
- modeling_moonshine.py +47 -22
modeling_moonshine.py
CHANGED
@@ -113,11 +113,11 @@ class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
|
|
113 |
def __init__(self, dim, inner_dim, n_head):
|
114 |
super().__init__(dim, inner_dim, n_head)
|
115 |
|
116 |
-
def forward(self, q, k_cache, v_cache):
|
117 |
q = self.to_q(q)
|
118 |
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
119 |
|
120 |
-
return super().sdp_attention(q, k_cache, v_cache)
|
121 |
|
122 |
|
123 |
class FFLinearGelu(nn.Module):
|
@@ -162,10 +162,10 @@ class EncoderLayer(nn.Module):
|
|
162 |
|
163 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
164 |
|
165 |
-
def forward(self, x, rot_pos_emb):
|
166 |
_x = x
|
167 |
x = self.norm1(x)
|
168 |
-
x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb)
|
169 |
x = x + _x
|
170 |
|
171 |
_x = x
|
@@ -187,12 +187,12 @@ class Encoder(nn.Module):
|
|
187 |
)
|
188 |
self.post_norm = nn.LayerNorm(dim, bias=False)
|
189 |
|
190 |
-
def forward(self, x):
|
191 |
-
pos = torch.arange(x.shape[
|
192 |
rot_pos_emb = self.rot_pos_emb(pos)
|
193 |
|
194 |
-
for layer in self.layers:
|
195 |
-
x = layer(x, rot_pos_emb=rot_pos_emb)
|
196 |
return self.post_norm(x)
|
197 |
|
198 |
|
@@ -214,7 +214,7 @@ class DecoderLayer(nn.Module):
|
|
214 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
215 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
216 |
|
217 |
-
def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb):
|
218 |
dim = x.size()[1]
|
219 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
220 |
_x = x
|
@@ -232,7 +232,7 @@ class DecoderLayer(nn.Module):
|
|
232 |
|
233 |
_x = x
|
234 |
x = self.norm2(x)
|
235 |
-
x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache)
|
236 |
x = x + _x
|
237 |
|
238 |
_x = x
|
@@ -259,7 +259,7 @@ class Decoder(nn.Module):
|
|
259 |
self.final_norm = nn.LayerNorm(dim, bias=False)
|
260 |
self.token_embedding = nn.Embedding(dec_voc_size, dim)
|
261 |
|
262 |
-
def forward(self, x, *args):
|
263 |
pos = torch.arange(x.shape[1], device=x.device)
|
264 |
rot_pos_emb = self.rot_pos_emb(pos)
|
265 |
x = self.token_embedding(x)
|
@@ -279,6 +279,7 @@ class Decoder(nn.Module):
|
|
279 |
x_attn_k_cache=x_attn_k_cache[idx],
|
280 |
x_attn_v_cache=x_attn_v_cache[idx],
|
281 |
rot_pos_emb=rot_pos_emb,
|
|
|
282 |
)
|
283 |
k_cache_new.append(new_k_line)
|
284 |
v_cache_new.append(new_v_line)
|
@@ -306,7 +307,7 @@ class InitialDecoderLayer(nn.Module):
|
|
306 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
307 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
308 |
|
309 |
-
def forward(self, x, context, rot_pos_emb):
|
310 |
dim = x.size()[1]
|
311 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
312 |
_x = x
|
@@ -323,7 +324,7 @@ class InitialDecoderLayer(nn.Module):
|
|
323 |
_x = x
|
324 |
x = self.norm2(x)
|
325 |
x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
|
326 |
-
q=x, k=context, v=context
|
327 |
)
|
328 |
x = x + _x
|
329 |
|
@@ -345,7 +346,7 @@ class DecoderInitial(Decoder):
|
|
345 |
]
|
346 |
)
|
347 |
|
348 |
-
def forward(self, x, enc_src):
|
349 |
pos = torch.arange(x.shape[1], device=x.device)
|
350 |
rot_pos_emb = self.rot_pos_emb(pos)
|
351 |
x = self.token_embedding(x)
|
@@ -362,6 +363,7 @@ class DecoderInitial(Decoder):
|
|
362 |
x,
|
363 |
enc_src,
|
364 |
rot_pos_emb,
|
|
|
365 |
)
|
366 |
|
367 |
k_cache.append(new_k_line)
|
@@ -429,16 +431,34 @@ class MoonshineModelTorch(nn.Module):
|
|
429 |
self.n_head = n_head
|
430 |
self.d_head = inner_dim // n_head
|
431 |
|
432 |
-
def generate(self, src):
|
433 |
preprocessed = self.preprocessor(src)
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
sot_token = 1
|
436 |
eot_token = 2
|
437 |
|
438 |
-
sot_array = [[sot_token] for _ in range(
|
439 |
seq = torch.as_tensor(sot_array).to(src.device)
|
440 |
|
441 |
-
vals = self.decoder_initial(x=seq, enc_src=enc)
|
442 |
logits = vals[0]
|
443 |
k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
|
444 |
vals[i : i + self.dec_depth]
|
@@ -448,10 +468,11 @@ class MoonshineModelTorch(nn.Module):
|
|
448 |
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
449 |
seq = torch.cat((seq, sample), dim=-1)
|
450 |
|
451 |
-
|
452 |
-
while
|
453 |
vals = self.decoder(
|
454 |
seq,
|
|
|
455 |
*k_cache,
|
456 |
*v_cache,
|
457 |
*x_attn_k_cache,
|
@@ -462,6 +483,10 @@ class MoonshineModelTorch(nn.Module):
|
|
462 |
v_cache = vals[self.dec_depth + 1 :]
|
463 |
logits = logits[:, -1] # get last token
|
464 |
sample = logits.argmax(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
465 |
seq = torch.cat((seq, sample), dim=-1)
|
466 |
|
467 |
return seq
|
@@ -483,5 +508,5 @@ class MoonshineModel(PreTrainedModel):
|
|
483 |
dec_ff_swiglu = config.dec_ff_swiglu,
|
484 |
)
|
485 |
|
486 |
-
def forward(self, tensor):
|
487 |
-
return self.model.generate(tensor)
|
|
|
113 |
def __init__(self, dim, inner_dim, n_head):
|
114 |
super().__init__(dim, inner_dim, n_head)
|
115 |
|
116 |
+
def forward(self, q, k_cache, v_cache, mask):
|
117 |
q = self.to_q(q)
|
118 |
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
119 |
|
120 |
+
return super().sdp_attention(q, k_cache, v_cache, mask=mask)
|
121 |
|
122 |
|
123 |
class FFLinearGelu(nn.Module):
|
|
|
162 |
|
163 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
164 |
|
165 |
+
def forward(self, x, rot_pos_emb, mask):
|
166 |
_x = x
|
167 |
x = self.norm1(x)
|
168 |
+
x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask)
|
169 |
x = x + _x
|
170 |
|
171 |
_x = x
|
|
|
187 |
)
|
188 |
self.post_norm = nn.LayerNorm(dim, bias=False)
|
189 |
|
190 |
+
def forward(self, x, mask):
|
191 |
+
pos = torch.arange(x.shape[-2], device=x.device)
|
192 |
rot_pos_emb = self.rot_pos_emb(pos)
|
193 |
|
194 |
+
for idx, layer in enumerate(self.layers):
|
195 |
+
x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask)
|
196 |
return self.post_norm(x)
|
197 |
|
198 |
|
|
|
214 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
215 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
216 |
|
217 |
+
def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask):
|
218 |
dim = x.size()[1]
|
219 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
220 |
_x = x
|
|
|
232 |
|
233 |
_x = x
|
234 |
x = self.norm2(x)
|
235 |
+
x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask)
|
236 |
x = x + _x
|
237 |
|
238 |
_x = x
|
|
|
259 |
self.final_norm = nn.LayerNorm(dim, bias=False)
|
260 |
self.token_embedding = nn.Embedding(dec_voc_size, dim)
|
261 |
|
262 |
+
def forward(self, x, input_mask, *args):
|
263 |
pos = torch.arange(x.shape[1], device=x.device)
|
264 |
rot_pos_emb = self.rot_pos_emb(pos)
|
265 |
x = self.token_embedding(x)
|
|
|
279 |
x_attn_k_cache=x_attn_k_cache[idx],
|
280 |
x_attn_v_cache=x_attn_v_cache[idx],
|
281 |
rot_pos_emb=rot_pos_emb,
|
282 |
+
input_mask=input_mask,
|
283 |
)
|
284 |
k_cache_new.append(new_k_line)
|
285 |
v_cache_new.append(new_v_line)
|
|
|
307 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
308 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
309 |
|
310 |
+
def forward(self, x, context, rot_pos_emb, input_mask):
|
311 |
dim = x.size()[1]
|
312 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
313 |
_x = x
|
|
|
324 |
_x = x
|
325 |
x = self.norm2(x)
|
326 |
x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
|
327 |
+
q=x, k=context, v=context, mask=input_mask,
|
328 |
)
|
329 |
x = x + _x
|
330 |
|
|
|
346 |
]
|
347 |
)
|
348 |
|
349 |
+
def forward(self, x, enc_src, input_mask):
|
350 |
pos = torch.arange(x.shape[1], device=x.device)
|
351 |
rot_pos_emb = self.rot_pos_emb(pos)
|
352 |
x = self.token_embedding(x)
|
|
|
363 |
x,
|
364 |
enc_src,
|
365 |
rot_pos_emb,
|
366 |
+
input_mask,
|
367 |
)
|
368 |
|
369 |
k_cache.append(new_k_line)
|
|
|
431 |
self.n_head = n_head
|
432 |
self.d_head = inner_dim // n_head
|
433 |
|
434 |
+
def generate(self, src, mask):
|
435 |
preprocessed = self.preprocessor(src)
|
436 |
+
batch_size = preprocessed.shape[0]
|
437 |
+
|
438 |
+
# Get max sequence length based on number of unmasked inputs for each sample in batch.
|
439 |
+
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second.
|
440 |
+
if mask is not None:
|
441 |
+
seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor
|
442 |
+
else:
|
443 |
+
token_limit = torch.tensor([src.shape[-1] * token_limit_factor])
|
444 |
+
seq_lens = torch.stack([token_limit for _ in range(batch_size)])
|
445 |
+
seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze()
|
446 |
+
|
447 |
+
# Preprocess mask so that it matches preprocessed audio.
|
448 |
+
if mask is not None:
|
449 |
+
mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool)
|
450 |
+
mask = ~mask.reshape((batch_size, 1, 1, -1))
|
451 |
+
mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1]))
|
452 |
+
|
453 |
+
enc = self.encoder(preprocessed, mask)
|
454 |
+
|
455 |
sot_token = 1
|
456 |
eot_token = 2
|
457 |
|
458 |
+
sot_array = [[sot_token] for _ in range(batch_size)]
|
459 |
seq = torch.as_tensor(sot_array).to(src.device)
|
460 |
|
461 |
+
vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask)
|
462 |
logits = vals[0]
|
463 |
k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
|
464 |
vals[i : i + self.dec_depth]
|
|
|
468 |
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
469 |
seq = torch.cat((seq, sample), dim=-1)
|
470 |
|
471 |
+
eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device)
|
472 |
+
while not torch.all(eot_mask):
|
473 |
vals = self.decoder(
|
474 |
seq,
|
475 |
+
mask,
|
476 |
*k_cache,
|
477 |
*v_cache,
|
478 |
*x_attn_k_cache,
|
|
|
483 |
v_cache = vals[self.dec_depth + 1 :]
|
484 |
logits = logits[:, -1] # get last token
|
485 |
sample = logits.argmax(dim=-1, keepdim=True)
|
486 |
+
# For each sample in batch detect EOT or token limit reached.
|
487 |
+
eot_mask = eot_mask | (sample.squeeze() == eot_token)
|
488 |
+
eot_mask = eot_mask | (seq.shape[-1] >= seq_lens)
|
489 |
+
sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token)
|
490 |
seq = torch.cat((seq, sample), dim=-1)
|
491 |
|
492 |
return seq
|
|
|
508 |
dec_ff_swiglu = config.dec_ff_swiglu,
|
509 |
)
|
510 |
|
511 |
+
def forward(self, tensor, input_mask=None):
|
512 |
+
return self.model.generate(tensor, input_mask)
|