njeffrie commited on
Commit
3ab5772
·
verified ·
1 Parent(s): 2ef02e9

Upload modeling_moonshine.py

Browse files

Add full support for batching. Update decoding loop and input mask logic.

Files changed (1) hide show
  1. modeling_moonshine.py +47 -22
modeling_moonshine.py CHANGED
@@ -113,11 +113,11 @@ class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
113
  def __init__(self, dim, inner_dim, n_head):
114
  super().__init__(dim, inner_dim, n_head)
115
 
116
- def forward(self, q, k_cache, v_cache):
117
  q = self.to_q(q)
118
  q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
119
 
120
- return super().sdp_attention(q, k_cache, v_cache)
121
 
122
 
123
  class FFLinearGelu(nn.Module):
@@ -162,10 +162,10 @@ class EncoderLayer(nn.Module):
162
 
163
  self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
164
 
165
- def forward(self, x, rot_pos_emb):
166
  _x = x
167
  x = self.norm1(x)
168
- x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb)
169
  x = x + _x
170
 
171
  _x = x
@@ -187,12 +187,12 @@ class Encoder(nn.Module):
187
  )
188
  self.post_norm = nn.LayerNorm(dim, bias=False)
189
 
190
- def forward(self, x):
191
- pos = torch.arange(x.shape[1], device=x.device)
192
  rot_pos_emb = self.rot_pos_emb(pos)
193
 
194
- for layer in self.layers:
195
- x = layer(x, rot_pos_emb=rot_pos_emb)
196
  return self.post_norm(x)
197
 
198
 
@@ -214,7 +214,7 @@ class DecoderLayer(nn.Module):
214
  self.norm3 = nn.LayerNorm(dim, bias=False)
215
  self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
216
 
217
- def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb):
218
  dim = x.size()[1]
219
  causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
220
  _x = x
@@ -232,7 +232,7 @@ class DecoderLayer(nn.Module):
232
 
233
  _x = x
234
  x = self.norm2(x)
235
- x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache)
236
  x = x + _x
237
 
238
  _x = x
@@ -259,7 +259,7 @@ class Decoder(nn.Module):
259
  self.final_norm = nn.LayerNorm(dim, bias=False)
260
  self.token_embedding = nn.Embedding(dec_voc_size, dim)
261
 
262
- def forward(self, x, *args):
263
  pos = torch.arange(x.shape[1], device=x.device)
264
  rot_pos_emb = self.rot_pos_emb(pos)
265
  x = self.token_embedding(x)
@@ -279,6 +279,7 @@ class Decoder(nn.Module):
279
  x_attn_k_cache=x_attn_k_cache[idx],
280
  x_attn_v_cache=x_attn_v_cache[idx],
281
  rot_pos_emb=rot_pos_emb,
 
282
  )
283
  k_cache_new.append(new_k_line)
284
  v_cache_new.append(new_v_line)
@@ -306,7 +307,7 @@ class InitialDecoderLayer(nn.Module):
306
  self.norm3 = nn.LayerNorm(dim, bias=False)
307
  self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
308
 
309
- def forward(self, x, context, rot_pos_emb):
310
  dim = x.size()[1]
311
  causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
312
  _x = x
@@ -323,7 +324,7 @@ class InitialDecoderLayer(nn.Module):
323
  _x = x
324
  x = self.norm2(x)
325
  x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
326
- q=x, k=context, v=context
327
  )
328
  x = x + _x
329
 
@@ -345,7 +346,7 @@ class DecoderInitial(Decoder):
345
  ]
346
  )
347
 
348
- def forward(self, x, enc_src):
349
  pos = torch.arange(x.shape[1], device=x.device)
350
  rot_pos_emb = self.rot_pos_emb(pos)
351
  x = self.token_embedding(x)
@@ -362,6 +363,7 @@ class DecoderInitial(Decoder):
362
  x,
363
  enc_src,
364
  rot_pos_emb,
 
365
  )
366
 
367
  k_cache.append(new_k_line)
@@ -429,16 +431,34 @@ class MoonshineModelTorch(nn.Module):
429
  self.n_head = n_head
430
  self.d_head = inner_dim // n_head
431
 
432
- def generate(self, src):
433
  preprocessed = self.preprocessor(src)
434
- enc = self.encoder(preprocessed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  sot_token = 1
436
  eot_token = 2
437
 
438
- sot_array = [[sot_token] for _ in range(enc.shape[0])]
439
  seq = torch.as_tensor(sot_array).to(src.device)
440
 
441
- vals = self.decoder_initial(x=seq, enc_src=enc)
442
  logits = vals[0]
443
  k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
444
  vals[i : i + self.dec_depth]
@@ -448,10 +468,11 @@ class MoonshineModelTorch(nn.Module):
448
  sample = logits[:, -1].argmax(dim=-1, keepdim=True)
449
  seq = torch.cat((seq, sample), dim=-1)
450
 
451
- seq_len = int(src.shape[-1] * 6.5 / 16000)
452
- while any([eot_token not in sub_seq for sub_seq in seq]) and seq.shape[-1] <= seq_len:
453
  vals = self.decoder(
454
  seq,
 
455
  *k_cache,
456
  *v_cache,
457
  *x_attn_k_cache,
@@ -462,6 +483,10 @@ class MoonshineModelTorch(nn.Module):
462
  v_cache = vals[self.dec_depth + 1 :]
463
  logits = logits[:, -1] # get last token
464
  sample = logits.argmax(dim=-1, keepdim=True)
 
 
 
 
465
  seq = torch.cat((seq, sample), dim=-1)
466
 
467
  return seq
@@ -483,5 +508,5 @@ class MoonshineModel(PreTrainedModel):
483
  dec_ff_swiglu = config.dec_ff_swiglu,
484
  )
485
 
486
- def forward(self, tensor):
487
- return self.model.generate(tensor)
 
113
  def __init__(self, dim, inner_dim, n_head):
114
  super().__init__(dim, inner_dim, n_head)
115
 
116
+ def forward(self, q, k_cache, v_cache, mask):
117
  q = self.to_q(q)
118
  q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
119
 
120
+ return super().sdp_attention(q, k_cache, v_cache, mask=mask)
121
 
122
 
123
  class FFLinearGelu(nn.Module):
 
162
 
163
  self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
164
 
165
+ def forward(self, x, rot_pos_emb, mask):
166
  _x = x
167
  x = self.norm1(x)
168
+ x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask)
169
  x = x + _x
170
 
171
  _x = x
 
187
  )
188
  self.post_norm = nn.LayerNorm(dim, bias=False)
189
 
190
+ def forward(self, x, mask):
191
+ pos = torch.arange(x.shape[-2], device=x.device)
192
  rot_pos_emb = self.rot_pos_emb(pos)
193
 
194
+ for idx, layer in enumerate(self.layers):
195
+ x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask)
196
  return self.post_norm(x)
197
 
198
 
 
214
  self.norm3 = nn.LayerNorm(dim, bias=False)
215
  self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
216
 
217
+ def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask):
218
  dim = x.size()[1]
219
  causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
220
  _x = x
 
232
 
233
  _x = x
234
  x = self.norm2(x)
235
+ x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask)
236
  x = x + _x
237
 
238
  _x = x
 
259
  self.final_norm = nn.LayerNorm(dim, bias=False)
260
  self.token_embedding = nn.Embedding(dec_voc_size, dim)
261
 
262
+ def forward(self, x, input_mask, *args):
263
  pos = torch.arange(x.shape[1], device=x.device)
264
  rot_pos_emb = self.rot_pos_emb(pos)
265
  x = self.token_embedding(x)
 
279
  x_attn_k_cache=x_attn_k_cache[idx],
280
  x_attn_v_cache=x_attn_v_cache[idx],
281
  rot_pos_emb=rot_pos_emb,
282
+ input_mask=input_mask,
283
  )
284
  k_cache_new.append(new_k_line)
285
  v_cache_new.append(new_v_line)
 
307
  self.norm3 = nn.LayerNorm(dim, bias=False)
308
  self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
309
 
310
+ def forward(self, x, context, rot_pos_emb, input_mask):
311
  dim = x.size()[1]
312
  causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
313
  _x = x
 
324
  _x = x
325
  x = self.norm2(x)
326
  x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
327
+ q=x, k=context, v=context, mask=input_mask,
328
  )
329
  x = x + _x
330
 
 
346
  ]
347
  )
348
 
349
+ def forward(self, x, enc_src, input_mask):
350
  pos = torch.arange(x.shape[1], device=x.device)
351
  rot_pos_emb = self.rot_pos_emb(pos)
352
  x = self.token_embedding(x)
 
363
  x,
364
  enc_src,
365
  rot_pos_emb,
366
+ input_mask,
367
  )
368
 
369
  k_cache.append(new_k_line)
 
431
  self.n_head = n_head
432
  self.d_head = inner_dim // n_head
433
 
434
+ def generate(self, src, mask):
435
  preprocessed = self.preprocessor(src)
436
+ batch_size = preprocessed.shape[0]
437
+
438
+ # Get max sequence length based on number of unmasked inputs for each sample in batch.
439
+ token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second.
440
+ if mask is not None:
441
+ seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor
442
+ else:
443
+ token_limit = torch.tensor([src.shape[-1] * token_limit_factor])
444
+ seq_lens = torch.stack([token_limit for _ in range(batch_size)])
445
+ seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze()
446
+
447
+ # Preprocess mask so that it matches preprocessed audio.
448
+ if mask is not None:
449
+ mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool)
450
+ mask = ~mask.reshape((batch_size, 1, 1, -1))
451
+ mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1]))
452
+
453
+ enc = self.encoder(preprocessed, mask)
454
+
455
  sot_token = 1
456
  eot_token = 2
457
 
458
+ sot_array = [[sot_token] for _ in range(batch_size)]
459
  seq = torch.as_tensor(sot_array).to(src.device)
460
 
461
+ vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask)
462
  logits = vals[0]
463
  k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
464
  vals[i : i + self.dec_depth]
 
468
  sample = logits[:, -1].argmax(dim=-1, keepdim=True)
469
  seq = torch.cat((seq, sample), dim=-1)
470
 
471
+ eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device)
472
+ while not torch.all(eot_mask):
473
  vals = self.decoder(
474
  seq,
475
+ mask,
476
  *k_cache,
477
  *v_cache,
478
  *x_attn_k_cache,
 
483
  v_cache = vals[self.dec_depth + 1 :]
484
  logits = logits[:, -1] # get last token
485
  sample = logits.argmax(dim=-1, keepdim=True)
486
+ # For each sample in batch detect EOT or token limit reached.
487
+ eot_mask = eot_mask | (sample.squeeze() == eot_token)
488
+ eot_mask = eot_mask | (seq.shape[-1] >= seq_lens)
489
+ sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token)
490
  seq = torch.cat((seq, sample), dim=-1)
491
 
492
  return seq
 
508
  dec_ff_swiglu = config.dec_ff_swiglu,
509
  )
510
 
511
+ def forward(self, tensor, input_mask=None):
512
+ return self.model.generate(tensor, input_mask)