VarunGumma commited on
Commit
c46c5c5
·
verified ·
1 Parent(s): 5a72625

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -12
README.md CHANGED
@@ -57,17 +57,22 @@ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2
57
 
58
  ```python
59
  import torch
60
- from transformers import (
61
- AutoModelForSeq2SeqLM,
62
- AutoTokenizer,
63
- )
64
  from IndicTransToolkit import IndicProcessor
 
 
 
65
 
66
-
67
- model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
68
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
69
 
70
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
71
 
72
  ip = IndicProcessor(inference=True)
73
 
@@ -78,8 +83,6 @@ input_sentences = [
78
  "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
79
  ]
80
 
81
- src_lang, tgt_lang = "hin_Deva", "tam_Taml"
82
-
83
  batch = ip.preprocess_batch(
84
  input_sentences,
85
  src_lang=src_lang,
@@ -124,9 +127,6 @@ for input_sentence, translation in zip(input_sentences, translations):
124
  print(f"{tgt_lang}: {translation}")
125
  ```
126
 
127
- **Note: IndicTrans2 is now compatible with AutoTokenizer, however you need to use IndicProcessor from [IndicTransTokenizer](https://github.com/VarunGumma/IndicTransTokenizer) for preprocessing before tokenization.**
128
-
129
-
130
  ### Citation
131
 
132
  If you consider using our work then please cite using:
 
57
 
58
  ```python
59
  import torch
60
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
 
 
 
61
  from IndicTransToolkit import IndicProcessor
62
+ # recommended to run this on a gpu with flash_attn installed
63
+ # don't set attn_implemetation if you don't have flash_attn
64
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
65
 
66
+ src_lang, tgt_lang = "hin_Deva", "tam_Taml"
67
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-200M"
68
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
69
 
70
+ model = AutoModelForSeq2SeqLM.from_pretrained(
71
+ model_name,
72
+ trust_remote_code=True,
73
+ torch_dtype=torch.float16, # performance might slightly vary for bfloat16
74
+ attn_implementation="flash_attention_2"
75
+ ).to(DEVICE)
76
 
77
  ip = IndicProcessor(inference=True)
78
 
 
83
  "मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
84
  ]
85
 
 
 
86
  batch = ip.preprocess_batch(
87
  input_sentences,
88
  src_lang=src_lang,
 
127
  print(f"{tgt_lang}: {translation}")
128
  ```
129
 
 
 
 
130
  ### Citation
131
 
132
  If you consider using our work then please cite using: