prajdabre commited on
Commit
0375fe8
·
verified ·
1 Parent(s): 0969b22

Update configuration and tokenization

Browse files
modeling_rotary_indictrans.py CHANGED
@@ -43,7 +43,7 @@ try:
43
  )
44
  except ImportError:
45
  logger.warning(
46
- "It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
47
  "Falling back to the default `eager` implementation."
48
  )
49
 
@@ -96,25 +96,23 @@ def rotate_half(x):
96
  @autocast("cuda", enabled=False)
97
  def apply_rotary_emb(cos, sin, t):
98
  rot_dim = cos.shape[-1]
99
- assert rot_dim <= t.shape[-1] and cos.shape == sin.shape
100
  t_left, t_right = t[..., :rot_dim], t[..., rot_dim:]
101
  t_transformed = (t_left * cos) + (rotate_half(t_left) * sin)
102
  return torch.cat((t_transformed, t_right), dim=-1).type(t.dtype)
103
 
104
 
105
  class RotaryEmbedding(torch.nn.Module):
106
- def __init__(
107
- self, dim, theta=10000, interpolate_factor=1.0, cache_max_seq_len=8192
108
- ):
109
  super().__init__()
110
 
111
- freqs_ = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
112
- self.cache_max_seq_len = cache_max_seq_len
113
- self.interpolate_factor = interpolate_factor
 
114
 
115
- self.freqs = torch.nn.Parameter(freqs_, requires_grad=False).to(device)
 
116
  self.apply_rotary_emb = staticmethod(apply_rotary_emb)
117
- self.precompute_freqs(cache_max_seq_len)
118
 
119
  def precompute_freqs(self, max_seq_len):
120
  thetas = self.forward(max_seq_len, device=device)
@@ -124,9 +122,9 @@ class RotaryEmbedding(torch.nn.Module):
124
  def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0):
125
  seq_len = t.shape[seq_dim]
126
 
127
- if seq_len > self.cache_max_seq_len:
128
- self.cache_max_seq_len = seq_len * 2
129
- self.precompute_freqs(self.cache_max_seq_len)
130
 
131
  cos, sin = (
132
  self.cached_cos[offset : (offset + seq_len)],
@@ -136,8 +134,8 @@ class RotaryEmbedding(torch.nn.Module):
136
 
137
  @autocast("cuda", enabled=False)
138
  def forward(self, seq_len, device):
139
- seq = torch.arange(seq_len, device=device) / self.interpolate_factor
140
- thetas = einsum("..., f -> ... f", seq, self.freqs)
141
  thetas = repeat(thetas, "... n -> ... (n r)", r=2)
142
  return thetas
143
 
@@ -176,7 +174,7 @@ class RotaryIndicTransAttention(nn.Module):
176
  RotaryEmbedding(
177
  dim=self.head_dim // 2,
178
  theta=config.rope_args.get("theta", 10000),
179
- interpolate_factor=config.rope_args.get("interpolate_factor", 1.0),
180
  )
181
  if not is_cross_attention
182
  else None
@@ -1653,4 +1651,4 @@ class RotaryIndicTransForConditionalGeneration(
1653
  past_state.index_select(0, beam_idx) for past_state in layer_past
1654
  ),
1655
  )
1656
- return reordered_past
 
43
  )
44
  except ImportError:
45
  logger.warning(
46
+ "It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
47
  "Falling back to the default `eager` implementation."
48
  )
49
 
 
96
  @autocast("cuda", enabled=False)
97
  def apply_rotary_emb(cos, sin, t):
98
  rot_dim = cos.shape[-1]
 
99
  t_left, t_right = t[..., :rot_dim], t[..., rot_dim:]
100
  t_transformed = (t_left * cos) + (rotate_half(t_left) * sin)
101
  return torch.cat((t_transformed, t_right), dim=-1).type(t.dtype)
102
 
103
 
104
  class RotaryEmbedding(torch.nn.Module):
105
+ def __init__(self, dim, theta=10000, scaling_factor=1.0, max_seq_len=8192):
 
 
106
  super().__init__()
107
 
108
+ self.max_seq_len = max_seq_len
109
+ self.scaling_factor = scaling_factor
110
+
111
+ inv_freq_ = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
112
 
113
+ self.register_buffer("inv_freq", inv_freq_, persistent=False)
114
+ self.precompute_freqs(max_seq_len)
115
  self.apply_rotary_emb = staticmethod(apply_rotary_emb)
 
116
 
117
  def precompute_freqs(self, max_seq_len):
118
  thetas = self.forward(max_seq_len, device=device)
 
122
  def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0):
123
  seq_len = t.shape[seq_dim]
124
 
125
+ if seq_len > self.max_seq_len:
126
+ self.max_seq_len = seq_len * 2
127
+ self.precompute_freqs(self.max_seq_len)
128
 
129
  cos, sin = (
130
  self.cached_cos[offset : (offset + seq_len)],
 
134
 
135
  @autocast("cuda", enabled=False)
136
  def forward(self, seq_len, device):
137
+ seq = torch.arange(seq_len, device=device) / self.scaling_factor
138
+ thetas = einsum("..., f -> ... f", seq, self.inv_freq)
139
  thetas = repeat(thetas, "... n -> ... (n r)", r=2)
140
  return thetas
141
 
 
174
  RotaryEmbedding(
175
  dim=self.head_dim // 2,
176
  theta=config.rope_args.get("theta", 10000),
177
+ scaling_factor=config.rope_args.get("scaling_factor", 1.0),
178
  )
179
  if not is_cross_attention
180
  else None
 
1651
  past_state.index_select(0, beam_idx) for past_state in layer_past
1652
  ),
1653
  )
1654
+ return reordered_past
tokenization_indictrans.py CHANGED
@@ -1,6 +1,5 @@
1
  import os
2
  import json
3
- from functools import lru_cache
4
 
5
  from transformers.utils import logging
6
  from typing import Dict, List, Optional, Union, Tuple
@@ -11,10 +10,8 @@ from transformers.tokenization_utils import PreTrainedTokenizer
11
 
12
  logger = logging.get_logger(__name__)
13
 
14
- SPIECE_UNDERLINE = "▁"
15
-
16
- # Convert SPECIAL_TAGS to a frozen set for faster lookups
17
- SPECIAL_TAGS = frozenset(
18
  {
19
  "asm_Beng",
20
  "awa_Deva",
@@ -137,9 +134,9 @@ class IndicTransTokenizer(PreTrainedTokenizer):
137
  **kwargs,
138
  )
139
 
140
- def add_new_special_tags(self, new_tags: List[str]) -> None:
141
- global SPECIAL_TAGS
142
- SPECIAL_TAGS = frozenset(SPECIAL_TAGS | set(new_tags))
143
 
144
  def _switch_to_input_mode(self) -> None:
145
  self.spm = self.src_spm
@@ -197,10 +194,12 @@ class IndicTransTokenizer(PreTrainedTokenizer):
197
  return self.decoder.get(index, self.unk_token)
198
 
199
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
200
- return "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
201
 
202
  def _src_tokenize(self, text: str) -> List[str]:
203
  src_lang, tgt_lang, text = text.split(" ", 2)
 
 
204
  return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
205
 
206
  def _tgt_tokenize(self, text: str) -> List[str]:
 
1
  import os
2
  import json
 
3
 
4
  from transformers.utils import logging
5
  from typing import Dict, List, Optional, Union, Tuple
 
10
 
11
  logger = logging.get_logger(__name__)
12
 
13
+ # Convert LANGUAGE_TAGS to a frozen set for faster lookups
14
+ LANGUAGE_TAGS = frozenset(
 
 
15
  {
16
  "asm_Beng",
17
  "awa_Deva",
 
134
  **kwargs,
135
  )
136
 
137
+ def add_new_language_tags(self, new_tags: List[str]) -> None:
138
+ global LANGUAGE_TAGS
139
+ LANGUAGE_TAGS = frozenset(LANGUAGE_TAGS | set(new_tags))
140
 
141
  def _switch_to_input_mode(self) -> None:
142
  self.spm = self.src_spm
 
194
  return self.decoder.get(index, self.unk_token)
195
 
196
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
197
+ return "".join(tokens).replace("▁", " ").strip()
198
 
199
  def _src_tokenize(self, text: str) -> List[str]:
200
  src_lang, tgt_lang, text = text.split(" ", 2)
201
+ assert src_lang in LANGUAGE_TAGS, f"Invalid source language tag: {src_lang}"
202
+ assert tgt_lang in LANGUAGE_TAGS, f"Invalid target language tag: {tgt_lang}"
203
  return [src_lang, tgt_lang] + self.spm.EncodeAsPieces(text)
204
 
205
  def _tgt_tokenize(self, text: str) -> List[str]: