jupyterjazz
commited on
Commit
•
6e55444
1
Parent(s):
0f0bed6
style: removing unused files, black, isort
Browse filesSigned-off-by: jupyterjazz <saba.sturua@jina.ai>
- block.py +5 -4
- embedding.py +27 -13
- mha.py +101 -42
- mlp.py +33 -15
- modeling_lora.py +30 -18
- modeling_xlm_roberta.py +116 -194
- modeling_xlm_roberta_for_glue.py +0 -109
- rotary.py +43 -16
- stochastic_depth.py +1 -1
- xlm_padding.py +24 -10
block.py
CHANGED
@@ -8,15 +8,14 @@ from typing import Optional
|
|
8 |
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
from torch import Tensor
|
13 |
|
14 |
-
from .stochastic_depth import StochasticDepth
|
15 |
from .mha import MHA
|
16 |
from .mlp import Mlp
|
|
|
17 |
|
18 |
try:
|
19 |
-
from flash_attn.ops.triton.layer_norm import
|
20 |
except ImportError:
|
21 |
layer_norm_fn, RMSNorm = None, None
|
22 |
|
@@ -233,7 +232,9 @@ class Block(nn.Module):
|
|
233 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
234 |
)
|
235 |
if not isinstance(self.mlp, nn.Identity):
|
236 |
-
mlp_out = self.mlp(
|
|
|
|
|
237 |
if self.return_residual: # mlp out is actually a pair here
|
238 |
mlp_out, hidden_states = mlp_out
|
239 |
if not self.fused_dropout_add_ln:
|
|
|
8 |
|
9 |
import torch
|
10 |
import torch.nn as nn
|
|
|
11 |
from torch import Tensor
|
12 |
|
|
|
13 |
from .mha import MHA
|
14 |
from .mlp import Mlp
|
15 |
+
from .stochastic_depth import StochasticDepth
|
16 |
|
17 |
try:
|
18 |
+
from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
|
19 |
except ImportError:
|
20 |
layer_norm_fn, RMSNorm = None, None
|
21 |
|
|
|
232 |
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
233 |
)
|
234 |
if not isinstance(self.mlp, nn.Identity):
|
235 |
+
mlp_out = self.mlp(
|
236 |
+
hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
|
237 |
+
)
|
238 |
if self.return_residual: # mlp out is actually a pair here
|
239 |
mlp_out, hidden_states = mlp_out
|
240 |
if not self.fused_dropout_add_ln:
|
embedding.py
CHANGED
@@ -5,10 +5,8 @@
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
|
12 |
|
13 |
|
14 |
class XLMRobertaEmbeddings(nn.Module):
|
@@ -38,20 +36,29 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
38 |
max_position_embeddings, embed_dim, **factory_kwargs
|
39 |
)
|
40 |
if self.type_vocab_size > 0:
|
41 |
-
self.token_type_embeddings = nn.Embedding(
|
|
|
|
|
42 |
|
43 |
-
def forward(
|
|
|
|
|
44 |
"""
|
45 |
input_ids: (batch, seqlen)
|
46 |
position_ids: (batch, seqlen)
|
47 |
token_type_ids: (batch, seqlen)
|
|
|
48 |
"""
|
49 |
batch_size, seqlen = input_ids.shape
|
50 |
if adapter_mask is not None:
|
51 |
unique_tasks = torch.unique(adapter_mask)
|
52 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
53 |
-
embeddings = torch.empty(
|
54 |
-
|
|
|
|
|
|
|
|
|
55 |
for task_id in unique_tasks:
|
56 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
57 |
task_input_ids = input_ids[task_indices]
|
@@ -61,20 +68,27 @@ class XLMRobertaEmbeddings(nn.Module):
|
|
61 |
embeddings = self.word_embeddings(input_ids)
|
62 |
if self.max_position_embeddings > 0:
|
63 |
if position_ids is None:
|
64 |
-
position_ids = create_position_ids_from_input_ids(
|
65 |
-
|
|
|
66 |
position_embeddings = self.position_embeddings(position_ids)
|
67 |
embeddings = embeddings + position_embeddings
|
68 |
if self.type_vocab_size > 0:
|
69 |
if token_type_ids is None:
|
70 |
-
token_type_ids = torch.zeros(
|
|
|
|
|
71 |
|
72 |
if adapter_mask is not None:
|
73 |
unique_tasks = torch.unique(adapter_mask)
|
74 |
for task_id in unique_tasks:
|
75 |
-
task_token_type_embeddings = self.token_type_embeddings(
|
|
|
|
|
76 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
77 |
-
embeddings[task_indices] =
|
|
|
|
|
78 |
else:
|
79 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
80 |
embeddings = embeddings + token_type_embeddings
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import \
|
9 |
+
create_position_ids_from_input_ids
|
|
|
|
|
10 |
|
11 |
|
12 |
class XLMRobertaEmbeddings(nn.Module):
|
|
|
36 |
max_position_embeddings, embed_dim, **factory_kwargs
|
37 |
)
|
38 |
if self.type_vocab_size > 0:
|
39 |
+
self.token_type_embeddings = nn.Embedding(
|
40 |
+
type_vocab_size, embed_dim, **factory_kwargs
|
41 |
+
)
|
42 |
|
43 |
+
def forward(
|
44 |
+
self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
|
45 |
+
):
|
46 |
"""
|
47 |
input_ids: (batch, seqlen)
|
48 |
position_ids: (batch, seqlen)
|
49 |
token_type_ids: (batch, seqlen)
|
50 |
+
adapter_mask: (batch, 1)
|
51 |
"""
|
52 |
batch_size, seqlen = input_ids.shape
|
53 |
if adapter_mask is not None:
|
54 |
unique_tasks = torch.unique(adapter_mask)
|
55 |
embedding_dtype = next(self.word_embeddings.parameters()).dtype
|
56 |
+
embeddings = torch.empty(
|
57 |
+
*input_ids.shape,
|
58 |
+
self.word_embeddings.embedding_dim,
|
59 |
+
dtype=embedding_dtype,
|
60 |
+
device=input_ids.device
|
61 |
+
)
|
62 |
for task_id in unique_tasks:
|
63 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
64 |
task_input_ids = input_ids[task_indices]
|
|
|
68 |
embeddings = self.word_embeddings(input_ids)
|
69 |
if self.max_position_embeddings > 0:
|
70 |
if position_ids is None:
|
71 |
+
position_ids = create_position_ids_from_input_ids(
|
72 |
+
input_ids, padding_idx=self.word_embeddings.padding_idx
|
73 |
+
).to(input_ids.device)
|
74 |
position_embeddings = self.position_embeddings(position_ids)
|
75 |
embeddings = embeddings + position_embeddings
|
76 |
if self.type_vocab_size > 0:
|
77 |
if token_type_ids is None:
|
78 |
+
token_type_ids = torch.zeros(
|
79 |
+
seqlen, dtype=torch.long, device=input_ids.device
|
80 |
+
)
|
81 |
|
82 |
if adapter_mask is not None:
|
83 |
unique_tasks = torch.unique(adapter_mask)
|
84 |
for task_id in unique_tasks:
|
85 |
+
task_token_type_embeddings = self.token_type_embeddings(
|
86 |
+
token_type_ids, task_id=task_id
|
87 |
+
)
|
88 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
89 |
+
embeddings[task_indices] = (
|
90 |
+
embeddings[task_indices] + task_token_type_embeddings
|
91 |
+
)
|
92 |
else:
|
93 |
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
94 |
embeddings = embeddings + token_type_embeddings
|
mha.py
CHANGED
@@ -1,5 +1,8 @@
|
|
|
|
|
|
|
|
|
|
1 |
# Copyright (c) 2023, Tri Dao.
|
2 |
-
# Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
|
3 |
|
4 |
import math
|
5 |
from functools import partial
|
@@ -9,20 +12,19 @@ import torch.nn as nn
|
|
9 |
from einops import rearrange, repeat
|
10 |
|
11 |
try:
|
12 |
-
from flash_attn import (
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
flash_attn_with_kvcache,
|
18 |
-
)
|
19 |
except ImportError:
|
20 |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
21 |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
22 |
flash_attn_with_kvcache = None
|
23 |
|
24 |
try:
|
25 |
-
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense,
|
|
|
26 |
except ImportError:
|
27 |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
28 |
|
@@ -42,7 +44,9 @@ def get_alibi_slopes(nheads):
|
|
42 |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
43 |
return (
|
44 |
get_slopes_power_of_2(closest_power_of_2)
|
45 |
-
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
|
|
|
|
46 |
)
|
47 |
|
48 |
|
@@ -67,7 +71,9 @@ class FlashSelfAttention(nn.Module):
|
|
67 |
deterministic=False,
|
68 |
):
|
69 |
super().__init__()
|
70 |
-
assert
|
|
|
|
|
71 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
72 |
self.causal = causal
|
73 |
self.softmax_scale = softmax_scale
|
@@ -147,7 +153,9 @@ class FlashCrossAttention(nn.Module):
|
|
147 |
deterministic=False,
|
148 |
):
|
149 |
super().__init__()
|
150 |
-
assert
|
|
|
|
|
151 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
152 |
self.causal = causal
|
153 |
self.softmax_scale = softmax_scale
|
@@ -313,7 +321,10 @@ class CrossAttention(nn.Module):
|
|
313 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
314 |
if key_padding_mask is not None:
|
315 |
padding_mask = torch.full(
|
316 |
-
(batch_size, seqlen_k),
|
|
|
|
|
|
|
317 |
)
|
318 |
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
319 |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
@@ -425,20 +436,26 @@ class MHA(nn.Module):
|
|
425 |
else:
|
426 |
alibi_slopes = None
|
427 |
if window_size != (-1, -1):
|
428 |
-
assert
|
|
|
|
|
429 |
|
430 |
self.num_heads = num_heads
|
431 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
432 |
assert (
|
433 |
self.num_heads % self.num_heads_kv == 0
|
434 |
), "num_heads must be divisible by num_heads_kv"
|
435 |
-
assert
|
|
|
|
|
436 |
self.head_dim = self.embed_dim // num_heads
|
437 |
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
438 |
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
439 |
|
440 |
if self.rotary_emb_dim > 0:
|
441 |
-
assert
|
|
|
|
|
442 |
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
443 |
self.rotary_emb = RotaryEmbedding(
|
444 |
self.rotary_emb_dim,
|
@@ -453,23 +470,33 @@ class MHA(nn.Module):
|
|
453 |
|
454 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
455 |
linear_resid_cls = (
|
456 |
-
LinearResidual
|
|
|
|
|
457 |
)
|
458 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
459 |
inner_attn_cls = (
|
460 |
-
partial(
|
|
|
|
|
461 |
if use_flash_attn
|
462 |
else SelfAttention
|
463 |
)
|
464 |
inner_cross_attn_cls = (
|
465 |
-
partial(
|
|
|
|
|
466 |
if use_flash_attn
|
467 |
else CrossAttention
|
468 |
)
|
469 |
if not self.cross_attn:
|
470 |
-
self.Wqkv = wqkv_cls(
|
|
|
|
|
471 |
else:
|
472 |
-
self.Wq = linear_cls(
|
|
|
|
|
473 |
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
474 |
if self.dwconv:
|
475 |
if self.num_heads_kv == self.num_heads:
|
@@ -480,7 +507,9 @@ class MHA(nn.Module):
|
|
480 |
self.dwconv_q = nn.Conv1d(
|
481 |
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
482 |
)
|
483 |
-
self.dwconv_kv = nn.Conv1d(
|
|
|
|
|
484 |
self.inner_attn = inner_attn_cls(
|
485 |
causal=causal,
|
486 |
softmax_scale=softmax_scale,
|
@@ -489,7 +518,9 @@ class MHA(nn.Module):
|
|
489 |
self.inner_cross_attn = inner_cross_attn_cls(
|
490 |
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
491 |
)
|
492 |
-
self.out_proj = linear_cls(
|
|
|
|
|
493 |
|
494 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
495 |
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
@@ -507,7 +538,9 @@ class MHA(nn.Module):
|
|
507 |
def _update_kv_cache(self, kv, inference_params):
|
508 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
509 |
assert not self.dwconv, "Generation does not support dwconv yet"
|
510 |
-
assert
|
|
|
|
|
511 |
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
512 |
|
513 |
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
@@ -523,7 +556,10 @@ class MHA(nn.Module):
|
|
523 |
self.rotary_emb._update_cos_sin_cache(
|
524 |
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
525 |
)
|
526 |
-
rotary_cos, rotary_sin =
|
|
|
|
|
|
|
527 |
else:
|
528 |
rotary_cos, rotary_sin = None, None
|
529 |
batch = q.shape[0]
|
@@ -545,7 +581,9 @@ class MHA(nn.Module):
|
|
545 |
cache_seqlens=cache_seqlens,
|
546 |
softmax_scale=self.inner_cross_attn.softmax_scale,
|
547 |
causal=self.inner_cross_attn.causal,
|
548 |
-
rotary_interleaved=
|
|
|
|
|
549 |
alibi_slopes=alibi_slopes,
|
550 |
)
|
551 |
return context
|
@@ -640,40 +678,49 @@ class MHA(nn.Module):
|
|
640 |
)
|
641 |
)
|
642 |
rotary_max_seqlen = (
|
643 |
-
inference_params.max_sequence_len
|
|
|
|
|
644 |
)
|
645 |
-
batch, seqlen = x.shape[:2]
|
646 |
-
lora_kwargs = {}
|
647 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
648 |
assert x_kv is None and mixer_subset is None
|
649 |
|
650 |
if adapter_mask is not None:
|
651 |
unique_tasks = torch.unique(adapter_mask)
|
652 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
653 |
-
qkv = torch.empty(
|
654 |
-
|
|
|
|
|
|
|
|
|
655 |
for task_id in unique_tasks:
|
656 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
657 |
task_tensor = x[task_indices]
|
658 |
if not self.return_residual:
|
659 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
660 |
else:
|
661 |
-
task_qkv, _ = self.Wqkv(
|
|
|
|
|
662 |
qkv[task_indices] = task_qkv
|
663 |
else:
|
664 |
if not self.return_residual:
|
665 |
qkv = self.Wqkv(x)
|
666 |
else:
|
667 |
-
if hasattr(self.Wqkv,
|
668 |
qkv, x = self.Wqkv(x, residual=True)
|
669 |
else:
|
670 |
qkv, x = self.Wqkv(x)
|
671 |
|
672 |
if self.dwconv:
|
673 |
qkv = rearrange(
|
674 |
-
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
|
|
|
675 |
).contiguous()
|
676 |
-
qkv = rearrange(
|
|
|
|
|
677 |
if (
|
678 |
inference_params is None
|
679 |
or inference_params.seqlen_offset == 0
|
@@ -691,7 +738,9 @@ class MHA(nn.Module):
|
|
691 |
if not self.checkpointing:
|
692 |
context = self.inner_attn(qkv, **kwargs)
|
693 |
else:
|
694 |
-
context = torch.utils.checkpoint.checkpoint(
|
|
|
|
|
695 |
else:
|
696 |
context = self._update_kvcache_attention(
|
697 |
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
@@ -720,13 +769,17 @@ class MHA(nn.Module):
|
|
720 |
q = qkv[..., : self.num_heads * self.head_dim]
|
721 |
kv = qkv[..., self.num_heads * self.head_dim :]
|
722 |
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
723 |
-
kv = rearrange(
|
|
|
|
|
724 |
if self.dwconv:
|
725 |
q = rearrange(
|
726 |
-
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
|
|
|
727 |
).contiguous()
|
728 |
kv = rearrange(
|
729 |
-
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
|
|
|
730 |
).contiguous()
|
731 |
if (
|
732 |
inference_params is None
|
@@ -752,14 +805,20 @@ class MHA(nn.Module):
|
|
752 |
else:
|
753 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
754 |
else:
|
755 |
-
context = self._apply_rotary_update_kvcache_attention(
|
|
|
|
|
756 |
|
757 |
inp = rearrange(context, "... h d -> ... (h d)")
|
758 |
if adapter_mask is not None:
|
759 |
unique_tasks = torch.unique(adapter_mask)
|
760 |
out_dtype = next(self.out_proj.parameters()).dtype
|
761 |
-
out = torch.empty(
|
762 |
-
|
|
|
|
|
|
|
|
|
763 |
for task_id in unique_tasks:
|
764 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
765 |
task_tensor = inp[task_indices]
|
|
|
1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
|
2 |
+
# Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
|
3 |
+
# Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
|
4 |
+
|
5 |
# Copyright (c) 2023, Tri Dao.
|
|
|
6 |
|
7 |
import math
|
8 |
from functools import partial
|
|
|
12 |
from einops import rearrange, repeat
|
13 |
|
14 |
try:
|
15 |
+
from flash_attn import (flash_attn_kvpacked_func,
|
16 |
+
flash_attn_qkvpacked_func,
|
17 |
+
flash_attn_varlen_kvpacked_func,
|
18 |
+
flash_attn_varlen_qkvpacked_func,
|
19 |
+
flash_attn_with_kvcache)
|
|
|
|
|
20 |
except ImportError:
|
21 |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
22 |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
23 |
flash_attn_with_kvcache = None
|
24 |
|
25 |
try:
|
26 |
+
from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
|
27 |
+
RowParallelLinear)
|
28 |
except ImportError:
|
29 |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
30 |
|
|
|
44 |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
|
45 |
return (
|
46 |
get_slopes_power_of_2(closest_power_of_2)
|
47 |
+
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][
|
48 |
+
: nheads - closest_power_of_2
|
49 |
+
]
|
50 |
)
|
51 |
|
52 |
|
|
|
71 |
deterministic=False,
|
72 |
):
|
73 |
super().__init__()
|
74 |
+
assert (
|
75 |
+
flash_attn_varlen_qkvpacked_func is not None
|
76 |
+
), "FlashAttention is not installed"
|
77 |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
78 |
self.causal = causal
|
79 |
self.softmax_scale = softmax_scale
|
|
|
153 |
deterministic=False,
|
154 |
):
|
155 |
super().__init__()
|
156 |
+
assert (
|
157 |
+
flash_attn_varlen_kvpacked_func is not None
|
158 |
+
), "FlashAttention is not installed"
|
159 |
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
160 |
self.causal = causal
|
161 |
self.softmax_scale = softmax_scale
|
|
|
321 |
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
322 |
if key_padding_mask is not None:
|
323 |
padding_mask = torch.full(
|
324 |
+
(batch_size, seqlen_k),
|
325 |
+
-10000.0,
|
326 |
+
dtype=scores.dtype,
|
327 |
+
device=scores.device,
|
328 |
)
|
329 |
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
330 |
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
|
|
436 |
else:
|
437 |
alibi_slopes = None
|
438 |
if window_size != (-1, -1):
|
439 |
+
assert (
|
440 |
+
use_flash_attn
|
441 |
+
), "Local (sliding window) attention code path requires flash_attn"
|
442 |
|
443 |
self.num_heads = num_heads
|
444 |
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
445 |
assert (
|
446 |
self.num_heads % self.num_heads_kv == 0
|
447 |
), "num_heads must be divisible by num_heads_kv"
|
448 |
+
assert (
|
449 |
+
self.embed_dim % num_heads == 0
|
450 |
+
), "embed_dim must be divisible by num_heads"
|
451 |
self.head_dim = self.embed_dim // num_heads
|
452 |
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
453 |
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
454 |
|
455 |
if self.rotary_emb_dim > 0:
|
456 |
+
assert (
|
457 |
+
not cross_attn
|
458 |
+
), "MHA with rotary embedding does not support cross-attention yet"
|
459 |
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
460 |
self.rotary_emb = RotaryEmbedding(
|
461 |
self.rotary_emb_dim,
|
|
|
470 |
|
471 |
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
472 |
linear_resid_cls = (
|
473 |
+
LinearResidual
|
474 |
+
if not fused_bias_fc
|
475 |
+
else partial(FusedDense, return_residual=True)
|
476 |
)
|
477 |
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
478 |
inner_attn_cls = (
|
479 |
+
partial(
|
480 |
+
FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
|
481 |
+
)
|
482 |
if use_flash_attn
|
483 |
else SelfAttention
|
484 |
)
|
485 |
inner_cross_attn_cls = (
|
486 |
+
partial(
|
487 |
+
FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
|
488 |
+
)
|
489 |
if use_flash_attn
|
490 |
else CrossAttention
|
491 |
)
|
492 |
if not self.cross_attn:
|
493 |
+
self.Wqkv = wqkv_cls(
|
494 |
+
embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
|
495 |
+
)
|
496 |
else:
|
497 |
+
self.Wq = linear_cls(
|
498 |
+
embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
|
499 |
+
)
|
500 |
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
501 |
if self.dwconv:
|
502 |
if self.num_heads_kv == self.num_heads:
|
|
|
507 |
self.dwconv_q = nn.Conv1d(
|
508 |
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
509 |
)
|
510 |
+
self.dwconv_kv = nn.Conv1d(
|
511 |
+
kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
|
512 |
+
)
|
513 |
self.inner_attn = inner_attn_cls(
|
514 |
causal=causal,
|
515 |
softmax_scale=softmax_scale,
|
|
|
518 |
self.inner_cross_attn = inner_cross_attn_cls(
|
519 |
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
520 |
)
|
521 |
+
self.out_proj = linear_cls(
|
522 |
+
embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
|
523 |
+
)
|
524 |
|
525 |
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
526 |
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
|
|
538 |
def _update_kv_cache(self, kv, inference_params):
|
539 |
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
540 |
assert not self.dwconv, "Generation does not support dwconv yet"
|
541 |
+
assert (
|
542 |
+
self.layer_idx is not None
|
543 |
+
), "Generation requires layer_idx in the constructor"
|
544 |
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
545 |
|
546 |
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
|
|
556 |
self.rotary_emb._update_cos_sin_cache(
|
557 |
inference_params.max_seqlen, device=q.device, dtype=q.dtype
|
558 |
)
|
559 |
+
rotary_cos, rotary_sin = (
|
560 |
+
self.rotary_emb._cos_cached,
|
561 |
+
self.rotary_emb._sin_cached,
|
562 |
+
)
|
563 |
else:
|
564 |
rotary_cos, rotary_sin = None, None
|
565 |
batch = q.shape[0]
|
|
|
581 |
cache_seqlens=cache_seqlens,
|
582 |
softmax_scale=self.inner_cross_attn.softmax_scale,
|
583 |
causal=self.inner_cross_attn.causal,
|
584 |
+
rotary_interleaved=(
|
585 |
+
self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
|
586 |
+
),
|
587 |
alibi_slopes=alibi_slopes,
|
588 |
)
|
589 |
return context
|
|
|
678 |
)
|
679 |
)
|
680 |
rotary_max_seqlen = (
|
681 |
+
inference_params.max_sequence_len
|
682 |
+
if inference_params is not None
|
683 |
+
else max_seqlen
|
684 |
)
|
|
|
|
|
685 |
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
686 |
assert x_kv is None and mixer_subset is None
|
687 |
|
688 |
if adapter_mask is not None:
|
689 |
unique_tasks = torch.unique(adapter_mask)
|
690 |
qkv_dtype = next(self.Wqkv.parameters()).dtype
|
691 |
+
qkv = torch.empty(
|
692 |
+
*x.shape[:-1],
|
693 |
+
self.Wqkv.out_features,
|
694 |
+
dtype=qkv_dtype,
|
695 |
+
device=x.device,
|
696 |
+
)
|
697 |
for task_id in unique_tasks:
|
698 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
699 |
task_tensor = x[task_indices]
|
700 |
if not self.return_residual:
|
701 |
task_qkv = self.Wqkv(task_tensor, task_id=task_id)
|
702 |
else:
|
703 |
+
task_qkv, _ = self.Wqkv(
|
704 |
+
task_tensor, task_id=task_id, residual=True
|
705 |
+
)
|
706 |
qkv[task_indices] = task_qkv
|
707 |
else:
|
708 |
if not self.return_residual:
|
709 |
qkv = self.Wqkv(x)
|
710 |
else:
|
711 |
+
if hasattr(self.Wqkv, "parametrizations"):
|
712 |
qkv, x = self.Wqkv(x, residual=True)
|
713 |
else:
|
714 |
qkv, x = self.Wqkv(x)
|
715 |
|
716 |
if self.dwconv:
|
717 |
qkv = rearrange(
|
718 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
|
719 |
+
"b d s -> b s d",
|
720 |
).contiguous()
|
721 |
+
qkv = rearrange(
|
722 |
+
qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
|
723 |
+
)
|
724 |
if (
|
725 |
inference_params is None
|
726 |
or inference_params.seqlen_offset == 0
|
|
|
738 |
if not self.checkpointing:
|
739 |
context = self.inner_attn(qkv, **kwargs)
|
740 |
else:
|
741 |
+
context = torch.utils.checkpoint.checkpoint(
|
742 |
+
self.inner_attn, qkv, **kwargs
|
743 |
+
)
|
744 |
else:
|
745 |
context = self._update_kvcache_attention(
|
746 |
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
|
|
769 |
q = qkv[..., : self.num_heads * self.head_dim]
|
770 |
kv = qkv[..., self.num_heads * self.head_dim :]
|
771 |
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
772 |
+
kv = rearrange(
|
773 |
+
kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
|
774 |
+
)
|
775 |
if self.dwconv:
|
776 |
q = rearrange(
|
777 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
|
778 |
+
"b d s -> b s d",
|
779 |
).contiguous()
|
780 |
kv = rearrange(
|
781 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
|
782 |
+
"b d s -> b s d",
|
783 |
).contiguous()
|
784 |
if (
|
785 |
inference_params is None
|
|
|
805 |
else:
|
806 |
context = self._update_kvcache_attention(q, kv, inference_params)
|
807 |
else:
|
808 |
+
context = self._apply_rotary_update_kvcache_attention(
|
809 |
+
q, kv, inference_params
|
810 |
+
)
|
811 |
|
812 |
inp = rearrange(context, "... h d -> ... (h d)")
|
813 |
if adapter_mask is not None:
|
814 |
unique_tasks = torch.unique(adapter_mask)
|
815 |
out_dtype = next(self.out_proj.parameters()).dtype
|
816 |
+
out = torch.empty(
|
817 |
+
*inp.shape[:-1],
|
818 |
+
self.out_proj.out_features,
|
819 |
+
dtype=out_dtype,
|
820 |
+
device=inp.device,
|
821 |
+
)
|
822 |
for task_id in unique_tasks:
|
823 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
824 |
task_tensor = inp[task_indices]
|
mlp.py
CHANGED
@@ -8,14 +8,14 @@ import torch.nn as nn
|
|
8 |
import torch.nn.functional as F
|
9 |
from torch.distributed import ProcessGroup
|
10 |
|
11 |
-
|
12 |
try:
|
13 |
from flash_attn.ops.activations import swiglu
|
14 |
except ImportError:
|
15 |
swiglu = None
|
16 |
|
17 |
try:
|
18 |
-
from flash_attn.ops.fused_dense import ColumnParallelLinear,
|
|
|
19 |
except ImportError:
|
20 |
ColumnParallelLinear, RowParallelLinear = None, None
|
21 |
|
@@ -41,18 +41,23 @@ class Mlp(nn.Module):
|
|
41 |
factory_kwargs = {"device": device, "dtype": dtype}
|
42 |
super().__init__()
|
43 |
out_features = out_features if out_features is not None else in_features
|
44 |
-
hidden_features =
|
|
|
|
|
45 |
self.return_residual = return_residual
|
46 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
47 |
self.activation = activation
|
48 |
-
self.fc2 = nn.Linear(
|
|
|
|
|
49 |
|
50 |
def forward(self, x, adapter_mask=None):
|
51 |
if adapter_mask is not None:
|
52 |
unique_tasks = torch.unique(adapter_mask)
|
53 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
54 |
-
y = torch.empty(
|
55 |
-
|
|
|
56 |
for task_id in unique_tasks:
|
57 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
58 |
task_tensor = x[task_indices]
|
@@ -66,8 +71,9 @@ class Mlp(nn.Module):
|
|
66 |
if adapter_mask is not None:
|
67 |
unique_tasks = torch.unique(adapter_mask)
|
68 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
69 |
-
out = torch.empty(
|
70 |
-
|
|
|
71 |
for task_id in unique_tasks:
|
72 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
73 |
task_tensor = y[task_indices]
|
@@ -98,7 +104,9 @@ class ParallelMLP(nn.Module):
|
|
98 |
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
99 |
assert RowParallelLinear is not None, "Need to install fused_dense"
|
100 |
out_features = out_features if out_features is not None else in_features
|
101 |
-
hidden_features =
|
|
|
|
|
102 |
self.fc1 = ColumnParallelLinear(
|
103 |
in_features,
|
104 |
hidden_features,
|
@@ -144,17 +152,25 @@ class GatedMlp(nn.Module):
|
|
144 |
hidden_features = (
|
145 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
146 |
)
|
147 |
-
hidden_features = (
|
|
|
|
|
148 |
self.return_residual = return_residual
|
149 |
-
self.fc1 = nn.Linear(
|
|
|
|
|
150 |
self.activation = activation
|
151 |
-
self.fc2 = nn.Linear(
|
|
|
|
|
152 |
|
153 |
def forward(self, x):
|
154 |
y = self.fc1(x)
|
155 |
if self.activation == F.sigmoid: # Special case for GLU
|
156 |
y = F.glu(y, dim=-1)
|
157 |
-
elif
|
|
|
|
|
158 |
y, gate = y.chunk(2, dim=-1)
|
159 |
y = swiglu(gate, y)
|
160 |
else:
|
@@ -187,7 +203,9 @@ class ParallelGatedMlp(nn.Module):
|
|
187 |
hidden_features = (
|
188 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
189 |
)
|
190 |
-
hidden_features = (
|
|
|
|
|
191 |
if ColumnParallelLinear is None or RowParallelLinear is None:
|
192 |
raise ImportError("fused_dense is not installed")
|
193 |
self.fc1 = ColumnParallelLinear(
|
@@ -216,4 +234,4 @@ class ParallelGatedMlp(nn.Module):
|
|
216 |
y, gate = y.chunk(2, dim=-1)
|
217 |
y = y * self.activation(gate)
|
218 |
y = self.fc2(y)
|
219 |
-
return y
|
|
|
8 |
import torch.nn.functional as F
|
9 |
from torch.distributed import ProcessGroup
|
10 |
|
|
|
11 |
try:
|
12 |
from flash_attn.ops.activations import swiglu
|
13 |
except ImportError:
|
14 |
swiglu = None
|
15 |
|
16 |
try:
|
17 |
+
from flash_attn.ops.fused_dense import (ColumnParallelLinear,
|
18 |
+
RowParallelLinear)
|
19 |
except ImportError:
|
20 |
ColumnParallelLinear, RowParallelLinear = None, None
|
21 |
|
|
|
41 |
factory_kwargs = {"device": device, "dtype": dtype}
|
42 |
super().__init__()
|
43 |
out_features = out_features if out_features is not None else in_features
|
44 |
+
hidden_features = (
|
45 |
+
hidden_features if hidden_features is not None else in_features * 4
|
46 |
+
)
|
47 |
self.return_residual = return_residual
|
48 |
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
49 |
self.activation = activation
|
50 |
+
self.fc2 = nn.Linear(
|
51 |
+
hidden_features, out_features, bias=bias2, **factory_kwargs
|
52 |
+
)
|
53 |
|
54 |
def forward(self, x, adapter_mask=None):
|
55 |
if adapter_mask is not None:
|
56 |
unique_tasks = torch.unique(adapter_mask)
|
57 |
fc1_dtype = next(self.fc1.parameters()).dtype
|
58 |
+
y = torch.empty(
|
59 |
+
*x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
|
60 |
+
)
|
61 |
for task_id in unique_tasks:
|
62 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
63 |
task_tensor = x[task_indices]
|
|
|
71 |
if adapter_mask is not None:
|
72 |
unique_tasks = torch.unique(adapter_mask)
|
73 |
fc2_dtype = next(self.fc2.parameters()).dtype
|
74 |
+
out = torch.empty(
|
75 |
+
*y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
|
76 |
+
)
|
77 |
for task_id in unique_tasks:
|
78 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
79 |
task_tensor = y[task_indices]
|
|
|
104 |
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
105 |
assert RowParallelLinear is not None, "Need to install fused_dense"
|
106 |
out_features = out_features if out_features is not None else in_features
|
107 |
+
hidden_features = (
|
108 |
+
hidden_features if hidden_features is not None else in_features * 4
|
109 |
+
)
|
110 |
self.fc1 = ColumnParallelLinear(
|
111 |
in_features,
|
112 |
hidden_features,
|
|
|
152 |
hidden_features = (
|
153 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
154 |
)
|
155 |
+
hidden_features = (
|
156 |
+
(hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
157 |
+
)
|
158 |
self.return_residual = return_residual
|
159 |
+
self.fc1 = nn.Linear(
|
160 |
+
in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
|
161 |
+
)
|
162 |
self.activation = activation
|
163 |
+
self.fc2 = nn.Linear(
|
164 |
+
hidden_features, out_features, bias=bias2, **factory_kwargs
|
165 |
+
)
|
166 |
|
167 |
def forward(self, x):
|
168 |
y = self.fc1(x)
|
169 |
if self.activation == F.sigmoid: # Special case for GLU
|
170 |
y = F.glu(y, dim=-1)
|
171 |
+
elif (
|
172 |
+
self.activation == F.silu and swiglu is not None
|
173 |
+
): # Special case for SwiGLU
|
174 |
y, gate = y.chunk(2, dim=-1)
|
175 |
y = swiglu(gate, y)
|
176 |
else:
|
|
|
203 |
hidden_features = (
|
204 |
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
205 |
)
|
206 |
+
hidden_features = (
|
207 |
+
(hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
208 |
+
)
|
209 |
if ColumnParallelLinear is None or RowParallelLinear is None:
|
210 |
raise ImportError("fused_dense is not installed")
|
211 |
self.fc1 = ColumnParallelLinear(
|
|
|
234 |
y, gate = y.chunk(2, dim=-1)
|
235 |
y = y * self.activation(gate)
|
236 |
y = self.fc2(y)
|
237 |
+
return y
|
modeling_lora.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import math
|
2 |
import os
|
3 |
-
import warnings
|
4 |
from functools import partial
|
5 |
from typing import Iterator, List, Optional, Tuple, Union
|
6 |
|
@@ -12,7 +11,8 @@ from torch.nn import Parameter
|
|
12 |
from torch.nn import functional as F
|
13 |
from transformers import PretrainedConfig
|
14 |
|
15 |
-
from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel,
|
|
|
16 |
|
17 |
|
18 |
def initialized_weights(
|
@@ -177,7 +177,9 @@ class LoRAParametrization(nn.Module):
|
|
177 |
|
178 |
def new_forward(self, input, task_id=None, residual=False):
|
179 |
if task_id is not None:
|
180 |
-
weights = self.parametrizations.weight[0].lora_forward(
|
|
|
|
|
181 |
else:
|
182 |
weights = self.weight
|
183 |
|
@@ -204,13 +206,21 @@ class LoRAParametrization(nn.Module):
|
|
204 |
|
205 |
def new_forward(self, input, task_id=None):
|
206 |
if task_id is not None:
|
207 |
-
weights = self.parametrizations.weight[0].lora_forward(
|
|
|
|
|
208 |
else:
|
209 |
weights = self.weight
|
210 |
|
211 |
out = F.embedding(
|
212 |
-
input,
|
213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
return out
|
216 |
|
@@ -219,9 +229,7 @@ class LoRAParametrization(nn.Module):
|
|
219 |
|
220 |
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
221 |
def __init__(
|
222 |
-
self,
|
223 |
-
config: XLMRobertaFlashConfig,
|
224 |
-
roberta: Optional[XLMRobertaModel] = None
|
225 |
):
|
226 |
super().__init__(config)
|
227 |
if roberta is None:
|
@@ -235,7 +243,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
235 |
or len(self._lora_adaptations) < 1
|
236 |
):
|
237 |
raise ValueError(
|
238 |
-
f
|
239 |
)
|
240 |
self._lora_prompts = config.lora_prompts
|
241 |
if (
|
@@ -244,9 +252,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
244 |
or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
|
245 |
):
|
246 |
raise ValueError(
|
247 |
-
f
|
248 |
-
f
|
249 |
-
|
250 |
self._adaptation_map = {
|
251 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
252 |
}
|
@@ -261,7 +269,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
261 |
)
|
262 |
self.main_params_trainable = config.lora_main_params_trainable
|
263 |
|
264 |
-
|
265 |
@property
|
266 |
def rotary_emb_base(self):
|
267 |
return self.roberta.rotary_emb_base
|
@@ -305,13 +312,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
305 |
config = XLMRobertaFlashConfig.from_pretrained(
|
306 |
pretrained_model_name_or_path, *model_args, **kwargs
|
307 |
)
|
308 |
-
|
309 |
if config.load_trained_adapters:
|
310 |
return super().from_pretrained(
|
311 |
pretrained_model_name_or_path, *model_args, **kwargs
|
312 |
)
|
313 |
else:
|
314 |
-
roberta = XLMRobertaModel.from_pretrained(
|
|
|
|
|
315 |
return cls(config, roberta=roberta)
|
316 |
|
317 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
@@ -367,5 +375,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
367 |
if task_type:
|
368 |
task_id = self._adaptation_map[task_type]
|
369 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
370 |
-
adapter_mask = torch.full(
|
371 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import math
|
2 |
import os
|
|
|
3 |
from functools import partial
|
4 |
from typing import Iterator, List, Optional, Tuple, Union
|
5 |
|
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
+
from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
|
15 |
+
XLMRobertaPreTrainedModel)
|
16 |
|
17 |
|
18 |
def initialized_weights(
|
|
|
177 |
|
178 |
def new_forward(self, input, task_id=None, residual=False):
|
179 |
if task_id is not None:
|
180 |
+
weights = self.parametrizations.weight[0].lora_forward(
|
181 |
+
self.weight, current_task=task_id
|
182 |
+
)
|
183 |
else:
|
184 |
weights = self.weight
|
185 |
|
|
|
206 |
|
207 |
def new_forward(self, input, task_id=None):
|
208 |
if task_id is not None:
|
209 |
+
weights = self.parametrizations.weight[0].lora_forward(
|
210 |
+
self.weight, current_task=task_id
|
211 |
+
)
|
212 |
else:
|
213 |
weights = self.weight
|
214 |
|
215 |
out = F.embedding(
|
216 |
+
input,
|
217 |
+
weights,
|
218 |
+
self.padding_idx,
|
219 |
+
self.max_norm,
|
220 |
+
self.norm_type,
|
221 |
+
self.scale_grad_by_freq,
|
222 |
+
self.sparse,
|
223 |
+
)
|
224 |
|
225 |
return out
|
226 |
|
|
|
229 |
|
230 |
class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
231 |
def __init__(
|
232 |
+
self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
|
|
|
|
|
233 |
):
|
234 |
super().__init__(config)
|
235 |
if roberta is None:
|
|
|
243 |
or len(self._lora_adaptations) < 1
|
244 |
):
|
245 |
raise ValueError(
|
246 |
+
f"`lora_adaptations` must be a list and contain at least one element"
|
247 |
)
|
248 |
self._lora_prompts = config.lora_prompts
|
249 |
if (
|
|
|
252 |
or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
|
253 |
):
|
254 |
raise ValueError(
|
255 |
+
f"`lora_prompts` must be a dict and contain the same number of elements "
|
256 |
+
f"as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`."
|
257 |
+
)
|
258 |
self._adaptation_map = {
|
259 |
name: idx for idx, name in enumerate(self._lora_adaptations)
|
260 |
}
|
|
|
269 |
)
|
270 |
self.main_params_trainable = config.lora_main_params_trainable
|
271 |
|
|
|
272 |
@property
|
273 |
def rotary_emb_base(self):
|
274 |
return self.roberta.rotary_emb_base
|
|
|
312 |
config = XLMRobertaFlashConfig.from_pretrained(
|
313 |
pretrained_model_name_or_path, *model_args, **kwargs
|
314 |
)
|
|
|
315 |
if config.load_trained_adapters:
|
316 |
return super().from_pretrained(
|
317 |
pretrained_model_name_or_path, *model_args, **kwargs
|
318 |
)
|
319 |
else:
|
320 |
+
roberta = XLMRobertaModel.from_pretrained(
|
321 |
+
pretrained_model_name_or_path, *model_args, **kwargs
|
322 |
+
)
|
323 |
return cls(config, roberta=roberta)
|
324 |
|
325 |
def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
|
|
|
375 |
if task_type:
|
376 |
task_id = self._adaptation_map[task_type]
|
377 |
num_examples = 1 if isinstance(sentences, str) else len(sentences)
|
378 |
+
adapter_mask = torch.full(
|
379 |
+
(num_examples,), task_id, dtype=torch.int32, device=self.device
|
380 |
+
)
|
381 |
+
return self.roberta.encode(
|
382 |
+
sentences, *args, adapter_mask=adapter_mask, **kwargs
|
383 |
+
)
|
modeling_xlm_roberta.py
CHANGED
@@ -13,39 +13,29 @@ import re
|
|
13 |
from collections import OrderedDict
|
14 |
from collections.abc import Sequence
|
15 |
from functools import partial
|
16 |
-
import
|
17 |
|
|
|
18 |
import torch
|
19 |
import torch.nn as nn
|
20 |
import torch.nn.functional as F
|
21 |
import torch.utils.checkpoint
|
22 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
23 |
-
from
|
24 |
-
from transformers import
|
|
|
25 |
from transformers.modeling_utils import PreTrainedModel
|
26 |
-
from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
|
27 |
-
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
28 |
-
|
29 |
from transformers.models.bert.modeling_bert import (
|
30 |
-
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
-
|
32 |
-
|
33 |
|
34 |
-
from typing import List, Optional, Tuple, Union
|
35 |
-
|
36 |
-
from .xlm_padding import (
|
37 |
-
index_first_axis,
|
38 |
-
index_first_axis_residual,
|
39 |
-
pad_input,
|
40 |
-
unpad_input,
|
41 |
-
)
|
42 |
-
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
43 |
from .block import Block
|
|
|
44 |
from .embedding import XLMRobertaEmbeddings
|
45 |
from .mha import MHA
|
46 |
from .mlp import FusedMLP, Mlp
|
47 |
-
from .
|
48 |
-
from .rotary import RotaryEmbedding
|
49 |
|
50 |
try:
|
51 |
from flash_attn.ops.fused_dense import FusedDense
|
@@ -79,7 +69,7 @@ def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
|
79 |
return False
|
80 |
if importlib.util.find_spec("flash_attn") is None:
|
81 |
logger.warning(
|
82 |
-
|
83 |
)
|
84 |
return False
|
85 |
return True
|
@@ -109,7 +99,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
109 |
fused_bias_fc=fused_bias_fc,
|
110 |
use_flash_attn=use_flash_attn,
|
111 |
return_residual=return_residual,
|
112 |
-
use_alibi=config.position_embedding_type ==
|
113 |
**rotary_kwargs,
|
114 |
)
|
115 |
return mixer_cls
|
@@ -204,15 +194,17 @@ class XLMRobertaEncoder(nn.Module):
|
|
204 |
def gradient_checkpointing(self, value):
|
205 |
self._grad_checkpointing = value
|
206 |
|
207 |
-
def forward(
|
|
|
|
|
208 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
209 |
This means that we only compute the last layer output for these tokens.
|
210 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
211 |
"""
|
212 |
if key_padding_mask is None or not self.use_flash_attn:
|
213 |
-
mixer_kwargs = {
|
214 |
if key_padding_mask is not None:
|
215 |
-
mixer_kwargs[
|
216 |
for layer in self.layers:
|
217 |
if self._grad_checkpointing:
|
218 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
@@ -227,10 +219,14 @@ class XLMRobertaEncoder(nn.Module):
|
|
227 |
hidden_states = hidden_states[subset_mask]
|
228 |
else:
|
229 |
batch, seqlen = hidden_states.shape[:2]
|
230 |
-
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask =
|
231 |
-
hidden_states, key_padding_mask, adapter_mask
|
232 |
)
|
233 |
-
mixer_kwargs = {
|
|
|
|
|
|
|
|
|
234 |
|
235 |
if subset_mask is None:
|
236 |
for layer in self.layers:
|
@@ -315,12 +311,18 @@ class XLMRobertaPooler(nn.Module):
|
|
315 |
if adapter_mask is not None:
|
316 |
unique_tasks = torch.unique(adapter_mask)
|
317 |
pool_dtype = next(self.dense.parameters()).dtype
|
318 |
-
pooled_output = torch.empty(
|
319 |
-
|
|
|
|
|
|
|
|
|
320 |
for task_id in unique_tasks:
|
321 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
322 |
task_first_token_tensor = first_token_tensor[task_indices]
|
323 |
-
task_pooled_output = self.dense(
|
|
|
|
|
324 |
pooled_output[task_indices] = task_pooled_output
|
325 |
else:
|
326 |
pooled_output = self.dense(first_token_tensor)
|
@@ -413,12 +415,11 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
|
|
413 |
*args,
|
414 |
**kwargs,
|
415 |
):
|
416 |
-
if not
|
417 |
-
kwargs[
|
418 |
return super().from_pretrained(*args, **kwargs)
|
419 |
|
420 |
|
421 |
-
|
422 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
423 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
424 |
super().__init__(config)
|
@@ -439,7 +440,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
439 |
self.embeddings = XLMRobertaEmbeddings(
|
440 |
config.hidden_size,
|
441 |
config.vocab_size,
|
442 |
-
|
|
|
|
|
|
|
|
|
443 |
config.type_vocab_size,
|
444 |
padding_idx=config.pad_token_id,
|
445 |
)
|
@@ -449,16 +454,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
449 |
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
450 |
|
451 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
452 |
-
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
453 |
self._rotary_emb_base = config.rotary_emb_base
|
454 |
|
455 |
@torch.inference_mode()
|
456 |
def encode(
|
457 |
-
self:
|
458 |
sentences: Union[str, List[str]],
|
459 |
batch_size: int = 32,
|
460 |
show_progress_bar: Optional[bool] = None,
|
461 |
-
output_value: str =
|
462 |
convert_to_numpy: bool = True,
|
463 |
convert_to_tensor: bool = False,
|
464 |
device: Optional[torch.device] = None,
|
@@ -516,12 +523,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
516 |
if convert_to_tensor:
|
517 |
convert_to_numpy = False
|
518 |
|
519 |
-
if output_value !=
|
520 |
convert_to_tensor = False
|
521 |
convert_to_numpy = False
|
522 |
|
523 |
input_was_string = False
|
524 |
-
if isinstance(sentences, str) or not hasattr(sentences,
|
525 |
sentences = [sentences]
|
526 |
input_was_string = True
|
527 |
|
@@ -532,11 +539,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
532 |
inverse_permutation = np.argsort(permutation)
|
533 |
sentences = [sentences[idx] for idx in permutation]
|
534 |
|
535 |
-
tokenizer_kwargs[
|
536 |
-
tokenizer_kwargs[
|
537 |
-
|
538 |
)
|
539 |
-
tokenizer_kwargs[
|
540 |
|
541 |
all_embeddings = []
|
542 |
|
@@ -550,11 +557,13 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
550 |
)
|
551 |
else:
|
552 |
range_iter = range(0, len(sentences), batch_size)
|
553 |
-
lora_arguments =
|
|
|
|
|
554 |
for i in range_iter:
|
555 |
encoded_input = self.tokenizer(
|
556 |
sentences[i : i + batch_size],
|
557 |
-
return_tensors=
|
558 |
**tokenizer_kwargs,
|
559 |
).to(self.device)
|
560 |
token_embs = self.forward(**encoded_input, **lora_arguments)[0]
|
@@ -562,18 +571,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
562 |
# Accumulate in fp32 to avoid overflow
|
563 |
token_embs = token_embs.float()
|
564 |
|
565 |
-
if output_value ==
|
566 |
raise NotImplementedError
|
567 |
elif output_value is None:
|
568 |
raise NotImplementedError
|
569 |
else:
|
570 |
-
if self.config.emb_pooler ==
|
571 |
embeddings = self.cls_pooling(
|
572 |
-
token_embs, encoded_input[
|
573 |
)
|
574 |
else:
|
575 |
embeddings = self.mean_pooling(
|
576 |
-
token_embs, encoded_input[
|
577 |
)
|
578 |
|
579 |
if normalize_embeddings:
|
@@ -603,14 +612,16 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
603 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
604 |
if not self.config.matryoshka_dimensions:
|
605 |
logger.warning(
|
606 |
-
|
607 |
)
|
608 |
return embeddings
|
609 |
elif truncate_dim in self.config.matryoshka_dimensions:
|
610 |
return [tensor[:truncate_dim] for tensor in embeddings]
|
611 |
else:
|
612 |
-
raise ValueError(
|
613 |
-
|
|
|
|
|
614 |
|
615 |
def mean_pooling(
|
616 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
@@ -622,10 +633,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
622 |
input_mask_expanded.sum(1), min=1e-9
|
623 |
)
|
624 |
|
625 |
-
def cls_pooling(
|
626 |
-
|
627 |
-
):
|
628 |
-
return token_embeddings[:,0]
|
629 |
|
630 |
@property
|
631 |
def rotary_emb_base(self):
|
@@ -635,7 +644,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
635 |
def rotary_emb_base(self, base):
|
636 |
if not isinstance(base, (int, float)):
|
637 |
raise TypeError("Base must be an integer or float")
|
638 |
-
logger.info(f
|
639 |
for layer in self.encoder.layers:
|
640 |
layer.mixer.rotary_emb.base = base
|
641 |
self._rotary_emb_base = base
|
@@ -655,12 +664,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
655 |
layer output for these tokens.
|
656 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
657 |
"""
|
658 |
-
adapter_mask = kwargs.pop(
|
659 |
if kwargs:
|
660 |
for key, value in kwargs.items():
|
661 |
if value is not None:
|
662 |
logger.warning(
|
663 |
-
|
664 |
key,
|
665 |
)
|
666 |
|
@@ -669,7 +678,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
669 |
)
|
670 |
|
671 |
hidden_states = self.embeddings(
|
672 |
-
input_ids,
|
|
|
|
|
|
|
673 |
)
|
674 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
675 |
# BERT puts embedding LayerNorm before embedding dropout.
|
@@ -693,12 +705,17 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
693 |
subset_mask = None
|
694 |
|
695 |
sequence_output = self.encoder(
|
696 |
-
hidden_states,
|
|
|
|
|
|
|
697 |
)
|
698 |
|
699 |
if masked_tokens_mask is None:
|
700 |
pooled_output = (
|
701 |
-
self.pooler(sequence_output, adapter_mask=adapter_mask)
|
|
|
|
|
702 |
)
|
703 |
else:
|
704 |
# TD [2022-03-01]: the indexing here is very tricky.
|
@@ -712,7 +729,9 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
|
712 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
713 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
714 |
pooled_output = (
|
715 |
-
self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
|
|
|
|
|
716 |
)
|
717 |
|
718 |
if not return_dict:
|
@@ -817,103 +836,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
|
817 |
)
|
818 |
|
819 |
|
820 |
-
# class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
|
821 |
-
# def __init__(self, config: XLMRobertaFlashConfig):
|
822 |
-
# super().__init__(config)
|
823 |
-
# # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
|
824 |
-
# # (around 15%) to the classifier heads.
|
825 |
-
# self.dense_seq_output = getattr(config, "dense_seq_output", False)
|
826 |
-
# # If last_layer_subset, we only need the compute the last layer for a subset of tokens
|
827 |
-
# # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
|
828 |
-
# self.last_layer_subset = getattr(config, "last_layer_subset", False)
|
829 |
-
# if self.last_layer_subset:
|
830 |
-
# assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
|
831 |
-
# use_xentropy = getattr(config, "use_xentropy", False)
|
832 |
-
# if use_xentropy and CrossEntropyLoss is None:
|
833 |
-
# raise ImportError("xentropy_cuda is not installed")
|
834 |
-
# loss_cls = (
|
835 |
-
# nn.CrossEntropyLoss
|
836 |
-
# if not use_xentropy
|
837 |
-
# else partial(CrossEntropyLoss, inplace_backward=True)
|
838 |
-
# )
|
839 |
-
#
|
840 |
-
# self.xlm = XLMRobertaModel(config)
|
841 |
-
# self.cls = XLMRobertaPreTrainingHeads(config)
|
842 |
-
# self.mlm_loss = loss_cls(ignore_index=0)
|
843 |
-
# self.nsp_loss = loss_cls(ignore_index=-1)
|
844 |
-
#
|
845 |
-
# # Initialize weights and apply final processing
|
846 |
-
# self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
847 |
-
# self.tie_weights()
|
848 |
-
#
|
849 |
-
# def tie_weights(self):
|
850 |
-
# self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
|
851 |
-
#
|
852 |
-
# def forward(
|
853 |
-
# self,
|
854 |
-
# input_ids,
|
855 |
-
# position_ids=None,
|
856 |
-
# token_type_ids=None,
|
857 |
-
# attention_mask=None,
|
858 |
-
# labels=None,
|
859 |
-
# next_sentence_label=None,
|
860 |
-
# ):
|
861 |
-
# """
|
862 |
-
# If labels are provided, they must be 0 for masked out tokens (as specified in the attention
|
863 |
-
# mask).
|
864 |
-
# Outputs:
|
865 |
-
# if `labels` and `next_sentence_label` are not `None`:
|
866 |
-
# Outputs the total_loss which is the sum of the masked language modeling loss and the next
|
867 |
-
# sentence classification loss.
|
868 |
-
# if `labels` or `next_sentence_label` is `None`:
|
869 |
-
# Outputs a tuple comprising
|
870 |
-
# - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
|
871 |
-
# - the next sentence classification logits of shape [batch_size, 2].
|
872 |
-
#
|
873 |
-
# """
|
874 |
-
# masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
|
875 |
-
# outputs = self.xlm(
|
876 |
-
# input_ids,
|
877 |
-
# position_ids=position_ids,
|
878 |
-
# token_type_ids=token_type_ids,
|
879 |
-
# attention_mask=attention_mask.bool() if attention_mask is not None else None,
|
880 |
-
# masked_tokens_mask=masked_tokens_mask,
|
881 |
-
# )
|
882 |
-
# sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
|
883 |
-
# if self.dense_seq_output and labels is not None:
|
884 |
-
# masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
|
885 |
-
# if not self.last_layer_subset:
|
886 |
-
# sequence_output = index_first_axis(
|
887 |
-
# rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
|
888 |
-
# )
|
889 |
-
# prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
890 |
-
#
|
891 |
-
# total_loss = None
|
892 |
-
# if labels is not None and next_sentence_label is not None:
|
893 |
-
# if (
|
894 |
-
# self.dense_seq_output and labels is not None
|
895 |
-
# ): # prediction_scores are already flattened
|
896 |
-
# masked_lm_loss = self.mlm_loss(
|
897 |
-
# prediction_scores, labels.flatten()[masked_token_idx]
|
898 |
-
# )
|
899 |
-
# else:
|
900 |
-
# masked_lm_loss = self.mlm_loss(
|
901 |
-
# rearrange(prediction_scores, "... v -> (...) v"),
|
902 |
-
# rearrange(labels, "... -> (...)"),
|
903 |
-
# )
|
904 |
-
# next_sentence_loss = self.nsp_loss(
|
905 |
-
# rearrange(seq_relationship_score, "... t -> (...) t"),
|
906 |
-
# rearrange(next_sentence_label, "... -> (...)"),
|
907 |
-
# )
|
908 |
-
# total_loss = masked_lm_loss.float() + next_sentence_loss.float()
|
909 |
-
#
|
910 |
-
# return BertForPreTrainingOutput(
|
911 |
-
# loss=total_loss,
|
912 |
-
# prediction_logits=prediction_scores,
|
913 |
-
# seq_relationship_logits=seq_relationship_score,
|
914 |
-
# )
|
915 |
-
|
916 |
-
|
917 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
918 |
"""
|
919 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
@@ -1065,47 +987,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
|
1065 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
1066 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
1067 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
1068 |
-
state_dict[
|
1069 |
-
|
1070 |
-
|
1071 |
-
state_dict[
|
1072 |
-
|
1073 |
-
|
1074 |
-
|
1075 |
-
|
1076 |
-
state_dict[
|
1077 |
-
|
1078 |
-
|
1079 |
-
state_dict[
|
1080 |
-
|
1081 |
-
|
1082 |
-
state_dict[
|
1083 |
-
|
1084 |
-
|
1085 |
-
state_dict[
|
1086 |
-
|
1087 |
-
|
1088 |
else:
|
1089 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
1090 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
1091 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
1092 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
1093 |
-
state_dict[
|
1094 |
-
|
1095 |
-
|
1096 |
-
state_dict[
|
1097 |
-
|
1098 |
-
|
1099 |
-
state_dict[
|
1100 |
-
|
1101 |
-
|
1102 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
1103 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
1104 |
: Wkv_biases.shape[0] // 2
|
1105 |
]
|
1106 |
-
state_dict[
|
1107 |
-
|
1108 |
-
|
1109 |
|
1110 |
def inv_key_mapping_ln(key):
|
1111 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
@@ -1294,4 +1216,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
|
1294 |
logits=logits,
|
1295 |
hidden_states=outputs.hidden_states,
|
1296 |
attentions=outputs.attentions,
|
1297 |
-
)
|
|
|
13 |
from collections import OrderedDict
|
14 |
from collections.abc import Sequence
|
15 |
from functools import partial
|
16 |
+
from typing import List, Optional, Tuple, Union
|
17 |
|
18 |
+
import numpy as np
|
19 |
import torch
|
20 |
import torch.nn as nn
|
21 |
import torch.nn.functional as F
|
22 |
import torch.utils.checkpoint
|
23 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
24 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
25 |
+
from transformers.modeling_outputs import (MaskedLMOutput,
|
26 |
+
SequenceClassifierOutput)
|
27 |
from transformers.modeling_utils import PreTrainedModel
|
|
|
|
|
|
|
28 |
from transformers.models.bert.modeling_bert import (
|
29 |
+
BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
|
30 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import \
|
31 |
+
XLMRobertaLMHead
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
from .block import Block
|
34 |
+
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
35 |
from .embedding import XLMRobertaEmbeddings
|
36 |
from .mha import MHA
|
37 |
from .mlp import FusedMLP, Mlp
|
38 |
+
from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
|
|
|
39 |
|
40 |
try:
|
41 |
from flash_attn.ops.fused_dense import FusedDense
|
|
|
69 |
return False
|
70 |
if importlib.util.find_spec("flash_attn") is None:
|
71 |
logger.warning(
|
72 |
+
"flash_attn is not installed. Using PyTorch native attention implementation."
|
73 |
)
|
74 |
return False
|
75 |
return True
|
|
|
99 |
fused_bias_fc=fused_bias_fc,
|
100 |
use_flash_attn=use_flash_attn,
|
101 |
return_residual=return_residual,
|
102 |
+
use_alibi=config.position_embedding_type == "alibi",
|
103 |
**rotary_kwargs,
|
104 |
)
|
105 |
return mixer_cls
|
|
|
194 |
def gradient_checkpointing(self, value):
|
195 |
self._grad_checkpointing = value
|
196 |
|
197 |
+
def forward(
|
198 |
+
self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None
|
199 |
+
):
|
200 |
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
201 |
This means that we only compute the last layer output for these tokens.
|
202 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
203 |
"""
|
204 |
if key_padding_mask is None or not self.use_flash_attn:
|
205 |
+
mixer_kwargs = {"adapter_mask": adapter_mask}
|
206 |
if key_padding_mask is not None:
|
207 |
+
mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
|
208 |
for layer in self.layers:
|
209 |
if self._grad_checkpointing:
|
210 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
|
|
219 |
hidden_states = hidden_states[subset_mask]
|
220 |
else:
|
221 |
batch, seqlen = hidden_states.shape[:2]
|
222 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
|
223 |
+
unpad_input(hidden_states, key_padding_mask, adapter_mask)
|
224 |
)
|
225 |
+
mixer_kwargs = {
|
226 |
+
"cu_seqlens": cu_seqlens,
|
227 |
+
"max_seqlen": max_seqlen_in_batch,
|
228 |
+
"adapter_mask": cu_adapter_mask,
|
229 |
+
}
|
230 |
|
231 |
if subset_mask is None:
|
232 |
for layer in self.layers:
|
|
|
311 |
if adapter_mask is not None:
|
312 |
unique_tasks = torch.unique(adapter_mask)
|
313 |
pool_dtype = next(self.dense.parameters()).dtype
|
314 |
+
pooled_output = torch.empty(
|
315 |
+
first_token_tensor.shape[0],
|
316 |
+
self.dense.out_features,
|
317 |
+
dtype=pool_dtype,
|
318 |
+
device=first_token_tensor.device,
|
319 |
+
)
|
320 |
for task_id in unique_tasks:
|
321 |
task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
|
322 |
task_first_token_tensor = first_token_tensor[task_indices]
|
323 |
+
task_pooled_output = self.dense(
|
324 |
+
task_first_token_tensor, task_id=task_id
|
325 |
+
)
|
326 |
pooled_output[task_indices] = task_pooled_output
|
327 |
else:
|
328 |
pooled_output = self.dense(first_token_tensor)
|
|
|
415 |
*args,
|
416 |
**kwargs,
|
417 |
):
|
418 |
+
if not "torch_dtype" in kwargs:
|
419 |
+
kwargs["torch_dtype"] = "auto"
|
420 |
return super().from_pretrained(*args, **kwargs)
|
421 |
|
422 |
|
|
|
423 |
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
424 |
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
425 |
super().__init__(config)
|
|
|
440 |
self.embeddings = XLMRobertaEmbeddings(
|
441 |
config.hidden_size,
|
442 |
config.vocab_size,
|
443 |
+
(
|
444 |
+
config.max_position_embeddings
|
445 |
+
if config.position_embedding_type == "absolute"
|
446 |
+
else -1
|
447 |
+
),
|
448 |
config.type_vocab_size,
|
449 |
padding_idx=config.pad_token_id,
|
450 |
)
|
|
|
454 |
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
455 |
|
456 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
457 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
458 |
+
self.name_or_path, trust_remote_code=True
|
459 |
+
)
|
460 |
self._rotary_emb_base = config.rotary_emb_base
|
461 |
|
462 |
@torch.inference_mode()
|
463 |
def encode(
|
464 |
+
self: "XLMRobertaModel",
|
465 |
sentences: Union[str, List[str]],
|
466 |
batch_size: int = 32,
|
467 |
show_progress_bar: Optional[bool] = None,
|
468 |
+
output_value: str = "sentence_embedding",
|
469 |
convert_to_numpy: bool = True,
|
470 |
convert_to_tensor: bool = False,
|
471 |
device: Optional[torch.device] = None,
|
|
|
523 |
if convert_to_tensor:
|
524 |
convert_to_numpy = False
|
525 |
|
526 |
+
if output_value != "sentence_embedding":
|
527 |
convert_to_tensor = False
|
528 |
convert_to_numpy = False
|
529 |
|
530 |
input_was_string = False
|
531 |
+
if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
|
532 |
sentences = [sentences]
|
533 |
input_was_string = True
|
534 |
|
|
|
539 |
inverse_permutation = np.argsort(permutation)
|
540 |
sentences = [sentences[idx] for idx in permutation]
|
541 |
|
542 |
+
tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
|
543 |
+
tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
|
544 |
+
"max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
|
545 |
)
|
546 |
+
tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
|
547 |
|
548 |
all_embeddings = []
|
549 |
|
|
|
557 |
)
|
558 |
else:
|
559 |
range_iter = range(0, len(sentences), batch_size)
|
560 |
+
lora_arguments = (
|
561 |
+
{"adapter_mask": adapter_mask} if adapter_mask is not None else {}
|
562 |
+
)
|
563 |
for i in range_iter:
|
564 |
encoded_input = self.tokenizer(
|
565 |
sentences[i : i + batch_size],
|
566 |
+
return_tensors="pt",
|
567 |
**tokenizer_kwargs,
|
568 |
).to(self.device)
|
569 |
token_embs = self.forward(**encoded_input, **lora_arguments)[0]
|
|
|
571 |
# Accumulate in fp32 to avoid overflow
|
572 |
token_embs = token_embs.float()
|
573 |
|
574 |
+
if output_value == "token_embeddings":
|
575 |
raise NotImplementedError
|
576 |
elif output_value is None:
|
577 |
raise NotImplementedError
|
578 |
else:
|
579 |
+
if self.config.emb_pooler == "cls":
|
580 |
embeddings = self.cls_pooling(
|
581 |
+
token_embs, encoded_input["attention_mask"]
|
582 |
)
|
583 |
else:
|
584 |
embeddings = self.mean_pooling(
|
585 |
+
token_embs, encoded_input["attention_mask"]
|
586 |
)
|
587 |
|
588 |
if normalize_embeddings:
|
|
|
612 |
def truncate_embeddings(self, embeddings, truncate_dim):
|
613 |
if not self.config.matryoshka_dimensions:
|
614 |
logger.warning(
|
615 |
+
"Matryoshka embeddings are not supported, so dimension truncation will not be performed."
|
616 |
)
|
617 |
return embeddings
|
618 |
elif truncate_dim in self.config.matryoshka_dimensions:
|
619 |
return [tensor[:truncate_dim] for tensor in embeddings]
|
620 |
else:
|
621 |
+
raise ValueError(
|
622 |
+
f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
|
623 |
+
f"Supported dimensions are {self.config.matryoshka_dimensions}."
|
624 |
+
)
|
625 |
|
626 |
def mean_pooling(
|
627 |
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
|
|
633 |
input_mask_expanded.sum(1), min=1e-9
|
634 |
)
|
635 |
|
636 |
+
def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
|
637 |
+
return token_embeddings[:, 0]
|
|
|
|
|
638 |
|
639 |
@property
|
640 |
def rotary_emb_base(self):
|
|
|
644 |
def rotary_emb_base(self, base):
|
645 |
if not isinstance(base, (int, float)):
|
646 |
raise TypeError("Base must be an integer or float")
|
647 |
+
logger.info(f"Changing RoPE base value to {base}")
|
648 |
for layer in self.encoder.layers:
|
649 |
layer.mixer.rotary_emb.base = base
|
650 |
self._rotary_emb_base = base
|
|
|
664 |
layer output for these tokens.
|
665 |
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
666 |
"""
|
667 |
+
adapter_mask = kwargs.pop("adapter_mask", None)
|
668 |
if kwargs:
|
669 |
for key, value in kwargs.items():
|
670 |
if value is not None:
|
671 |
logger.warning(
|
672 |
+
"Flash attention implementation does not support kwargs: %s",
|
673 |
key,
|
674 |
)
|
675 |
|
|
|
678 |
)
|
679 |
|
680 |
hidden_states = self.embeddings(
|
681 |
+
input_ids,
|
682 |
+
position_ids=position_ids,
|
683 |
+
token_type_ids=token_type_ids,
|
684 |
+
adapter_mask=adapter_mask,
|
685 |
)
|
686 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
687 |
# BERT puts embedding LayerNorm before embedding dropout.
|
|
|
705 |
subset_mask = None
|
706 |
|
707 |
sequence_output = self.encoder(
|
708 |
+
hidden_states,
|
709 |
+
key_padding_mask=attention_mask,
|
710 |
+
subset_mask=subset_mask,
|
711 |
+
adapter_mask=adapter_mask,
|
712 |
)
|
713 |
|
714 |
if masked_tokens_mask is None:
|
715 |
pooled_output = (
|
716 |
+
self.pooler(sequence_output, adapter_mask=adapter_mask)
|
717 |
+
if self.pooler is not None
|
718 |
+
else None
|
719 |
)
|
720 |
else:
|
721 |
# TD [2022-03-01]: the indexing here is very tricky.
|
|
|
729 |
pool_input = sequence_output[first_col_mask[subset_mask]]
|
730 |
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
731 |
pooled_output = (
|
732 |
+
self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
|
733 |
+
if self.pooler is not None
|
734 |
+
else None
|
735 |
)
|
736 |
|
737 |
if not return_dict:
|
|
|
836 |
)
|
837 |
|
838 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
839 |
def remap_state_dict(state_dict, config: PretrainedConfig):
|
840 |
"""
|
841 |
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
|
|
987 |
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
988 |
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
989 |
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
990 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
|
991 |
+
Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
|
992 |
+
)
|
993 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
|
994 |
+
Wqkv_weights[
|
995 |
+
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
996 |
+
]
|
997 |
+
)
|
998 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
|
999 |
+
Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
|
1000 |
+
)
|
1001 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
|
1002 |
+
Wqkv_biases[: Wqkv_biases.shape[0] // 3]
|
1003 |
+
)
|
1004 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
|
1005 |
+
Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
|
1006 |
+
)
|
1007 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
|
1008 |
+
Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
|
1009 |
+
)
|
1010 |
else:
|
1011 |
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
1012 |
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
1013 |
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
1014 |
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
1015 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
|
1016 |
+
Wq_weight
|
1017 |
+
)
|
1018 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
|
1019 |
+
Wkv_weights[: Wkv_weights.shape[0] // 2, :]
|
1020 |
+
)
|
1021 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
|
1022 |
+
Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
|
1023 |
+
)
|
1024 |
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
1025 |
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
1026 |
: Wkv_biases.shape[0] // 2
|
1027 |
]
|
1028 |
+
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
|
1029 |
+
Wkv_biases[Wkv_biases.shape[0] // 2 :]
|
1030 |
+
)
|
1031 |
|
1032 |
def inv_key_mapping_ln(key):
|
1033 |
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
|
|
1216 |
logits=logits,
|
1217 |
hidden_states=outputs.hidden_states,
|
1218 |
attentions=outputs.attentions,
|
1219 |
+
)
|
modeling_xlm_roberta_for_glue.py
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
from typing import Optional, Union, Tuple
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch import nn
|
5 |
-
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
|
6 |
-
from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
|
7 |
-
|
8 |
-
from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
|
9 |
-
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
10 |
-
|
11 |
-
|
12 |
-
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
13 |
-
def __init__(self, config: XLMRobertaFlashConfig):
|
14 |
-
super().__init__(config)
|
15 |
-
self.num_labels = config.num_labels
|
16 |
-
self.config = config
|
17 |
-
|
18 |
-
self.roberta = XLMRobertaModel(config)
|
19 |
-
classifier_dropout = (
|
20 |
-
config.classifier_dropout
|
21 |
-
if config.classifier_dropout is not None
|
22 |
-
else config.hidden_dropout_prob
|
23 |
-
)
|
24 |
-
self.dropout = nn.Dropout(classifier_dropout)
|
25 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
26 |
-
|
27 |
-
# Initialize weights and apply final processing
|
28 |
-
self.post_init()
|
29 |
-
|
30 |
-
|
31 |
-
def forward(
|
32 |
-
self,
|
33 |
-
input_ids: Optional[torch.Tensor] = None,
|
34 |
-
attention_mask: Optional[torch.Tensor] = None,
|
35 |
-
token_type_ids: Optional[torch.Tensor] = None,
|
36 |
-
position_ids: Optional[torch.Tensor] = None,
|
37 |
-
head_mask: Optional[torch.Tensor] = None,
|
38 |
-
inputs_embeds: Optional[torch.Tensor] = None,
|
39 |
-
labels: Optional[torch.Tensor] = None,
|
40 |
-
output_attentions: Optional[bool] = None,
|
41 |
-
output_hidden_states: Optional[bool] = None,
|
42 |
-
return_dict: Optional[bool] = None,
|
43 |
-
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
44 |
-
r"""
|
45 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
46 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
47 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
48 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
49 |
-
"""
|
50 |
-
return_dict = (
|
51 |
-
return_dict if return_dict is not None else self.config.use_return_dict
|
52 |
-
)
|
53 |
-
|
54 |
-
assert head_mask is None
|
55 |
-
assert inputs_embeds is None
|
56 |
-
assert output_attentions is None
|
57 |
-
assert output_hidden_states is None
|
58 |
-
assert return_dict
|
59 |
-
outputs = self.roberta(
|
60 |
-
input_ids,
|
61 |
-
attention_mask=attention_mask,
|
62 |
-
token_type_ids=token_type_ids,
|
63 |
-
position_ids=position_ids,
|
64 |
-
head_mask=head_mask,
|
65 |
-
inputs_embeds=inputs_embeds,
|
66 |
-
output_attentions=output_attentions,
|
67 |
-
output_hidden_states=output_hidden_states,
|
68 |
-
return_dict=return_dict,
|
69 |
-
)
|
70 |
-
|
71 |
-
pooled_output = outputs[1]
|
72 |
-
|
73 |
-
pooled_output = self.dropout(pooled_output)
|
74 |
-
logits = self.classifier(pooled_output)
|
75 |
-
|
76 |
-
loss = None
|
77 |
-
if labels is not None:
|
78 |
-
if self.config.problem_type is None:
|
79 |
-
if self.num_labels == 1:
|
80 |
-
self.config.problem_type = "regression"
|
81 |
-
elif self.num_labels > 1 and (
|
82 |
-
labels.dtype == torch.long or labels.dtype == torch.int
|
83 |
-
):
|
84 |
-
self.config.problem_type = "single_label_classification"
|
85 |
-
else:
|
86 |
-
self.config.problem_type = "multi_label_classification"
|
87 |
-
|
88 |
-
if self.config.problem_type == "regression":
|
89 |
-
loss_fct = MSELoss()
|
90 |
-
if self.num_labels == 1:
|
91 |
-
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
92 |
-
else:
|
93 |
-
loss = loss_fct(logits, labels)
|
94 |
-
elif self.config.problem_type == "single_label_classification":
|
95 |
-
loss_fct = CrossEntropyLoss()
|
96 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
97 |
-
elif self.config.problem_type == "multi_label_classification":
|
98 |
-
loss_fct = BCEWithLogitsLoss()
|
99 |
-
loss = loss_fct(logits, labels)
|
100 |
-
if not return_dict:
|
101 |
-
output = (logits,) + outputs[2:]
|
102 |
-
return ((loss,) + output) if loss is not None else output
|
103 |
-
|
104 |
-
return SequenceClassifierOutput(
|
105 |
-
loss=loss,
|
106 |
-
logits=logits,
|
107 |
-
hidden_states=outputs.hidden_states,
|
108 |
-
attentions=outputs.attentions,
|
109 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rotary.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
-
#
|
|
|
|
|
|
|
2 |
# Copyright (c) 2023, Tri Dao.
|
3 |
|
4 |
import math
|
@@ -11,8 +14,9 @@ if torch.cuda.is_available():
|
|
11 |
try:
|
12 |
from flash_attn.ops.triton.rotary import apply_rotary
|
13 |
except ImportError:
|
|
|
14 |
def apply_rotary(*args, **kwargs):
|
15 |
-
raise RuntimeError(
|
16 |
|
17 |
|
18 |
def rotate_half(x, interleaved=False):
|
@@ -21,7 +25,9 @@ def rotate_half(x, interleaved=False):
|
|
21 |
return torch.cat((-x2, x1), dim=-1)
|
22 |
else:
|
23 |
x1, x2 = x[..., ::2], x[..., 1::2]
|
24 |
-
return rearrange(
|
|
|
|
|
25 |
|
26 |
|
27 |
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
@@ -32,13 +38,20 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
32 |
ro_dim = cos.shape[-1] * 2
|
33 |
assert ro_dim <= x.shape[-1]
|
34 |
cos, sin = (
|
35 |
-
cos[:x.shape[1]],
|
36 |
-
sin[:x.shape[1]],
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
)
|
38 |
-
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
39 |
-
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
40 |
return torch.cat(
|
41 |
-
[
|
|
|
|
|
|
|
42 |
dim=-1,
|
43 |
)
|
44 |
|
@@ -68,7 +81,9 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|
68 |
)
|
69 |
|
70 |
if isinstance(seqlen_offsets, int):
|
71 |
-
ctx.save_for_backward(
|
|
|
|
|
72 |
ctx.seqlen_offsets = seqlen_offsets
|
73 |
else:
|
74 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
@@ -336,7 +351,9 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
|
|
336 |
max_seqlen=max_seqlen,
|
337 |
)
|
338 |
if isinstance(seqlen_offsets, int):
|
339 |
-
ctx.save_for_backward(
|
|
|
|
|
340 |
ctx.seqlen_offsets = seqlen_offsets
|
341 |
else:
|
342 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
@@ -451,7 +468,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|
451 |
self.interleaved = interleaved
|
452 |
self.scale_base = scale_base
|
453 |
scale = (
|
454 |
-
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
|
|
455 |
if scale_base is not None
|
456 |
else None
|
457 |
)
|
@@ -477,7 +495,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
477 |
def _compute_inv_freq(self, device=None):
|
478 |
return 1.0 / (
|
479 |
self.base
|
480 |
-
** (
|
|
|
|
|
|
|
481 |
)
|
482 |
|
483 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
@@ -516,10 +537,14 @@ class RotaryEmbedding(torch.nn.Module):
|
|
516 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
517 |
else:
|
518 |
power = (
|
519 |
-
torch.arange(
|
|
|
|
|
520 |
- seqlen // 2
|
521 |
) / self.scale_base
|
522 |
-
scale = self.scale.to(device=power.device) ** rearrange(
|
|
|
|
|
523 |
# We want the multiplication by scale to happen in fp32
|
524 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
525 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
@@ -550,7 +575,9 @@ class RotaryEmbedding(torch.nn.Module):
|
|
550 |
if max_seqlen is not None:
|
551 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
552 |
elif isinstance(seqlen_offset, int):
|
553 |
-
self._update_cos_sin_cache(
|
|
|
|
|
554 |
if kv is None:
|
555 |
if self.scale is None:
|
556 |
return apply_rotary_emb_qkv_(
|
@@ -606,4 +633,4 @@ class RotaryEmbedding(torch.nn.Module):
|
|
606 |
cu_seqlens=cu_seqlens,
|
607 |
max_seqlen=max_seqlen,
|
608 |
)
|
609 |
-
return q, kv
|
|
|
1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
|
2 |
+
# Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
|
3 |
+
# Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
|
4 |
+
|
5 |
# Copyright (c) 2023, Tri Dao.
|
6 |
|
7 |
import math
|
|
|
14 |
try:
|
15 |
from flash_attn.ops.triton.rotary import apply_rotary
|
16 |
except ImportError:
|
17 |
+
|
18 |
def apply_rotary(*args, **kwargs):
|
19 |
+
raise RuntimeError("RoPE requires flash-attention to be installed")
|
20 |
|
21 |
|
22 |
def rotate_half(x, interleaved=False):
|
|
|
25 |
return torch.cat((-x2, x1), dim=-1)
|
26 |
else:
|
27 |
x1, x2 = x[..., ::2], x[..., 1::2]
|
28 |
+
return rearrange(
|
29 |
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
30 |
+
)
|
31 |
|
32 |
|
33 |
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
|
38 |
ro_dim = cos.shape[-1] * 2
|
39 |
assert ro_dim <= x.shape[-1]
|
40 |
cos, sin = (
|
41 |
+
cos[: x.shape[1]],
|
42 |
+
sin[: x.shape[1]],
|
43 |
+
)
|
44 |
+
cos = repeat(
|
45 |
+
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
46 |
+
)
|
47 |
+
sin = repeat(
|
48 |
+
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
49 |
)
|
|
|
|
|
50 |
return torch.cat(
|
51 |
+
[
|
52 |
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
53 |
+
x[..., ro_dim:],
|
54 |
+
],
|
55 |
dim=-1,
|
56 |
)
|
57 |
|
|
|
81 |
)
|
82 |
|
83 |
if isinstance(seqlen_offsets, int):
|
84 |
+
ctx.save_for_backward(
|
85 |
+
cos, sin, cu_seqlens
|
86 |
+
) # Can't save int with save_for_backward
|
87 |
ctx.seqlen_offsets = seqlen_offsets
|
88 |
else:
|
89 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
|
|
351 |
max_seqlen=max_seqlen,
|
352 |
)
|
353 |
if isinstance(seqlen_offsets, int):
|
354 |
+
ctx.save_for_backward(
|
355 |
+
cos, sin, cu_seqlens
|
356 |
+
) # Can't save int with save_for_backward
|
357 |
ctx.seqlen_offsets = seqlen_offsets
|
358 |
else:
|
359 |
ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
|
|
|
468 |
self.interleaved = interleaved
|
469 |
self.scale_base = scale_base
|
470 |
scale = (
|
471 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
|
472 |
+
/ (1.4 * dim)
|
473 |
if scale_base is not None
|
474 |
else None
|
475 |
)
|
|
|
495 |
def _compute_inv_freq(self, device=None):
|
496 |
return 1.0 / (
|
497 |
self.base
|
498 |
+
** (
|
499 |
+
torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
|
500 |
+
/ self.dim
|
501 |
+
)
|
502 |
)
|
503 |
|
504 |
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
|
|
537 |
self._sin_cached = torch.sin(freqs).to(dtype)
|
538 |
else:
|
539 |
power = (
|
540 |
+
torch.arange(
|
541 |
+
seqlen, dtype=self.scale.dtype, device=self.scale.device
|
542 |
+
)
|
543 |
- seqlen // 2
|
544 |
) / self.scale_base
|
545 |
+
scale = self.scale.to(device=power.device) ** rearrange(
|
546 |
+
power, "s -> s 1"
|
547 |
+
)
|
548 |
# We want the multiplication by scale to happen in fp32
|
549 |
self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
|
550 |
self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
|
|
|
575 |
if max_seqlen is not None:
|
576 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
577 |
elif isinstance(seqlen_offset, int):
|
578 |
+
self._update_cos_sin_cache(
|
579 |
+
seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
|
580 |
+
)
|
581 |
if kv is None:
|
582 |
if self.scale is None:
|
583 |
return apply_rotary_emb_qkv_(
|
|
|
633 |
cu_seqlens=cu_seqlens,
|
634 |
max_seqlen=max_seqlen,
|
635 |
)
|
636 |
+
return q, kv
|
stochastic_depth.py
CHANGED
@@ -34,7 +34,7 @@
|
|
34 |
|
35 |
import torch
|
36 |
import torch.fx
|
37 |
-
from torch import
|
38 |
|
39 |
|
40 |
def stochastic_depth(
|
|
|
34 |
|
35 |
import torch
|
36 |
import torch.fx
|
37 |
+
from torch import Tensor, nn
|
38 |
|
39 |
|
40 |
def stochastic_depth(
|
xlm_padding.py
CHANGED
@@ -18,7 +18,9 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
18 |
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
19 |
# return input[indices]
|
20 |
return torch.gather(
|
21 |
-
rearrange(input, "b ... -> b (...)"),
|
|
|
|
|
22 |
).reshape(-1, *other_shape)
|
23 |
|
24 |
@staticmethod
|
@@ -34,7 +36,9 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
34 |
)
|
35 |
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
36 |
# grad_input[indices] = grad_output
|
37 |
-
grad_input.scatter_(
|
|
|
|
|
38 |
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
39 |
|
40 |
|
@@ -112,9 +116,15 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
|
|
112 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
113 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
-
cu_seqlens = F.pad(
|
|
|
|
|
116 |
|
117 |
-
cu_adapter_mask =
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
120 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
@@ -184,14 +194,18 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
|
|
184 |
"""
|
185 |
length = attention_mask_in_length.sum(dim=-1)
|
186 |
seqlen = attention_mask_in_length.size(-1)
|
187 |
-
attention_mask_2d = torch.arange(
|
188 |
-
|
189 |
-
|
190 |
-
real_indices_idx = torch.nonzero(
|
|
|
|
|
191 |
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
192 |
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
193 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
194 |
-
cu_seqlens = F.pad(
|
|
|
|
|
195 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
196 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
197 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
@@ -219,4 +233,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
|
|
219 |
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
220 |
# output[indices] = hidden_states
|
221 |
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
222 |
-
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|
|
|
18 |
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
19 |
# return input[indices]
|
20 |
return torch.gather(
|
21 |
+
rearrange(input, "b ... -> b (...)"),
|
22 |
+
0,
|
23 |
+
repeat(indices, "z -> z d", d=second_dim),
|
24 |
).reshape(-1, *other_shape)
|
25 |
|
26 |
@staticmethod
|
|
|
36 |
)
|
37 |
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
38 |
# grad_input[indices] = grad_output
|
39 |
+
grad_input.scatter_(
|
40 |
+
0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
|
41 |
+
)
|
42 |
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
43 |
|
44 |
|
|
|
116 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
117 |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
118 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
119 |
+
cu_seqlens = F.pad(
|
120 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
121 |
+
)
|
122 |
|
123 |
+
cu_adapter_mask = (
|
124 |
+
torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
|
125 |
+
if adapter_mask is not None
|
126 |
+
else None
|
127 |
+
)
|
128 |
|
129 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
130 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
|
|
194 |
"""
|
195 |
length = attention_mask_in_length.sum(dim=-1)
|
196 |
seqlen = attention_mask_in_length.size(-1)
|
197 |
+
attention_mask_2d = torch.arange(
|
198 |
+
seqlen, device=length.device, dtype=length.dtype
|
199 |
+
).expand(len(length), seqlen) < length.unsqueeze(1)
|
200 |
+
real_indices_idx = torch.nonzero(
|
201 |
+
attention_mask_in_length.flatten(), as_tuple=False
|
202 |
+
).flatten()
|
203 |
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
204 |
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
205 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
206 |
+
cu_seqlens = F.pad(
|
207 |
+
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
|
208 |
+
)
|
209 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
210 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
211 |
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
|
|
233 |
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
234 |
# output[indices] = hidden_states
|
235 |
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
236 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|