SaranaAbidueva commited on
Commit
f6d9d47
1 Parent(s): 87a18d2

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +35 -1
README.md CHANGED
@@ -7,4 +7,38 @@ datasets:
7
  - SaranaAbidueva/buryat-russian_parallel_corpus
8
  metrics:
9
  - bleu
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  - SaranaAbidueva/buryat-russian_parallel_corpus
8
  metrics:
9
  - bleu
10
+ ---
11
+
12
+ How to use in Python:
13
+ ```python
14
+ from transformers import MBartForConditionalGeneration, MBart50Tokenizer
15
+ model = MBartForConditionalGeneration.from_pretrained("SaranaAbidueva/mbart50_bur_ru")
16
+ tokenizer = MBart50Tokenizer.from_pretrained("SaranaAbidueva/mbart50_bur_ru")
17
+ def fix_tokenizer(tokenizer):
18
+ old_len = len(tokenizer) - int('bxr_XX' in tokenizer.added_tokens_encoder)
19
+ tokenizer.lang_code_to_id['bxr_XX'] = old_len-1
20
+ tokenizer.id_to_lang_code[old_len-1] = 'bxr_XX'
21
+ tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
22
+
23
+ tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
24
+ tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
25
+ if 'bxr_XX' not in tokenizer._additional_special_tokens:
26
+ tokenizer._additional_special_tokens.append('bxr_XX')
27
+ tokenizer.added_tokens_encoder = {}
28
+ fix_tokenizer(tokenizer)
29
+
30
+ def translate(text, src='ru_RU', trg='bxr_XX', max_length=200, num_beams=5, repetition_penalty=5.0, **kwargs):
31
+ tokenizer.src_lang = src
32
+ encoded = tokenizer(text, return_tensors="pt")
33
+ generated_tokens = model.generate(
34
+ **encoded.to(model.device),
35
+ forced_bos_token_id=tokenizer.lang_code_to_id[trg],
36
+ max_length=max_length,
37
+ num_beams=num_beams,
38
+ repetition_penalty=repetition_penalty,
39
+ # early_stopping=True,
40
+ )
41
+ return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
42
+
43
+ translate('Евгений Онегин интересная книга')
44
+ ```