ccdv commited on
Commit
0cebb6c
1 Parent(s): c262c80

bos_token + readme

Browse files
Files changed (1) hide show
  1. modeling_lsg_camembert.py +39 -12
modeling_lsg_camembert.py CHANGED
@@ -53,16 +53,16 @@ class LSGCamembertConfig(CamembertConfig):
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
63
- if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
@@ -497,15 +497,16 @@ class LSGSelfAttention(BaseSelfAttention):
497
  "lsh": self.get_sparse_tokens_with_lsh,
498
  "stride": self.get_sparse_tokens_with_stride,
499
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
500
  }
501
 
502
  self.sparsity_type = config.sparsity_type
503
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
504
 
505
  if config.sparsity_type == "lsh":
506
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
507
 
508
- def get_sparse_tokens_with_norm(self, keys, values, mask):
509
 
510
  if self.sparsity_factor == 1:
511
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -533,7 +534,7 @@ class LSGSelfAttention(BaseSelfAttention):
533
 
534
  return keys, values, mask
535
 
536
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
537
 
538
  if self.sparsity_factor == 1:
539
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -556,7 +557,7 @@ class LSGSelfAttention(BaseSelfAttention):
556
  mask *= torch.finfo(mask.dtype).min
557
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
558
 
559
- def get_sparse_tokens_with_stride(self, keys, values, mask):
560
 
561
  if self.sparsity_factor == 1:
562
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -572,7 +573,7 @@ class LSGSelfAttention(BaseSelfAttention):
572
 
573
  return keys, values, mask
574
 
575
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
576
 
577
  if self.sparsity_factor == 1:
578
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -592,11 +593,14 @@ class LSGSelfAttention(BaseSelfAttention):
592
 
593
  return keys, values, mask
594
 
595
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
596
 
597
  if self.sparsity_factor == 1:
598
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
599
 
 
 
 
600
  block_size = min(self.block_size, self.sparse_block_size)
601
  keys = self.chunk(keys, block_size)
602
  values = self.chunk(values, block_size)
@@ -644,6 +648,29 @@ class LSGSelfAttention(BaseSelfAttention):
644
 
645
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
646
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
647
  def forward(
648
  self,
649
  hidden_states,
@@ -765,7 +792,7 @@ class LSGSelfAttention(BaseSelfAttention):
765
  # Get sparse idx
766
  sparse_key, sparse_value, sparse_mask = (None, None, None)
767
  if self.sparse_block_size and self.sparsity_factor > 0:
768
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
769
 
770
  # Expand masks on heads
771
  attention_mask = attention_mask.expand(-1, h, -1, -1)
@@ -838,7 +865,7 @@ class LSGSelfAttention(BaseSelfAttention):
838
  sparse_key, sparse_value, sparse_mask = (None, None, None)
839
 
840
  if self.sparse_block_size and self.sparsity_factor > 0:
841
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
842
 
843
  # Expand masks on heads
844
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
63
+ if self.sparsity_factor > self.num_attention_heads:
64
  logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
 
497
  "lsh": self.get_sparse_tokens_with_lsh,
498
  "stride": self.get_sparse_tokens_with_stride,
499
  "block_stride": self.get_sparse_tokens_with_block_stride,
500
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
501
  }
502
 
503
  self.sparsity_type = config.sparsity_type
504
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
505
 
506
  if config.sparsity_type == "lsh":
507
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
508
 
509
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
510
 
511
  if self.sparsity_factor == 1:
512
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
534
 
535
  return keys, values, mask
536
 
537
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
538
 
539
  if self.sparsity_factor == 1:
540
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
557
  mask *= torch.finfo(mask.dtype).min
558
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
559
 
560
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
561
 
562
  if self.sparsity_factor == 1:
563
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
573
 
574
  return keys, values, mask
575
 
576
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
577
 
578
  if self.sparsity_factor == 1:
579
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
593
 
594
  return keys, values, mask
595
 
596
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
597
 
598
  if self.sparsity_factor == 1:
599
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
600
 
601
+ if self.sparsity_factor == self.sparse_block_size:
602
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
603
+
604
  block_size = min(self.block_size, self.sparse_block_size)
605
  keys = self.chunk(keys, block_size)
606
  values = self.chunk(values, block_size)
 
648
 
649
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
650
 
651
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
652
+
653
+ if self.sparsity_factor == 1:
654
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
655
+
656
+ queries = queries.unsqueeze(-3)
657
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
658
+ keys = self.chunk(keys, self.sparsity_factor)
659
+ values = self.chunk(values, self.sparsity_factor)
660
+
661
+ n, h, b, t, d = keys.size()
662
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
663
+ if mask is not None:
664
+ scores = scores + mask
665
+
666
+ scores = torch.softmax(scores, dim=-1)
667
+ keys = scores @ keys
668
+ values = scores @ values
669
+ mask = mask.mean(dim=-1)
670
+ mask[mask != torch.finfo(mask.dtype).min] = 0
671
+
672
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
673
+
674
  def forward(
675
  self,
676
  hidden_states,
 
792
  # Get sparse idx
793
  sparse_key, sparse_value, sparse_mask = (None, None, None)
794
  if self.sparse_block_size and self.sparsity_factor > 0:
795
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
796
 
797
  # Expand masks on heads
798
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
865
  sparse_key, sparse_value, sparse_mask = (None, None, None)
866
 
867
  if self.sparse_block_size and self.sparsity_factor > 0:
868
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
869
 
870
  # Expand masks on heads
871
  attention_mask = attention_mask.expand(-1, h, -1, -1)