KaleiNeely commited on
Commit
29e5afb
1 Parent(s): 92c7113

Upload 7 files

Browse files
Files changed (3) hide show
  1. README.md +69 -3
  2. modeling_rwkv5.py +56 -37
  3. tokenization_rwkv_world.py +142 -12
README.md CHANGED
@@ -85,7 +85,7 @@ Assistant:"""
85
  model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-3b", trust_remote_code=True, torch_dtype=torch.float16).to(0)
86
  tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-3b", trust_remote_code=True)
87
 
88
- text = "乌兰察布"
89
  prompt = generate_prompt(text)
90
 
91
  inputs = tokenizer(prompt, return_tensors="pt").to(0)
@@ -100,8 +100,74 @@ User: hi
100
 
101
  Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
102
 
103
- User: 乌兰察布
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- Assistant: 乌兰察布市是中国新疆维吾尔自治区的一个地级市,位于新疆维吾尔自治区西南部,毗邻青海省。乌兰察布市是新疆维吾尔自治区的重要城市之一,也是新疆维吾尔自治区的第二大城市。乌兰察布市是新疆的重要经济中心之一,拥有丰富的自然资源和人口密度,是新疆的重要交通枢纽和商
106
  ```
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-3b", trust_remote_code=True, torch_dtype=torch.float16).to(0)
86
  tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-3b", trust_remote_code=True)
87
 
88
+ text = "介绍一下大熊猫"
89
  prompt = generate_prompt(text)
90
 
91
  inputs = tokenizer(prompt, return_tensors="pt").to(0)
 
100
 
101
  Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
102
 
103
+ User: 介绍一下大熊猫
104
+
105
+ Assistant: 大熊猫是一种中国特有的哺乳动物,也是中国的国宝之一。它们的外貌特征是圆形的黑白相间的身体,有着黑色的毛发和白色的耳朵。大熊猫的食物主要是竹子,它们会在竹林中寻找竹子,并且会将竹子放在竹笼中进行储存。大熊猫的寿命约为20至30年,但由于栖息地的丧失和人类活动的
106
+ ```
107
+
108
+ #### Batch Inference
109
+
110
+ ```python
111
+ import torch
112
+ from transformers import AutoModelForCausalLM, AutoTokenizer
113
+
114
+ def generate_prompt(instruction, input=""):
115
+ instruction = instruction.strip().replace('\r\n', '\n').replace('\n\n', '\n')
116
+ input = input.strip().replace('\r\n', '\n').replace('\n\n', '\n')
117
+ if input:
118
+ return f"""Instruction: {instruction}
119
+
120
+ Input: {input}
121
+
122
+ Response:"""
123
+ else:
124
+ return f"""User: hi
125
+
126
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
127
+
128
+ User: {instruction}
129
+
130
+ Assistant:"""
131
+
132
+ model = AutoModelForCausalLM.from_pretrained("RWKV/rwkv-5-world-3b", trust_remote_code=True).to(torch.float32)
133
+ tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-3b", trust_remote_code=True)
134
+
135
+ texts = ["请介绍北京的旅游景点", "介绍一下大熊猫", "乌兰察布"]
136
+ prompts = [generate_prompt(text) for text in texts]
137
+
138
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True)
139
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=128, do_sample=True, temperature=1.0, top_p=0.3, top_k=0, )
140
+
141
+ for output in outputs:
142
+ print(tokenizer.decode(output.tolist(), skip_special_tokens=True))
143
 
 
144
  ```
145
 
146
+ output:
147
+
148
+ ```shell
149
+ User: hi
150
+
151
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
152
+
153
+ User: 请介绍北京的旅游景点
154
+
155
+ Assistant: 北京是中国的首都,拥有丰富的旅游资源和历史文化遗产。以下是一些北京的旅游景点:
156
+ 1. 故宫:位于北京市中心,是明清两代的皇宫,是中国最大的古代宫殿建筑群之一。
157
+ 2. 天安门广场:位于北京市中心,是中国最著名的城市广场之一,也是中国最大的城市广场。
158
+ 3. 颐和
159
+ User: hi
160
+
161
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
162
+
163
+ User: 介绍一下大熊猫
164
+
165
+ Assistant: 大熊猫是一种生活在中国中部地区的哺乳动物,也是中国的国宝之一。它们的外貌特征是圆形的黑白相间的身体,有着黑色的毛发和圆圆的眼睛。大熊猫是一种濒危物种,目前只有在野外的几个保护区才能看到它们的身影。大熊猫的食物主要是竹子,它们会在竹子上寻找食物,并且可以通
166
+ User: hi
167
+
168
+ Assistant: Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.
169
+
170
+ User: 乌兰察布
171
+
172
+ Assistant: 乌兰察布是中国新疆维吾尔自治区的一个县级市,位于新疆维吾尔自治区中部,是新疆的第二大城市。乌兰察布市是新疆的第一大城市,也是新疆的重要城市之一。乌兰察布市是新疆的经济中心,也是新疆的重要交通枢纽之一。乌兰察布市的人口约为2.5万人,其中汉族占绝大多数。乌
173
+ ```
modeling_rwkv5.py CHANGED
@@ -85,33 +85,46 @@ def rwkv_linear_attention_v5_0(H, S, T, hidden, time_decay, time_first, receptan
85
 
86
  return out, state
87
 
88
- def rwkv_linear_attention_v5_2(H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
 
 
89
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
90
  time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
91
  lxw = lxw.float()
92
  lxb = lxb.float()
93
- if seq_mode:
94
- out = torch.empty((T, H, S), dtype=receptance.dtype, device=receptance.device)
95
- for t in range(T):
96
- rt = receptance[:,t:t+1,:]
97
- kt = key[:,:,t:t+1]
98
- vt = value[:,t:t+1,:]
99
- at = kt @ vt
100
- out[t] = (rt @ (time_first * at + state.squeeze(0))).squeeze(1)
101
- state = at + time_decay * state
102
-
103
- out = out.reshape(T, H*S)
104
- out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb)
105
- out = out.to(dtype=hidden.dtype) * gate
106
- out = out @ ow
107
- else:
108
- a = key @ value
109
- out = receptance @ (time_first * a + state.squeeze(0))
110
- state = a + time_decay * state
111
- out = out.flatten()
112
- out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lxw, bias=lxb).squeeze(0)
113
- out = out.to(dtype=hidden.dtype) * gate
114
- out = out @ ow
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  return out, state
117
 
@@ -153,7 +166,7 @@ class RwkvSelfAttention(nn.Module):
153
  self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
154
 
155
  # TODO: maybe jit, otherwise move inside forward
156
- def extract_key_value(self, H, S, T, hidden, state=None):
157
  # Mix hidden with the previous timestep to produce key, value, receptance
158
  if hidden.size(1) == 1 and state is not None:
159
  shifted = state[0][:, :, self.layer_id]
@@ -161,25 +174,27 @@ class RwkvSelfAttention(nn.Module):
161
  shifted = self.time_shift(hidden)
162
  if state is not None:
163
  shifted[:, 0] = state[0][:, :, self.layer_id]
 
 
164
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
165
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
166
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
167
  if self.config.model_version == "5_2":
168
  gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
169
 
170
- if hidden.size(1) == 1 and state is not None:
171
- receptance = self.receptance(receptance).to(torch.float32).view(H, 1, S)
172
- key = self.key(key).to(torch.float32).view(H, S, 1)
173
- value = self.value(value).to(torch.float32).view(H, 1, S)
174
- else:
175
- # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
176
- key = self.key(key).to(torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1)
177
- value = self.value(value).to(torch.float32).view(T, H, S).transpose(0, 1)
178
- receptance = self.receptance(receptance).to(torch.float32).view(T, H, S).transpose(0, 1)
179
 
180
  if self.config.model_version == "5_2":
181
  gate = F.silu(self.gate(gate))
182
-
183
  if state is not None:
184
  state[0][:, :, self.layer_id] = hidden[:, -1]
185
 
@@ -188,17 +203,19 @@ class RwkvSelfAttention(nn.Module):
188
  return receptance, key, value, state
189
 
190
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
 
191
  H = self.time_decay.shape[0]
192
  S = hidden.shape[-1] // H
193
  T = hidden.shape[1]
194
 
195
  if self.config.model_version == "5_2":
196
- receptance, key, value, gate, state = self.extract_key_value(H, S, T, hidden, state=state)
197
  else:
198
  receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
199
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
200
  if self.config.model_version == "5_2":
201
  rwkv, layer_state = rwkv_linear_attention_v5_2(
 
202
  H,
203
  S,
204
  T,
@@ -273,6 +290,8 @@ class RwkvFeedForward(nn.Module):
273
  shifted = self.time_shift(hidden)
274
  if state is not None:
275
  shifted[:, 0] = state[2][:, :, self.layer_id]
 
 
276
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
277
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
278
 
@@ -594,7 +613,8 @@ class RwkvModel(RwkvPreTrainedModel):
594
 
595
 
596
  hidden_states = inputs_embeds
597
-
 
598
  all_self_attentions = () if output_attentions else None
599
  all_hidden_states = () if output_hidden_states else None
600
  for idx, block in enumerate(self.blocks):
@@ -645,7 +665,6 @@ class RwkvModel(RwkvPreTrainedModel):
645
 
646
  self.layers_are_rescaled = not self.training
647
 
648
-
649
  @add_start_docstrings(
650
  """
651
  The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
 
85
 
86
  return out, state
87
 
88
+ cnt = 0
89
+
90
+ def rwkv_linear_attention_v5_2(B, H, S, T, n_head, hidden, time_decay, time_first, receptance, key, value, gate, lxw, lxb, ow, state, return_state=False, seq_mode=True):
91
  time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1,1,1).reshape(n_head, -1, 1)
92
  time_first = time_first.float().reshape(-1,1,1).reshape(n_head, -1, 1)
93
  lxw = lxw.float()
94
  lxb = lxb.float()
95
+ # if seq_mode:
96
+ out = torch.empty((B, T, H, S), dtype=receptance.dtype, device=receptance.device)
97
+ for t in range(T):
98
+ rt = receptance[:,:,t:t+1,:]
99
+ kt = key[:,:,:,t:t+1]
100
+ vt = value[:,:,t:t+1,:]
101
+ at = kt @ vt
102
+ out[:, t] = (rt @ (time_first * at + state)).squeeze(2)
103
+ state = at + time_decay * state
104
+
105
+ out = out.reshape(B*T, H*S)
106
+ out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, T, H*S)
107
+ out = out.to(dtype=hidden.dtype) * gate
108
+ out = out @ ow
109
+ # else:
110
+ # a = key @ value
111
+ # # print('key.shape: ', key.shape)
112
+ # # print('value.shape: ', value.shape)
113
+ # # print('receptance.shape: ', receptance.shape)
114
+ # # print('a.shape: ', a.shape)
115
+ # # print('time_first.shape: ', time_first.shape)
116
+ # # print('(time_first * a).shape: ', (time_first * a).shape)
117
+ # # print('time_decay.shape: ', time_decay.shape)
118
+ # # print('state.shape: ', state.shape)
119
+ # out = receptance @ (time_first * a + state)
120
+ # # print('out.shape: ', out.shape)
121
+ # state = a + time_decay * state
122
+ # # print('state.shape: ', state.shape)
123
+ # out = out.reshape(B, H*S)
124
+ # out = F.group_norm(out, num_groups=H, weight=lxw, bias=lxb).reshape(B, 1, H*S)
125
+ # out = out.to(dtype=hidden.dtype) * gate
126
+ # out = out @ ow
127
+
128
 
129
  return out, state
130
 
 
166
  self.ln_x = nn.GroupNorm(hidden_size // config.head_size, hidden_size)
167
 
168
  # TODO: maybe jit, otherwise move inside forward
169
+ def extract_key_value(self, B, H, S, T, hidden, state=None):
170
  # Mix hidden with the previous timestep to produce key, value, receptance
171
  if hidden.size(1) == 1 and state is not None:
172
  shifted = state[0][:, :, self.layer_id]
 
174
  shifted = self.time_shift(hidden)
175
  if state is not None:
176
  shifted[:, 0] = state[0][:, :, self.layer_id]
177
+ if len(shifted.size()) == 2:
178
+ shifted = shifted.unsqueeze(1)
179
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
180
  value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value)
181
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
182
  if self.config.model_version == "5_2":
183
  gate = hidden* self.time_mix_gate + shifted * (1 - self.time_mix_gate)
184
 
185
+ # if hidden.size(1) == 1 and state is not None:
186
+ # receptance = self.receptance(receptance).to(torch.float32).view(B, H, 1, S)
187
+ # key = self.key(key).to(torch.float32).view(B, H, S, 1)
188
+ # value = self.value(value).to(torch.float32).view(B, H, 1, S)
189
+ # else:
190
+ # https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py#L693
191
+ key = self.key(key).to(torch.float32).view(B, T, H, S).transpose(1, 2).transpose(-2, -1)
192
+ value = self.value(value).to(torch.float32).view(B, T, H, S).transpose(1, 2)
193
+ receptance = self.receptance(receptance).to(torch.float32).view(B, T, H, S).transpose(1, 2)
194
 
195
  if self.config.model_version == "5_2":
196
  gate = F.silu(self.gate(gate))
197
+
198
  if state is not None:
199
  state[0][:, :, self.layer_id] = hidden[:, -1]
200
 
 
203
  return receptance, key, value, state
204
 
205
  def forward(self, hidden, state=None, use_cache=False, seq_mode=True):
206
+ B = hidden.shape[0]
207
  H = self.time_decay.shape[0]
208
  S = hidden.shape[-1] // H
209
  T = hidden.shape[1]
210
 
211
  if self.config.model_version == "5_2":
212
+ receptance, key, value, gate, state = self.extract_key_value(B, H, S, T, hidden, state=state)
213
  else:
214
  receptance, key, value, state = self.extract_key_value(H, S, T, hidden, state=state)
215
  layer_state = state[1][:, :, :, :, self.layer_id] if state is not None else None
216
  if self.config.model_version == "5_2":
217
  rwkv, layer_state = rwkv_linear_attention_v5_2(
218
+ B,
219
  H,
220
  S,
221
  T,
 
290
  shifted = self.time_shift(hidden)
291
  if state is not None:
292
  shifted[:, 0] = state[2][:, :, self.layer_id]
293
+ if len(shifted.size()) == 2:
294
+ shifted = shifted.unsqueeze(1)
295
  key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key)
296
  receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance)
297
 
 
613
 
614
 
615
  hidden_states = inputs_embeds
616
+ global cnt
617
+ cnt += 1
618
  all_self_attentions = () if output_attentions else None
619
  all_hidden_states = () if output_hidden_states else None
620
  for idx, block in enumerate(self.blocks):
 
665
 
666
  self.layers_are_rescaled = not self.training
667
 
 
668
  @add_start_docstrings(
669
  """
670
  The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input
tokenization_rwkv_world.py CHANGED
@@ -107,6 +107,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
107
  self,
108
  vocab_file,
109
  errors="replace",
 
110
  **kwargs
111
  ):
112
  self.add_bos_token = False
@@ -122,11 +123,7 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
122
  assert len(x) == int(l[l.rindex(' '):])
123
  sorted += [x]
124
  self.encoder[idx] = x
125
-
126
- super().__init__(
127
- errors=errors,
128
- **kwargs,
129
- )
130
  self.decoder = {}
131
  for k,v in self.encoder.items():
132
  self.decoder[v] = int(k)
@@ -136,6 +133,14 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
136
  _ = self.trie.add(t, val=(t, i))
137
  self.errors = errors # how to handle errors in decoding
138
  self.cache = {}
 
 
 
 
 
 
 
 
139
 
140
  @property
141
  def vocab_size(self):
@@ -143,6 +148,22 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
143
 
144
  def get_vocab(self):
145
  return dict(self.encoder, **self.added_tokens_encoder)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
148
  if self.add_bos_token:
@@ -219,14 +240,21 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
219
  skip_special_tokens: bool = False,
220
  **kwargs
221
  ) -> str:
222
-
 
 
 
 
 
223
  # Convert inputs to python lists
224
  token_ids = to_py_obj(token_ids)
 
225
  if isinstance(token_ids, int):
226
  if token_ids in self.all_special_ids and skip_special_tokens:
227
  return ""
228
  return self.encoder.get(token_ids, self.unk_token)
229
  elif isinstance(token_ids, list):
 
230
  out_str = ""
231
  out_last = 0
232
  out_tokens = []
@@ -268,6 +296,11 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
268
  def prepare_for_tokenization(self, text, **kwargs):
269
  return (text, kwargs)
270
 
 
 
 
 
 
271
  def _encode_plus(
272
  self,
273
  text: Union[TextInput, EncodedInput],
@@ -352,19 +385,33 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
352
  verbose: bool = True,
353
  **kwargs
354
  ) -> BatchEncoding:
355
- def get_input_ids(text):
 
 
 
356
  if isinstance(text, str):
357
- text_id = self._tokenize(text)
358
- return text_id
 
 
 
359
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
360
- return [self._tokenize(t) for t in text]
 
 
 
 
361
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
 
 
362
  return text
 
363
  else:
364
  raise ValueError(
365
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
366
  )
367
 
 
368
  if return_offsets_mapping:
369
  raise NotImplementedError(
370
  "return_offset_mapping is not available when using Python tokenizers. "
@@ -372,15 +419,29 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
372
  "transformers.PreTrainedTokenizerFast."
373
  )
374
 
375
- input_ids = []
 
376
  for ids_or_pair_ids in batch_text_or_text_pairs:
377
  if not isinstance(ids_or_pair_ids, (list, tuple)):
378
  ids, pair_ids = ids_or_pair_ids, None
379
  else:
380
  ids, pair_ids = ids_or_pair_ids
381
-
382
  first_ids = get_input_ids(ids)
383
  second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  input_ids.append((first_ids, second_ids))
385
 
386
  batch_outputs = self._batch_prepare_for_model(
@@ -401,6 +462,75 @@ class RWKVWorldTokenizer(PreTrainedTokenizer):
401
  )
402
 
403
  return BatchEncoding(batch_outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
404
 
405
  def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
406
  input_ids = []
 
107
  self,
108
  vocab_file,
109
  errors="replace",
110
+ pad_token="0",
111
  **kwargs
112
  ):
113
  self.add_bos_token = False
 
123
  assert len(x) == int(l[l.rindex(' '):])
124
  sorted += [x]
125
  self.encoder[idx] = x
126
+
 
 
 
 
127
  self.decoder = {}
128
  for k,v in self.encoder.items():
129
  self.decoder[v] = int(k)
 
133
  _ = self.trie.add(t, val=(t, i))
134
  self.errors = errors # how to handle errors in decoding
135
  self.cache = {}
136
+ self.first_max_length = 0
137
+
138
+ # pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
139
+ super().__init__(
140
+ errors=errors,
141
+ # pad_token=pad_token,
142
+ **kwargs,
143
+ )
144
 
145
  @property
146
  def vocab_size(self):
 
148
 
149
  def get_vocab(self):
150
  return dict(self.encoder, **self.added_tokens_encoder)
151
+
152
+ def add_tokens(self, new_tokens, special_tokens: bool = False):
153
+ for token in new_tokens:
154
+ token_id = self.convert_tokens_to_ids(token)
155
+ self.added_tokens_decoder[token_id] = token
156
+
157
+ def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
158
+ if isinstance(ids, int):
159
+ ids = [ids]
160
+ tokens = []
161
+ for id_ in ids:
162
+ if id_ in self.added_tokens_decoder:
163
+ tokens.append(self.added_tokens_decoder[id_])
164
+ else:
165
+ tokens.append(self._convert_id_to_token(id_))
166
+ return tokens
167
 
168
  def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
169
  if self.add_bos_token:
 
240
  skip_special_tokens: bool = False,
241
  **kwargs
242
  ) -> str:
243
+
244
+ def remove_zeros_from_first_segment(token_ids, first_max_length):
245
+ first_segment = token_ids[:first_max_length]
246
+ first_segment_cleaned = [token for token in first_segment if token != 0]
247
+ return first_segment_cleaned + token_ids[first_max_length:]
248
+
249
  # Convert inputs to python lists
250
  token_ids = to_py_obj(token_ids)
251
+ token_ids = remove_zeros_from_first_segment(token_ids, self.first_max_length)
252
  if isinstance(token_ids, int):
253
  if token_ids in self.all_special_ids and skip_special_tokens:
254
  return ""
255
  return self.encoder.get(token_ids, self.unk_token)
256
  elif isinstance(token_ids, list):
257
+ self.first_max_length
258
  out_str = ""
259
  out_last = 0
260
  out_tokens = []
 
296
  def prepare_for_tokenization(self, text, **kwargs):
297
  return (text, kwargs)
298
 
299
+ def _get_padding_truncation_strategies(
300
+ self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
301
+ ):
302
+ return PaddingStrategy.LONGEST, TruncationStrategy.DO_NOT_TRUNCATE, -1, kwargs
303
+
304
  def _encode_plus(
305
  self,
306
  text: Union[TextInput, EncodedInput],
 
385
  verbose: bool = True,
386
  **kwargs
387
  ) -> BatchEncoding:
388
+ def get_input_ids(text, max_length=None, pad_token_id=0):
389
+ def pad_sequence(seq, max_len, pad_tok):
390
+ return [pad_tok] * (max_len - len(seq)) + seq
391
+
392
  if isinstance(text, str):
393
+ tokens = self._tokenize(text)
394
+ if max_length is not None:
395
+ tokens = pad_sequence(tokens, max_length, pad_token_id)
396
+ return tokens
397
+
398
  elif isinstance(text, list) and len(text) > 0 and isinstance(text[0], str):
399
+ tokenized_texts = [self._tokenize(t) for t in text]
400
+ if max_length is None:
401
+ max_length = max(len(t) for t in tokenized_texts)
402
+ return [pad_sequence(t, max_length, pad_token_id) for t in tokenized_texts]
403
+
404
  elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
405
+ if max_length is not None and len(text) < max_length:
406
+ return pad_sequence(text, max_length, pad_token_id)
407
  return text
408
+
409
  else:
410
  raise ValueError(
411
  "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers."
412
  )
413
 
414
+
415
  if return_offsets_mapping:
416
  raise NotImplementedError(
417
  "return_offset_mapping is not available when using Python tokenizers. "
 
419
  "transformers.PreTrainedTokenizerFast."
420
  )
421
 
422
+ first_max_length = 0
423
+ second_max_length = 0
424
  for ids_or_pair_ids in batch_text_or_text_pairs:
425
  if not isinstance(ids_or_pair_ids, (list, tuple)):
426
  ids, pair_ids = ids_or_pair_ids, None
427
  else:
428
  ids, pair_ids = ids_or_pair_ids
 
429
  first_ids = get_input_ids(ids)
430
  second_ids = get_input_ids(pair_ids) if pair_ids is not None else None
431
+ first_max_length = max(first_max_length, len(first_ids))
432
+ if second_ids is not None:
433
+ second_max_length = max(second_max_length, len(second_ids))
434
+
435
+ self.first_max_length = first_max_length
436
+ input_ids = []
437
+ for ids_or_pair_ids in batch_text_or_text_pairs:
438
+ if not isinstance(ids_or_pair_ids, (list, tuple)):
439
+ ids, pair_ids = ids_or_pair_ids, None
440
+ else:
441
+ ids, pair_ids = ids_or_pair_ids
442
+
443
+ first_ids = get_input_ids(ids, max_length=first_max_length)
444
+ second_ids = get_input_ids(pair_ids, max_length=second_max_length) if pair_ids is not None else None
445
  input_ids.append((first_ids, second_ids))
446
 
447
  batch_outputs = self._batch_prepare_for_model(
 
462
  )
463
 
464
  return BatchEncoding(batch_outputs)
465
+
466
+ def decode(
467
+ self,
468
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
469
+ skip_special_tokens: bool = False,
470
+ clean_up_tokenization_spaces: bool = None,
471
+ **kwargs,
472
+ ) -> str:
473
+ """
474
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
475
+ tokens and clean up tokenization spaces.
476
+
477
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
478
+
479
+ Args:
480
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
481
+ List of tokenized input ids. Can be obtained using the `__call__` method.
482
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
483
+ Whether or not to remove special tokens in the decoding.
484
+ clean_up_tokenization_spaces (`bool`, *optional*):
485
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
486
+ `self.clean_up_tokenization_spaces`.
487
+ kwargs (additional keyword arguments, *optional*):
488
+ Will be passed to the underlying model specific decode method.
489
+
490
+ Returns:
491
+ `str`: The decoded sentence.
492
+ """
493
+ # Convert inputs to python lists
494
+ return self._decode(
495
+ token_ids=token_ids,
496
+ skip_special_tokens=skip_special_tokens,
497
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
498
+ **kwargs,
499
+ )
500
+
501
+ def batch_decode(
502
+ self,
503
+ sequences: Union[List[int], List[List[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
504
+ skip_special_tokens: bool = False,
505
+ clean_up_tokenization_spaces: bool = None,
506
+ **kwargs,
507
+ ) -> List[str]:
508
+ """
509
+ Convert a list of lists of token ids into a list of strings by calling decode.
510
+
511
+ Args:
512
+ sequences (`Union[List[int], List[List[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
513
+ List of tokenized input ids. Can be obtained using the `__call__` method.
514
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
515
+ Whether or not to remove special tokens in the decoding.
516
+ clean_up_tokenization_spaces (`bool`, *optional*):
517
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
518
+ `self.clean_up_tokenization_spaces`.
519
+ kwargs (additional keyword arguments, *optional*):
520
+ Will be passed to the underlying model specific decode method.
521
+
522
+ Returns:
523
+ `List[str]`: The list of decoded sentences.
524
+ """
525
+ return [
526
+ self.decode(
527
+ seq,
528
+ skip_special_tokens=skip_special_tokens,
529
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
530
+ **kwargs,
531
+ )
532
+ for seq in sequences
533
+ ]
534
 
535
  def _build_conversation_input_ids(self, conversation: "Conversation") -> List[int]:
536
  input_ids = []