Update configuration and tokenization
Browse files- modeling_rotary_indictrans.py +15 -17
- tokenization_indictrans.py +8 -9
modeling_rotary_indictrans.py
CHANGED
@@ -43,7 +43,7 @@ try:
|
|
43 |
)
|
44 |
except ImportError:
|
45 |
logger.warning(
|
46 |
-
"It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
|
47 |
"Falling back to the default `eager` implementation."
|
48 |
)
|
49 |
|
@@ -96,25 +96,23 @@ def rotate_half(x):
|
|
96 |
@autocast("cuda", enabled=False)
|
97 |
def apply_rotary_emb(cos, sin, t):
|
98 |
rot_dim = cos.shape[-1]
|
99 |
-
assert rot_dim <= t.shape[-1] and cos.shape == sin.shape
|
100 |
t_left, t_right = t[..., :rot_dim], t[..., rot_dim:]
|
101 |
t_transformed = (t_left * cos) + (rotate_half(t_left) * sin)
|
102 |
return torch.cat((t_transformed, t_right), dim=-1).type(t.dtype)
|
103 |
|
104 |
|
105 |
class RotaryEmbedding(torch.nn.Module):
|
106 |
-
def __init__(
|
107 |
-
self, dim, theta=10000, interpolate_factor=1.0, cache_max_seq_len=8192
|
108 |
-
):
|
109 |
super().__init__()
|
110 |
|
111 |
-
|
112 |
-
self.
|
113 |
-
|
|
|
114 |
|
115 |
-
self.
|
|
|
116 |
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
117 |
-
self.precompute_freqs(cache_max_seq_len)
|
118 |
|
119 |
def precompute_freqs(self, max_seq_len):
|
120 |
thetas = self.forward(max_seq_len, device=device)
|
@@ -124,9 +122,9 @@ class RotaryEmbedding(torch.nn.Module):
|
|
124 |
def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0):
|
125 |
seq_len = t.shape[seq_dim]
|
126 |
|
127 |
-
if seq_len > self.
|
128 |
-
self.
|
129 |
-
self.precompute_freqs(self.
|
130 |
|
131 |
cos, sin = (
|
132 |
self.cached_cos[offset : (offset + seq_len)],
|
@@ -136,8 +134,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|
136 |
|
137 |
@autocast("cuda", enabled=False)
|
138 |
def forward(self, seq_len, device):
|
139 |
-
seq = torch.arange(seq_len, device=device) / self.
|
140 |
-
thetas = einsum("..., f -> ... f", seq, self.
|
141 |
thetas = repeat(thetas, "... n -> ... (n r)", r=2)
|
142 |
return thetas
|
143 |
|
@@ -176,7 +174,7 @@ class RotaryIndicTransAttention(nn.Module):
|
|
176 |
RotaryEmbedding(
|
177 |
dim=self.head_dim // 2,
|
178 |
theta=config.rope_args.get("theta", 10000),
|
179 |
-
|
180 |
)
|
181 |
if not is_cross_attention
|
182 |
else None
|
@@ -1653,4 +1651,4 @@ class RotaryIndicTransForConditionalGeneration(
|
|
1653 |
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1654 |
),
|
1655 |
)
|
1656 |
-
return reordered_past
|
|
|
43 |
)
|
44 |
except ImportError:
|
45 |
logger.warning(
|
46 |
+
"It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
|
47 |
"Falling back to the default `eager` implementation."
|
48 |
)
|
49 |
|
|
|
96 |
@autocast("cuda", enabled=False)
|
97 |
def apply_rotary_emb(cos, sin, t):
|
98 |
rot_dim = cos.shape[-1]
|
|
|
99 |
t_left, t_right = t[..., :rot_dim], t[..., rot_dim:]
|
100 |
t_transformed = (t_left * cos) + (rotate_half(t_left) * sin)
|
101 |
return torch.cat((t_transformed, t_right), dim=-1).type(t.dtype)
|
102 |
|
103 |
|
104 |
class RotaryEmbedding(torch.nn.Module):
|
105 |
+
def __init__(self, dim, theta=10000, scaling_factor=1.0, max_seq_len=8192):
|
|
|
|
|
106 |
super().__init__()
|
107 |
|
108 |
+
self.max_seq_len = max_seq_len
|
109 |
+
self.scaling_factor = scaling_factor
|
110 |
+
|
111 |
+
inv_freq_ = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
112 |
|
113 |
+
self.register_buffer("inv_freq", inv_freq_, persistent=False)
|
114 |
+
self.precompute_freqs(max_seq_len)
|
115 |
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
|
|
116 |
|
117 |
def precompute_freqs(self, max_seq_len):
|
118 |
thetas = self.forward(max_seq_len, device=device)
|
|
|
122 |
def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0):
|
123 |
seq_len = t.shape[seq_dim]
|
124 |
|
125 |
+
if seq_len > self.max_seq_len:
|
126 |
+
self.max_seq_len = seq_len * 2
|
127 |
+
self.precompute_freqs(self.max_seq_len)
|
128 |
|
129 |
cos, sin = (
|
130 |
self.cached_cos[offset : (offset + seq_len)],
|
|
|
134 |
|
135 |
@autocast("cuda", enabled=False)
|
136 |
def forward(self, seq_len, device):
|
137 |
+
seq = torch.arange(seq_len, device=device) / self.scaling_factor
|
138 |
+
thetas = einsum("..., f -> ... f", seq, self.inv_freq)
|
139 |
thetas = repeat(thetas, "... n -> ... (n r)", r=2)
|
140 |
return thetas
|
141 |
|
|
|
174 |
RotaryEmbedding(
|
175 |
dim=self.head_dim // 2,
|
176 |
theta=config.rope_args.get("theta", 10000),
|
177 |
+
scaling_factor=config.rope_args.get("scaling_factor", 1.0),
|
178 |
)
|
179 |
if not is_cross_attention
|
180 |
else None
|
|
|
1651 |
past_state.index_select(0, beam_idx) for past_state in layer_past
|
1652 |
),
|
1653 |
)
|
1654 |
+
return reordered_past
|
tokenization_indictrans.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
-
from functools import lru_cache
|
4 |
|
5 |
from transformers.utils import logging
|
6 |
from typing import Dict, List, Optional, Union, Tuple
|
@@ -11,10 +10,8 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
|
11 |
|
12 |
logger = logging.get_logger(__name__)
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
# Convert SPECIAL_TAGS to a frozen set for faster lookups
|
17 |
-
SPECIAL_TAGS = frozenset(
|
18 |
{
|
19 |
"asm_Beng",
|
20 |
"awa_Deva",
|
@@ -137,9 +134,9 @@ class IndicTransTokenizer(PreTrainedTokenizer):
|
|
137 |
**kwargs,
|
138 |
)
|
139 |
|
140 |
-
def
|
141 |
-
global
|
142 |
-
|
143 |
|
144 |
def _switch_to_input_mode(self) -> None:
|
145 |
self.spm = self.src_spm
|
@@ -197,10 +194,12 @@ class IndicTransTokenizer(PreTrainedTokenizer):
|
|
197 |
return self.decoder.get(index, self.unk_token)
|
198 |
|
199 |
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
200 |
-
return "".join(tokens).replace(
|
201 |
|
202 |
def _src_tokenize(self, text: str) -> List[str]:
|
203 |
src_lang, tgt_lang, text = text.split(" ", 2)
|
|
|
|
|
204 |
return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
|
205 |
|
206 |
def _tgt_tokenize(self, text: str) -> List[str]:
|
|
|
1 |
import os
|
2 |
import json
|
|
|
3 |
|
4 |
from transformers.utils import logging
|
5 |
from typing import Dict, List, Optional, Union, Tuple
|
|
|
10 |
|
11 |
logger = logging.get_logger(__name__)
|
12 |
|
13 |
+
# Convert LANGUAGE_TAGS to a frozen set for faster lookups
|
14 |
+
LANGUAGE_TAGS = frozenset(
|
|
|
|
|
15 |
{
|
16 |
"asm_Beng",
|
17 |
"awa_Deva",
|
|
|
134 |
**kwargs,
|
135 |
)
|
136 |
|
137 |
+
def add_new_language_tags(self, new_tags: List[str]) -> None:
|
138 |
+
global LANGUAGE_TAGS
|
139 |
+
LANGUAGE_TAGS = frozenset(LANGUAGE_TAGS | set(new_tags))
|
140 |
|
141 |
def _switch_to_input_mode(self) -> None:
|
142 |
self.spm = self.src_spm
|
|
|
194 |
return self.decoder.get(index, self.unk_token)
|
195 |
|
196 |
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
197 |
+
return "".join(tokens).replace("▁", " ").strip()
|
198 |
|
199 |
def _src_tokenize(self, text: str) -> List[str]:
|
200 |
src_lang, tgt_lang, text = text.split(" ", 2)
|
201 |
+
assert src_lang in LANGUAGE_TAGS, f"Invalid source language tag: {src_lang}"
|
202 |
+
assert tgt_lang in LANGUAGE_TAGS, f"Invalid target language tag: {tgt_lang}"
|
203 |
return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
|
204 |
|
205 |
def _tgt_tokenize(self, text: str) -> List[str]:
|