.gitattributes CHANGED
@@ -33,4 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- model.TGT filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -60,78 +60,7 @@ Please refer to `Appendix D: Model Card` of the [preprint](https://arxiv.org/abs
60
 
61
  ### Usage Instructions
62
 
63
- Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_interface) for a detail description on how to use HF compatible IndicTrans2 models for inference.
64
-
65
- ```python
66
- import torch
67
- from transformers import (
68
- AutoModelForSeq2SeqLM,
69
- AutoTokenizer,
70
- )
71
- from IndicTransTokenizer import IndicProcessor
72
-
73
-
74
- model_name = "ai4bharat/indictrans2-en-indic-1B"
75
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
76
-
77
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
78
-
79
- ip = IndicProcessor(inference=True)
80
-
81
- input_sentences = [
82
- "When I was young, I used to go to the park every day.",
83
- "We watched a new movie last week, which was very inspiring.",
84
- "If you had met me at that time, we would have gone out to eat.",
85
- "My friend has invited me to his birthday party, and I will give him a gift.",
86
- ]
87
-
88
- src_lang, tgt_lang = "eng_Latn", "hin_Deva"
89
-
90
- batch = ip.preprocess_batch(
91
- input_sentences,
92
- src_lang=src_lang,
93
- tgt_lang=tgt_lang,
94
- )
95
-
96
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
97
-
98
- # Tokenize the sentences and generate input encodings
99
- inputs = tokenizer(
100
- batch,
101
- truncation=True,
102
- padding="longest",
103
- return_tensors="pt",
104
- return_attention_mask=True,
105
- ).to(DEVICE)
106
-
107
- # Generate translations using the model
108
- with torch.no_grad():
109
- generated_tokens = model.generate(
110
- **inputs,
111
- use_cache=True,
112
- min_length=0,
113
- max_length=256,
114
- num_beams=5,
115
- num_return_sequences=1,
116
- )
117
-
118
- # Decode the generated tokens into text
119
- with tokenizer.as_target_tokenizer():
120
- generated_tokens = tokenizer.batch_decode(
121
- generated_tokens.detach().cpu().tolist(),
122
- skip_special_tokens=True,
123
- clean_up_tokenization_spaces=True,
124
- )
125
-
126
- # Postprocess the translations, including entity replacement
127
- translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
128
-
129
- for input_sentence, translation in zip(input_sentences, translations):
130
- print(f"{src_lang}: {input_sentence}")
131
- print(f"{tgt_lang}: {translation}")
132
- ```
133
-
134
- **Note: IndicTrans2 is now compatible with AutoTokenizer, however you need to use IndicProcessor from [IndicTransTokenizer](https://github.com/VarunGumma/IndicTransTokenizer) for preprocessing before tokenization.**
135
 
136
 
137
  ### Citation
 
60
 
61
  ### Usage Instructions
62
 
63
+ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2/tree/main/huggingface_inference) for a detail description on how to use HF compatible IndicTrans2 models for inference.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  ### Citation
config.json CHANGED
@@ -9,7 +9,6 @@
9
  "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
  "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
11
  },
12
- "tokenizer_class": "IndicTransTokenizer",
13
  "attention_dropout": 0.0,
14
  "bos_token_id": 0,
15
  "decoder_attention_heads": 16,
@@ -41,6 +40,5 @@
41
  "share_decoder_input_output_embed": false,
42
  "torch_dtype": "float32",
43
  "transformers_version": "4.32.1",
44
- "use_cache": true,
45
- "attn_implementation": "eager"
46
  }
 
9
  "AutoConfig": "configuration_indictrans.IndicTransConfig",
10
  "AutoModelForSeq2SeqLM": "modeling_indictrans.IndicTransForConditionalGeneration"
11
  },
 
12
  "attention_dropout": 0.0,
13
  "bos_token_id": 0,
14
  "decoder_attention_heads": 16,
 
40
  "share_decoder_input_output_embed": false,
41
  "torch_dtype": "float32",
42
  "transformers_version": "4.32.1",
43
+ "use_cache": true
 
44
  }
configuration_indictrans.py CHANGED
@@ -118,7 +118,6 @@ class IndicTransConfig(PretrainedConfig):
118
  pad_token_id=1,
119
  bos_token_id=0,
120
  eos_token_id=2,
121
- attn_implementation="eager",
122
  **kwargs,
123
  ):
124
  self.encoder_vocab_size = encoder_vocab_size
@@ -147,8 +146,7 @@ class IndicTransConfig(PretrainedConfig):
147
  self.num_hidden_layers = encoder_layers
148
  self.scale_embedding = scale_embedding
149
  self.share_decoder_input_output_embed = share_decoder_input_output_embed
150
- self.attn_implementation = attn_implementation
151
-
152
  super().__init__(
153
  pad_token_id=pad_token_id,
154
  bos_token_id=bos_token_id,
 
118
  pad_token_id=1,
119
  bos_token_id=0,
120
  eos_token_id=2,
 
121
  **kwargs,
122
  ):
123
  self.encoder_vocab_size = encoder_vocab_size
 
146
  self.num_hidden_layers = encoder_layers
147
  self.scale_embedding = scale_embedding
148
  self.share_decoder_input_output_embed = share_decoder_input_output_embed
149
+
 
150
  super().__init__(
151
  pad_token_id=pad_token_id,
152
  bos_token_id=bos_token_id,
dict.SRC.json DELETED
The diff for this file is too large to render. See raw diff
 
dict.TGT.json DELETED
The diff for this file is too large to render. See raw diff
 
model.SRC DELETED
Binary file (759 kB)
 
model.TGT DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac9257c8e76b8b607705b959cc3d075656ea33032f7a974e467b8941df6e98d4
3
- size 3256903
 
 
 
 
model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:35d28fe035cd6ac026536b555558b07762425c8b930670219063e4fc3666c96d
3
- size 4462265272
 
 
 
 
modeling_indictrans.py CHANGED
@@ -23,28 +23,15 @@ import torch.nn as nn
23
  from torch.nn import functional as F
24
 
25
  from transformers.activations import ACT2FN
26
-
27
- from transformers.modeling_attn_mask_utils import (
28
- _prepare_4d_attention_mask,
29
- _prepare_4d_attention_mask_for_sdpa,
30
- _prepare_4d_causal_attention_mask,
31
- _prepare_4d_causal_attention_mask_for_sdpa,
32
- )
33
-
34
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
35
  from transformers.modeling_outputs import (
36
  BaseModelOutput,
37
  BaseModelOutputWithPastAndCrossAttentions,
38
  Seq2SeqLMOutput,
39
- Seq2SeqModelOutput
40
  )
41
 
42
- from transformers.utils import (
43
- logging,
44
- is_flash_attn_2_available,
45
- is_flash_attn_greater_or_equal_2_10,
46
- )
47
-
48
  from transformers.modeling_utils import PreTrainedModel
49
 
50
  from .configuration_indictrans import IndicTransConfig
@@ -52,27 +39,9 @@ from .configuration_indictrans import IndicTransConfig
52
 
53
  logger = logging.get_logger(__name__)
54
 
55
- INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
56
 
57
- try:
58
- if is_flash_attn_2_available():
59
- from flash_attn import flash_attn_func, flash_attn_varlen_func
60
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
61
- except:
62
- pass
63
-
64
-
65
- # Copied from transformers.models.llama.modeling_llama._get_unpad_data
66
- def _get_unpad_data(attention_mask):
67
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
68
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
69
- max_seqlen_in_batch = seqlens_in_batch.max().item()
70
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
71
- return (
72
- indices,
73
- cu_seqlens,
74
- max_seqlen_in_batch,
75
- )
76
 
77
 
78
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
@@ -94,6 +63,54 @@ def shift_tokens_right(
94
  return shifted_input_ids
95
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def create_position_ids_from_input_ids(
98
  input_ids, padding_idx, past_key_values_length=0
99
  ):
@@ -230,15 +247,12 @@ class IndicTransAttention(nn.Module):
230
  dropout: float = 0.0,
231
  is_decoder: bool = False,
232
  bias: bool = True,
233
- is_causal: bool = False,
234
- config: Optional[IndicTransConfig] = None,
235
  ):
236
  super().__init__()
237
  self.embed_dim = embed_dim
238
  self.num_heads = num_heads
239
  self.dropout = dropout
240
  self.head_dim = embed_dim // num_heads
241
- self.config = config
242
 
243
  if (self.head_dim * num_heads) != self.embed_dim:
244
  raise ValueError(
@@ -247,7 +261,6 @@ class IndicTransAttention(nn.Module):
247
  )
248
  self.scaling = self.head_dim**-0.5
249
  self.is_decoder = is_decoder
250
- self.is_causal = is_causal
251
 
252
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
253
  self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -389,345 +402,17 @@ class IndicTransAttention(nn.Module):
389
  attn_output = self.out_proj(attn_output)
390
 
391
  return attn_output, attn_weights_reshaped, past_key_value
392
-
393
-
394
- class IndicTransFlashAttention2(IndicTransAttention):
395
- """
396
- IndicTrans flash attention module. This module inherits from `IndicTransAttention` as the weights of the module stays
397
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
398
- flash attention and deal with padding tokens in case the input contains any of them.
399
- """
400
-
401
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
402
- def __init__(self, *args, **kwargs):
403
- super().__init__(*args, **kwargs)
404
-
405
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
406
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
407
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
408
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
409
-
410
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
411
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
412
-
413
- def forward(
414
- self,
415
- hidden_states: torch.Tensor,
416
- key_value_states: Optional[torch.Tensor] = None,
417
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
418
- attention_mask: Optional[torch.Tensor] = None,
419
- layer_head_mask: Optional[torch.Tensor] = None,
420
- output_attentions: bool = False,
421
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
422
- # IndicTransFlashAttention2 attention does not support output_attentions
423
- if output_attentions:
424
- raise ValueError("IndicTransFlashAttention2 attention does not support output_attentions")
425
-
426
- # if key_value_states are provided this layer is used as a cross-attention layer
427
- # for the decoder
428
- is_cross_attention = key_value_states is not None
429
-
430
- bsz, q_len, _ = hidden_states.size()
431
-
432
- # get query proj
433
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
434
- # get key, value proj
435
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
436
- # is checking that the `sequence_length` of the `past_key_value` is the same as
437
- # the provided `key_value_states` to support prefix tuning
438
- if (
439
- is_cross_attention
440
- and past_key_value is not None
441
- and past_key_value[0].shape[2] == key_value_states.shape[1]
442
- ):
443
- # reuse k,v, cross_attentions
444
- key_states = past_key_value[0].transpose(1, 2)
445
- value_states = past_key_value[1].transpose(1, 2)
446
- elif is_cross_attention:
447
- # cross_attentions
448
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
449
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
450
- elif past_key_value is not None:
451
- # reuse k, v, self_attention
452
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
453
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
454
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
455
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
456
- else:
457
- # self_attention
458
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
459
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
460
-
461
- if self.is_decoder:
462
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
463
- # Further calls to cross_attention layer can then reuse all cross-attention
464
- # key/value_states (first "if" case)
465
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
466
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
467
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
468
- # if encoder bi-directional self-attention `past_key_value` is always `None`
469
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
470
-
471
- kv_seq_len = key_states.shape[-2]
472
- if past_key_value is not None:
473
- kv_seq_len += past_key_value[0].shape[-2]
474
-
475
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
476
- # therefore the input hidden states gets silently casted in float32. Hence, we need
477
- # cast them back in the correct dtype just to be sure everything works as expected.
478
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
479
- # in fp32. (LlamaRMSNorm handles it correctly)
480
-
481
- input_dtype = query_states.dtype
482
- if input_dtype == torch.float32:
483
- if torch.is_autocast_enabled():
484
- target_dtype = torch.get_autocast_gpu_dtype()
485
- # Handle the case where the model is quantized
486
- elif hasattr(self.config, "_pre_quantization_dtype"):
487
- target_dtype = self.config._pre_quantization_dtype
488
- else:
489
- target_dtype = self.q_proj.weight.dtype
490
-
491
- logger.warning_once(
492
- f"The input hidden states seems to be silently casted in float32, this might be related to"
493
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
494
- f" {target_dtype}."
495
- )
496
-
497
- query_states = query_states.to(target_dtype)
498
- key_states = key_states.to(target_dtype)
499
- value_states = value_states.to(target_dtype)
500
-
501
- attn_output = self._flash_attention_forward(
502
- query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
503
- )
504
-
505
- attn_output = attn_output.reshape(bsz, q_len, -1)
506
- attn_output = self.out_proj(attn_output)
507
-
508
- if not output_attentions:
509
- attn_weights = None
510
-
511
- return attn_output, attn_weights, past_key_value
512
-
513
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
514
- def _flash_attention_forward(
515
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
516
- ):
517
- """
518
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
519
- first unpad the input, then computes the attention scores and pad the final attention scores.
520
-
521
- Args:
522
- query_states (`torch.Tensor`):
523
- Input query states to be passed to Flash Attention API
524
- key_states (`torch.Tensor`):
525
- Input key states to be passed to Flash Attention API
526
- value_states (`torch.Tensor`):
527
- Input value states to be passed to Flash Attention API
528
- attention_mask (`torch.Tensor`):
529
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
530
- position of padding tokens and 1 for the position of non-padding tokens.
531
- dropout (`float`):
532
- Attention dropout
533
- softmax_scale (`float`, *optional*):
534
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
535
- """
536
- if not self._flash_attn_uses_top_left_mask:
537
- causal = self.is_causal
538
- else:
539
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
540
- causal = self.is_causal and query_length != 1
541
-
542
- # Contains at least one padding token in the sequence
543
- if attention_mask is not None:
544
- batch_size = query_states.shape[0]
545
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
546
- query_states, key_states, value_states, attention_mask, query_length
547
- )
548
-
549
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
550
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
551
-
552
- attn_output_unpad = flash_attn_varlen_func(
553
- query_states,
554
- key_states,
555
- value_states,
556
- cu_seqlens_q=cu_seqlens_q,
557
- cu_seqlens_k=cu_seqlens_k,
558
- max_seqlen_q=max_seqlen_in_batch_q,
559
- max_seqlen_k=max_seqlen_in_batch_k,
560
- dropout_p=dropout,
561
- softmax_scale=softmax_scale,
562
- causal=causal,
563
- )
564
-
565
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
566
- else:
567
- attn_output = flash_attn_func(
568
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
569
- )
570
-
571
- return attn_output
572
 
573
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
574
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
575
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
576
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
577
-
578
- key_layer = index_first_axis(
579
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
580
- )
581
- value_layer = index_first_axis(
582
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
583
- )
584
- if query_length == kv_seq_len:
585
- query_layer = index_first_axis(
586
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
587
- )
588
- cu_seqlens_q = cu_seqlens_k
589
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
590
- indices_q = indices_k
591
- elif query_length == 1:
592
- max_seqlen_in_batch_q = 1
593
- cu_seqlens_q = torch.arange(
594
- batch_size + 1, dtype=torch.int32, device=query_layer.device
595
- ) # There is a memcpy here, that is very bad.
596
- indices_q = cu_seqlens_q[:-1]
597
- query_layer = query_layer.squeeze(1)
598
- else:
599
- # The -q_len: slice assumes left padding.
600
- attention_mask = attention_mask[:, -query_length:]
601
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
602
-
603
- return (
604
- query_layer,
605
- key_layer,
606
- value_layer,
607
- indices_q,
608
- (cu_seqlens_q, cu_seqlens_k),
609
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
610
- )
611
-
612
-
613
- class IndicTransSdpaAttention(IndicTransAttention):
614
- def forward(
615
- self,
616
- hidden_states: torch.Tensor,
617
- key_value_states: Optional[torch.Tensor] = None,
618
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
619
- attention_mask: Optional[torch.Tensor] = None,
620
- layer_head_mask: Optional[torch.Tensor] = None,
621
- output_attentions: bool = False,
622
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
623
- """Input shape: Batch x Time x Channel"""
624
- if output_attentions or layer_head_mask is not None:
625
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
626
- logger.warning_once(
627
- "IndicTransModel is using IndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
628
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
629
- )
630
- return super().forward(
631
- hidden_states,
632
- key_value_states=key_value_states,
633
- past_key_value=past_key_value,
634
- attention_mask=attention_mask,
635
- layer_head_mask=layer_head_mask,
636
- output_attentions=output_attentions,
637
- )
638
-
639
- # if key_value_states are provided this layer is used as a cross-attention layer
640
- # for the decoder
641
- is_cross_attention = key_value_states is not None
642
-
643
- bsz, tgt_len, _ = hidden_states.size()
644
-
645
- # get query proj
646
- query_states = self.q_proj(hidden_states)
647
- # get key, value proj
648
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
649
- # is checking that the `sequence_length` of the `past_key_value` is the same as
650
- # the provided `key_value_states` to support prefix tuning
651
- if (
652
- is_cross_attention
653
- and past_key_value is not None
654
- and past_key_value[0].shape[2] == key_value_states.shape[1]
655
- ):
656
- # reuse k,v, cross_attentions
657
- key_states = past_key_value[0]
658
- value_states = past_key_value[1]
659
- elif is_cross_attention:
660
- # cross_attentions
661
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
662
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
663
- elif past_key_value is not None:
664
- # reuse k, v, self_attention
665
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
666
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
667
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
668
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
669
- else:
670
- # self_attention
671
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
672
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
673
-
674
- if self.is_decoder:
675
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
676
- # Further calls to cross_attention layer can then reuse all cross-attention
677
- # key/value_states (first "if" case)
678
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
679
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
680
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
681
- # if encoder bi-directional self-attention `past_key_value` is always `None`
682
- past_key_value = (key_states, value_states)
683
-
684
- query_states = self._shape(query_states, tgt_len, bsz)
685
-
686
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
687
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
688
- attn_output = F.scaled_dot_product_attention(
689
- query_states,
690
- key_states,
691
- value_states,
692
- attn_mask=attention_mask,
693
- dropout_p=self.dropout if self.training else 0.0,
694
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
695
- is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
696
- )
697
-
698
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
699
- raise ValueError(
700
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
701
- f" {attn_output.size()}"
702
- )
703
-
704
- attn_output = attn_output.transpose(1, 2)
705
-
706
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
707
- # partitioned across GPUs when using tensor-parallelism.
708
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
709
-
710
- attn_output = self.out_proj(attn_output)
711
-
712
- return attn_output, None, past_key_value
713
-
714
-
715
- INDICTRANS_ATTENTION_CLASSES = {
716
- "eager": IndicTransAttention,
717
- "sdpa": IndicTransSdpaAttention,
718
- "flash_attention_2": IndicTransFlashAttention2,
719
- }
720
 
721
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
722
  class IndicTransEncoderLayer(nn.Module):
723
  def __init__(self, config: IndicTransConfig):
724
  super().__init__()
725
  self.embed_dim = config.encoder_embed_dim
726
- self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
727
  embed_dim=self.embed_dim,
728
  num_heads=config.encoder_attention_heads,
729
  dropout=config.attention_dropout,
730
- config=config,
731
  )
732
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
733
  self.dropout = config.dropout
@@ -805,25 +490,22 @@ class IndicTransDecoderLayer(nn.Module):
805
  super().__init__()
806
  self.embed_dim = config.decoder_embed_dim
807
 
808
- self.self_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
809
  embed_dim=self.embed_dim,
810
  num_heads=config.decoder_attention_heads,
811
  dropout=config.attention_dropout,
812
  is_decoder=True,
813
- is_causal=True,
814
- config=config,
815
  )
816
  self.dropout = config.dropout
817
  self.activation_fn = ACT2FN[config.activation_function]
818
  self.activation_dropout = config.activation_dropout
819
 
820
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
821
- self.encoder_attn = INDICTRANS_ATTENTION_CLASSES[config._attn_implementation](
822
  self.embed_dim,
823
  config.decoder_attention_heads,
824
  dropout=config.attention_dropout,
825
  is_decoder=True,
826
- config=config,
827
  )
828
  self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
829
  self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
@@ -1011,9 +693,6 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
1011
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1012
  )
1013
 
1014
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1015
- self._use_sdpa = config._attn_implementation == "sdpa"
1016
-
1017
  self.gradient_checkpointing = False
1018
  # Initialize weights and apply final processing
1019
  self.post_init()
@@ -1100,21 +779,13 @@ class IndicTransEncoder(IndicTransPreTrainedModel):
1100
 
1101
  hidden_states = inputs_embeds + embed_pos
1102
  if self.layernorm_embedding is not None:
1103
- hidden_states = self.layernorm_embedding(hidden_states)
1104
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
1105
 
 
1106
  if attention_mask is not None:
1107
- if self._use_flash_attention_2:
1108
- attention_mask = attention_mask if 0 in attention_mask else None
1109
- elif self._use_sdpa and head_mask is None and not output_attentions:
1110
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
1111
- # the manual implementation that requires a 4D causal mask in all cases.
1112
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1113
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
1114
- else:
1115
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1116
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
1117
-
1118
 
1119
  encoder_states = () if output_hidden_states else None
1120
  all_attentions = () if output_attentions else None
@@ -1238,9 +909,6 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1238
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
1239
  )
1240
 
1241
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
1242
- self._use_sdpa = config._attn_implementation == "sdpa"
1243
-
1244
  self.gradient_checkpointing = False
1245
  # Initialize weights and apply final processing
1246
  self.post_init()
@@ -1363,43 +1031,29 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1363
  if inputs_embeds is None:
1364
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1365
 
1366
-
1367
- if self._use_flash_attention_2:
1368
- # 2d mask is passed through the layers
1369
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1370
- elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
1371
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1372
- # the manual implementation that requires a 4D causal mask in all cases.
1373
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1374
- attention_mask,
1375
  input_shape,
1376
- inputs_embeds,
1377
- past_key_values_length,
 
1378
  )
1379
- else:
1380
- # 4d mask is passed through the layers
1381
- attention_mask = _prepare_4d_causal_attention_mask(
1382
- attention_mask, input_shape, inputs_embeds, past_key_values_length
 
1383
  )
1384
 
1385
  # expand encoder attention mask
1386
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1387
- if self._use_flash_attention_2:
1388
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
1389
- elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
1390
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
1391
- # the manual implementation that requires a 4D causal mask in all cases.
1392
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1393
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
1394
- encoder_attention_mask,
1395
- inputs_embeds.dtype,
1396
- tgt_len=input_shape[-1],
1397
- )
1398
- else:
1399
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1400
- encoder_attention_mask = _prepare_4d_attention_mask(
1401
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1402
- )
1403
 
1404
  # embed positions
1405
  positions = self.embed_positions(
@@ -1470,7 +1124,7 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1470
  layer_outputs = torch.utils.checkpoint.checkpoint(
1471
  create_custom_forward(decoder_layer),
1472
  hidden_states,
1473
- attention_mask,
1474
  encoder_hidden_states,
1475
  encoder_attention_mask,
1476
  head_mask[idx] if head_mask is not None else None,
@@ -1482,7 +1136,7 @@ class IndicTransDecoder(IndicTransPreTrainedModel):
1482
  else:
1483
  layer_outputs = decoder_layer(
1484
  hidden_states,
1485
- attention_mask=attention_mask,
1486
  encoder_hidden_states=encoder_hidden_states,
1487
  encoder_attention_mask=encoder_attention_mask,
1488
  layer_head_mask=(
@@ -1739,7 +1393,7 @@ class IndicTransForConditionalGeneration(IndicTransPreTrainedModel):
1739
  masked_lm_loss = F.cross_entropy(
1740
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
1741
  target=labels.view(-1),
1742
- ignore_index=-100,
1743
  label_smoothing=self._label_smoothing,
1744
  )
1745
 
 
23
  from torch.nn import functional as F
24
 
25
  from transformers.activations import ACT2FN
 
 
 
 
 
 
 
 
26
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
27
  from transformers.modeling_outputs import (
28
  BaseModelOutput,
29
  BaseModelOutputWithPastAndCrossAttentions,
30
  Seq2SeqLMOutput,
31
+ Seq2SeqModelOutput,
32
  )
33
 
34
+ from transformers.utils import logging
 
 
 
 
 
35
  from transformers.modeling_utils import PreTrainedModel
36
 
37
  from .configuration_indictrans import IndicTransConfig
 
39
 
40
  logger = logging.get_logger(__name__)
41
 
42
+ _CONFIG_FOR_DOC = "IndicTransConfig"
43
 
44
+ INDICTRANS_PRETRAINED_MODEL_ARCHIVE_LIST = [""]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
 
47
  # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
 
63
  return shifted_input_ids
64
 
65
 
66
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
67
+ def _make_causal_mask(
68
+ input_ids_shape: torch.Size,
69
+ dtype: torch.dtype,
70
+ device: torch.device,
71
+ past_key_values_length: int = 0,
72
+ ):
73
+ """
74
+ Make causal mask used for bi-directional self-attention.
75
+ """
76
+ bsz, tgt_len = input_ids_shape
77
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
78
+ mask_cond = torch.arange(mask.size(-1), device=device)
79
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
80
+ mask = mask.to(dtype)
81
+
82
+ if past_key_values_length > 0:
83
+ mask = torch.cat(
84
+ [
85
+ torch.zeros(
86
+ tgt_len, past_key_values_length, dtype=dtype, device=device
87
+ ),
88
+ mask,
89
+ ],
90
+ dim=-1,
91
+ )
92
+ return mask[None, None, :, :].expand(
93
+ bsz, 1, tgt_len, tgt_len + past_key_values_length
94
+ )
95
+
96
+
97
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
98
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
99
+ """
100
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
101
+ """
102
+ bsz, src_len = mask.size()
103
+ tgt_len = tgt_len if tgt_len is not None else src_len
104
+
105
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
106
+
107
+ inverted_mask = 1.0 - expanded_mask
108
+
109
+ return inverted_mask.masked_fill(
110
+ inverted_mask.to(torch.bool), torch.finfo(dtype).min
111
+ )
112
+
113
+
114
  def create_position_ids_from_input_ids(
115
  input_ids, padding_idx, past_key_values_length=0
116
  ):
 
247
  dropout: float = 0.0,
248
  is_decoder: bool = False,
249
  bias: bool = True,
 
 
250
  ):
251
  super().__init__()
252
  self.embed_dim = embed_dim
253
  self.num_heads = num_heads
254
  self.dropout = dropout
255
  self.head_dim = embed_dim // num_heads
 
256
 
257
  if (self.head_dim * num_heads) != self.embed_dim:
258
  raise ValueError(
 
261
  )
262
  self.scaling = self.head_dim**-0.5
263
  self.is_decoder = is_decoder
 
264
 
265
  self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
266
  self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
 
402
  attn_output = self.out_proj(attn_output)
403
 
404
  return attn_output, attn_weights_reshaped, past_key_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->IndicTrans
408
  class IndicTransEncoderLayer(nn.Module):
409
  def __init__(self, config: IndicTransConfig):
410
  super().__init__()
411
  self.embed_dim = config.encoder_embed_dim
412
+ self.self_attn = IndicTransAttention(
413
  embed_dim=self.embed_dim,
414
  num_heads=config.encoder_attention_heads,
415
  dropout=config.attention_dropout,
 
416
  )
417
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
418
  self.dropout = config.dropout
 
490
  super().__init__()
491
  self.embed_dim = config.decoder_embed_dim
492
 
493
+ self.self_attn = IndicTransAttention(
494
  embed_dim=self.embed_dim,
495
  num_heads=config.decoder_attention_heads,
496
  dropout=config.attention_dropout,
497
  is_decoder=True,
 
 
498
  )
499
  self.dropout = config.dropout
500
  self.activation_fn = ACT2FN[config.activation_function]
501
  self.activation_dropout = config.activation_dropout
502
 
503
  self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
504
+ self.encoder_attn = IndicTransAttention(
505
  self.embed_dim,
506
  config.decoder_attention_heads,
507
  dropout=config.attention_dropout,
508
  is_decoder=True,
 
509
  )
510
  self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
511
  self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
 
693
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
694
  )
695
 
 
 
 
696
  self.gradient_checkpointing = False
697
  # Initialize weights and apply final processing
698
  self.post_init()
 
779
 
780
  hidden_states = inputs_embeds + embed_pos
781
  if self.layernorm_embedding is not None:
782
+ x = self.layernorm_embedding(hidden_states)
783
  hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
784
 
785
+ # expand attention_mask
786
  if attention_mask is not None:
787
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
788
+ attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
 
 
 
 
 
 
 
 
 
789
 
790
  encoder_states = () if output_hidden_states else None
791
  all_attentions = () if output_attentions else None
 
909
  nn.LayerNorm(embed_dim) if config.layernorm_embedding else None
910
  )
911
 
 
 
 
912
  self.gradient_checkpointing = False
913
  # Initialize weights and apply final processing
914
  self.post_init()
 
1031
  if inputs_embeds is None:
1032
  inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
1033
 
1034
+ # create causal mask
1035
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1036
+ combined_attention_mask = None
1037
+ if input_shape[-1] > 1:
1038
+ combined_attention_mask = _make_causal_mask(
 
 
 
 
1039
  input_shape,
1040
+ inputs_embeds.dtype,
1041
+ device=inputs_embeds.device,
1042
+ past_key_values_length=past_key_values_length,
1043
  )
1044
+
1045
+ if attention_mask is not None and combined_attention_mask is not None:
1046
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1047
+ combined_attention_mask = combined_attention_mask + _expand_mask(
1048
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1049
  )
1050
 
1051
  # expand encoder attention mask
1052
  if encoder_hidden_states is not None and encoder_attention_mask is not None:
1053
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1054
+ encoder_attention_mask = _expand_mask(
1055
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
1056
+ )
 
 
 
 
 
 
 
 
 
 
 
 
1057
 
1058
  # embed positions
1059
  positions = self.embed_positions(
 
1124
  layer_outputs = torch.utils.checkpoint.checkpoint(
1125
  create_custom_forward(decoder_layer),
1126
  hidden_states,
1127
+ combined_attention_mask,
1128
  encoder_hidden_states,
1129
  encoder_attention_mask,
1130
  head_mask[idx] if head_mask is not None else None,
 
1136
  else:
1137
  layer_outputs = decoder_layer(
1138
  hidden_states,
1139
+ attention_mask=combined_attention_mask,
1140
  encoder_hidden_states=encoder_hidden_states,
1141
  encoder_attention_mask=encoder_attention_mask,
1142
  layer_head_mask=(
 
1393
  masked_lm_loss = F.cross_entropy(
1394
  input=lm_logits.view(-1, self.config.decoder_vocab_size),
1395
  target=labels.view(-1),
1396
+ ignore_index=self.config.pad_token_id,
1397
  label_smoothing=self._label_smoothing,
1398
  )
1399
 
special_tokens_map.json DELETED
@@ -1,6 +0,0 @@
1
- {
2
- "bos_token": "<s>",
3
- "eos_token": "</s>",
4
- "pad_token": "<pad>",
5
- "unk_token": "<unk>"
6
- }
 
 
 
 
 
 
 
tokenization_indictrans.py DELETED
@@ -1,261 +0,0 @@
1
- import os
2
- import json
3
-
4
- from typing import Dict, List, Optional, Union, Tuple
5
-
6
- from transformers.utils import logging
7
- from sentencepiece import SentencePieceProcessor
8
- from transformers.tokenization_utils import PreTrainedTokenizer
9
-
10
-
11
- logger = logging.get_logger(__name__)
12
-
13
- SPIECE_UNDERLINE = "▁"
14
-
15
- SPECIAL_TAGS = {
16
- "_bt_",
17
- "_ft_",
18
- "asm_Beng",
19
- "awa_Deva",
20
- "ben_Beng",
21
- "bho_Deva",
22
- "brx_Deva",
23
- "doi_Deva",
24
- "eng_Latn",
25
- "gom_Deva",
26
- "gon_Deva",
27
- "guj_Gujr",
28
- "hin_Deva",
29
- "hne_Deva",
30
- "kan_Knda",
31
- "kas_Arab",
32
- "kas_Deva",
33
- "kha_Latn",
34
- "lus_Latn",
35
- "mag_Deva",
36
- "mai_Deva",
37
- "mal_Mlym",
38
- "mar_Deva",
39
- "mni_Beng",
40
- "mni_Mtei",
41
- "npi_Deva",
42
- "ory_Orya",
43
- "pan_Guru",
44
- "san_Deva",
45
- "sat_Olck",
46
- "snd_Arab",
47
- "snd_Deva",
48
- "tam_Taml",
49
- "tel_Telu",
50
- "urd_Arab",
51
- "unr_Deva",
52
- }
53
-
54
- VOCAB_FILES_NAMES = {
55
- "src_vocab_fp": "dict.SRC.json",
56
- "tgt_vocab_fp": "dict.TGT.json",
57
- "src_spm_fp": "model.SRC",
58
- "tgt_spm_fp": "model.TGT",
59
- }
60
-
61
-
62
- class IndicTransTokenizer(PreTrainedTokenizer):
63
- _added_tokens_encoder = {}
64
- _added_tokens_decoder = {}
65
-
66
- vocab_files_names = VOCAB_FILES_NAMES
67
- model_input_names = ["input_ids", "attention_mask"]
68
-
69
- def __init__(
70
- self,
71
- src_vocab_fp=None,
72
- tgt_vocab_fp=None,
73
- src_spm_fp=None,
74
- tgt_spm_fp=None,
75
- unk_token="<unk>",
76
- bos_token="<s>",
77
- eos_token="</s>",
78
- pad_token="<pad>",
79
- do_lower_case=False,
80
- **kwargs,
81
- ):
82
-
83
- self.src = True
84
-
85
- self.src_vocab_fp = src_vocab_fp
86
- self.tgt_vocab_fp = tgt_vocab_fp
87
- self.src_spm_fp = src_spm_fp
88
- self.tgt_spm_fp = tgt_spm_fp
89
-
90
- self.unk_token = unk_token
91
- self.pad_token = pad_token
92
- self.eos_token = eos_token
93
- self.bos_token = bos_token
94
-
95
- self.encoder = self._load_json(self.src_vocab_fp)
96
- if self.unk_token not in self.encoder:
97
- raise KeyError("<unk> token must be in vocab")
98
- assert self.pad_token in self.encoder
99
- self.encoder_rev = {v: k for k, v in self.encoder.items()}
100
-
101
- self.decoder = self._load_json(self.tgt_vocab_fp)
102
- if self.unk_token not in self.encoder:
103
- raise KeyError("<unk> token must be in vocab")
104
- assert self.pad_token in self.encoder
105
- self.decoder_rev = {v: k for k, v in self.decoder.items()}
106
-
107
- # load SentencePiece model for pre-processing
108
- self.src_spm = self._load_spm(self.src_spm_fp)
109
- self.tgt_spm = self._load_spm(self.tgt_spm_fp)
110
-
111
- self.current_spm = self.src_spm
112
- self.current_encoder = self.encoder
113
- self.current_encoder_rev = self.encoder_rev
114
-
115
- self.unk_token_id = self.encoder[self.unk_token]
116
- self.pad_token_id = self.encoder[self.pad_token]
117
- self.eos_token_id = self.encoder[self.eos_token]
118
- self.bos_token_id = self.encoder[self.bos_token]
119
-
120
- super().__init__(
121
- src_vocab_file=self.src_vocab_fp,
122
- tgt_vocab_file=self.src_vocab_fp,
123
- do_lower_case=do_lower_case,
124
- unk_token=unk_token,
125
- bos_token=bos_token,
126
- eos_token=eos_token,
127
- pad_token=pad_token,
128
- **kwargs,
129
- )
130
-
131
- def add_new_special_tags(self, new_tags: List[str]):
132
- SPECIAL_TAGS.update(new_tags)
133
-
134
- def _switch_to_input_mode(self):
135
- self.src = True
136
- self.padding_side = "left"
137
- self.current_spm = self.src_spm
138
- self.current_encoder = self.encoder
139
- self.current_encoder_rev = self.encoder_rev
140
-
141
- def _switch_to_target_mode(self):
142
- self.src = False
143
- self.padding_side = "right"
144
- self.current_spm = self.tgt_spm
145
- self.current_encoder = self.decoder
146
- self.current_encoder_rev = self.decoder_rev
147
-
148
- def _load_spm(self, path: str) -> SentencePieceProcessor:
149
- return SentencePieceProcessor(model_file=path)
150
-
151
- def _save_json(self, data, path: str) -> None:
152
- with open(path, "w", encoding="utf-8") as f:
153
- json.dump(data, f, indent=2)
154
-
155
- def _load_json(self, path: str) -> Union[Dict, List]:
156
- with open(path, "r", encoding="utf-8") as f:
157
- return json.load(f)
158
-
159
- def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
- tags = [token for token in tokens if token in SPECIAL_TAGS]
161
- tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
- return tags, tokens
163
-
164
- def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
- pads = [token for token in tokens if token == self.pad_token]
166
- tokens = [token for token in tokens if token != self.pad_token]
167
- return pads, tokens
168
-
169
- @property
170
- def src_vocab_size(self) -> int:
171
- return len(self.encoder)
172
-
173
- @property
174
- def tgt_vocab_size(self) -> int:
175
- return len(self.decoder)
176
-
177
- def get_src_vocab(self) -> Dict[str, int]:
178
- return dict(self.encoder, **self.added_tokens_encoder)
179
-
180
- def get_tgt_vocab(self) -> Dict[str, int]:
181
- return dict(self.decoder, **self.added_tokens_decoder)
182
-
183
- # hack override
184
- def get_vocab(self) -> Dict[str, int]:
185
- return self.get_src_vocab()
186
-
187
- # hack override
188
- @property
189
- def vocab_size(self) -> int:
190
- return self.src_vocab_size
191
-
192
- def _convert_token_to_id(self, token: str) -> int:
193
- """Converts an token (str) into an index (integer) using the source/target vocabulary map."""
194
- return self.current_encoder.get(token, self.current_encoder[self.unk_token])
195
-
196
- def _convert_id_to_token(self, index: int) -> str:
197
- """Converts an index (integer) into a token (str) using the source/target vocabulary map."""
198
- return self.current_encoder_rev.get(index, self.unk_token)
199
-
200
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
- """Uses sentencepiece model for detokenization"""
202
- pads, tokens = self._split_pads(tokens)
203
-
204
- if self.src:
205
-
206
- tags, non_tags = self._split_tags(tokens)
207
-
208
- return (
209
- " ".join(pads)
210
- + " "
211
- + " ".join(tags)
212
- + " "
213
- + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
- )
215
-
216
- return (
217
- "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
- + " "
219
- + " ".join(pads)
220
- )
221
-
222
- def _tokenize(self, text) -> List[str]:
223
- if self.src:
224
- tokens = text.split(" ")
225
- tags, non_tags = self._split_tags(tokens)
226
- text = " ".join(non_tags)
227
- tokens = self.current_spm.EncodeAsPieces(text)
228
- return tags + tokens
229
- else:
230
- return self.current_spm.EncodeAsPieces(text)
231
-
232
- def build_inputs_with_special_tokens(
233
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
234
- ) -> List[int]:
235
- if token_ids_1 is None:
236
- return token_ids_0 + [self.eos_token_id]
237
- # We don't expect to process pairs, but leave the pair logic for API consistency
238
- return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
-
240
- def save_vocabulary(
241
- self, save_directory: str, filename_prefix: Optional[str] = None
242
- ) -> Tuple[str]:
243
- if not os.path.isdir(save_directory):
244
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
- return
246
-
247
- src_spm_fp = os.path.join(save_directory, "model.SRC")
248
- tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
- src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
- tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
-
252
- self._save_json(self.encoder, src_vocab_fp)
253
- self._save_json(self.decoder, tgt_vocab_fp)
254
-
255
- with open(src_spm_fp, "wb") as f:
256
- f.write(self.src_spm.serialized_model_proto())
257
-
258
- with open(tgt_spm_fp, "wb") as f:
259
- f.write(self.tgt_spm.serialized_model_proto())
260
-
261
- return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer_config.json DELETED
@@ -1,51 +0,0 @@
1
- {
2
- "added_tokens_decoder": {
3
- "0": {
4
- "content": "<s>",
5
- "lstrip": false,
6
- "normalized": false,
7
- "rstrip": false,
8
- "single_word": false,
9
- "special": true
10
- },
11
- "1": {
12
- "content": "<pad>",
13
- "lstrip": false,
14
- "normalized": false,
15
- "rstrip": false,
16
- "single_word": false,
17
- "special": true
18
- },
19
- "2": {
20
- "content": "</s>",
21
- "lstrip": false,
22
- "normalized": false,
23
- "rstrip": false,
24
- "single_word": false,
25
- "special": true
26
- },
27
- "3": {
28
- "content": "<unk>",
29
- "lstrip": false,
30
- "normalized": false,
31
- "rstrip": false,
32
- "single_word": false,
33
- "special": true
34
- }
35
- },
36
- "bos_token": "<s>",
37
- "clean_up_tokenization_spaces": true,
38
- "do_lower_case": false,
39
- "eos_token": "</s>",
40
- "model_max_length": 256,
41
- "pad_token": "<pad>",
42
- "name_or_path": "ai4bharat/indictrans2-en-indic-1B",
43
- "tokenizer_class": "IndicTransTokenizer",
44
- "auto_map": {
45
- "AutoTokenizer": [
46
- "tokenization_indictrans.IndicTransTokenizer",
47
- null
48
- ]
49
- },
50
- "unk_token": "<unk>"
51
- }