feat: implemented positional interpolation
Browse files- modeling_bert.py +2 -1
modeling_bert.py
CHANGED
@@ -787,7 +787,8 @@ class JinaBertEncoder(nn.Module):
|
|
787 |
# Device catch-up
|
788 |
self.alibi = self.alibi.to(hidden_states.device)
|
789 |
|
790 |
-
|
|
|
791 |
if self.gradient_checkpointing and self.training:
|
792 |
if use_cache:
|
793 |
logger.warning_once(
|
|
|
787 |
# Device catch-up
|
788 |
self.alibi = self.alibi.to(hidden_states.device)
|
789 |
|
790 |
+
unpadded_seqlens = torch.sum(attention_mask, dim=1).unsqueeze(1).unsqueeze(1).unsqueeze(1)
|
791 |
+
alibi_bias = self.alibi[:, :, :seqlen, :seqlen] * 512 / unpadded_seqlens
|
792 |
if self.gradient_checkpointing and self.training:
|
793 |
if use_cache:
|
794 |
logger.warning_once(
|