VarunGumma
commited on
Update README.md
Browse files
README.md
CHANGED
@@ -57,17 +57,22 @@ Please refer to the [github repository](https://github.com/AI4Bharat/IndicTrans2
|
|
57 |
|
58 |
```python
|
59 |
import torch
|
60 |
-
from transformers import
|
61 |
-
AutoModelForSeq2SeqLM,
|
62 |
-
AutoTokenizer,
|
63 |
-
)
|
64 |
from IndicTransToolkit import IndicProcessor
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
model_name = "ai4bharat/indictrans2-indic-indic-dist-
|
68 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
69 |
|
70 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
ip = IndicProcessor(inference=True)
|
73 |
|
@@ -78,8 +83,6 @@ input_sentences = [
|
|
78 |
"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
|
79 |
]
|
80 |
|
81 |
-
src_lang, tgt_lang = "hin_Deva", "tam_Taml"
|
82 |
-
|
83 |
batch = ip.preprocess_batch(
|
84 |
input_sentences,
|
85 |
src_lang=src_lang,
|
@@ -124,9 +127,6 @@ for input_sentence, translation in zip(input_sentences, translations):
|
|
124 |
print(f"{tgt_lang}: {translation}")
|
125 |
```
|
126 |
|
127 |
-
**Note: IndicTrans2 is now compatible with AutoTokenizer, however you need to use IndicProcessor from [IndicTransTokenizer](https://github.com/VarunGumma/IndicTransTokenizer) for preprocessing before tokenization.**
|
128 |
-
|
129 |
-
|
130 |
### Citation
|
131 |
|
132 |
If you consider using our work then please cite using:
|
|
|
57 |
|
58 |
```python
|
59 |
import torch
|
60 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
|
|
|
61 |
from IndicTransToolkit import IndicProcessor
|
62 |
+
# recommended to run this on a gpu with flash_attn installed
|
63 |
+
# don't set attn_implemetation if you don't have flash_attn
|
64 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
65 |
|
66 |
+
src_lang, tgt_lang = "hin_Deva", "tam_Taml"
|
67 |
+
model_name = "ai4bharat/indictrans2-indic-indic-dist-200M"
|
68 |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
69 |
|
70 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
71 |
+
model_name,
|
72 |
+
trust_remote_code=True,
|
73 |
+
torch_dtype=torch.float16, # performance might slightly vary for bfloat16
|
74 |
+
attn_implementation="flash_attention_2"
|
75 |
+
).to(DEVICE)
|
76 |
|
77 |
ip = IndicProcessor(inference=True)
|
78 |
|
|
|
83 |
"मेरे मित्र ने मुझे उसके जन्मदिन की पार्टी में बुलाया है, और मैं उसे एक तोहफा दूंगा।",
|
84 |
]
|
85 |
|
|
|
|
|
86 |
batch = ip.preprocess_batch(
|
87 |
input_sentences,
|
88 |
src_lang=src_lang,
|
|
|
127 |
print(f"{tgt_lang}: {translation}")
|
128 |
```
|
129 |
|
|
|
|
|
|
|
130 |
### Citation
|
131 |
|
132 |
If you consider using our work then please cite using:
|