Edit model card

bart_dev_rom_tl

This model is a fine-tuned version of ar5entum/bart_hin_eng_mt on ar5entum/hindi-english-roman-devnagiri-transliteration-corpus dataset. It achieves the following results on the evaluation set:

  • Loss: 0.8156
  • Bleu: 40.6409
  • Gen Len: 40.3178

Model description

This model is trained on transliteration dataset of roman and devnagiri sentences. The objective of this experiment was to correctly transliterate sentences based on their context.

Inference and Evaluation

import torch
import evaluate
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

def batch_long_string(text):
    batch = []
    temp = []
    count = 0
    for word in text.split():
        count+=len(word)
        temp.append(word.strip())
        if count > 40:
            count = 0
            batch.append(" ".join(temp).strip())
            temp = []
    if len(temp) > 0:
        batch.append(" ".join(temp).strip())
    return batch

class BartSmall():
    def __init__(self, model_path = 'ar5entum/bart_dev_rom_tl', device = None):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        if not device:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device
        self.model.to(device)

    def predict(self, input_text):
        inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
        pred_ids = self.model.generate(inputs.input_ids, max_length=512, num_beams=4, early_stopping=True)
        prediction = self.tokenizer.decode(pred_ids[0], skip_special_tokens=True)
        return prediction
    
    def predict_batch(self, input_texts, batch_size=32):
        all_predictions = []
        for i in range(0, len(input_texts), batch_size):
            batch_texts = input_texts[i:i+batch_size]
            inputs = self.tokenizer(batch_texts, return_tensors="pt", max_length=512, 
                                    truncation=True, padding=True).to(self.device)
            
            with torch.no_grad():
                pred_ids = self.model.generate(inputs.input_ids, 
                                               max_length=512, 
                                               num_beams=4, 
                                               early_stopping=True)
            
            predictions = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
            all_predictions.extend(predictions)

        return all_predictions

model = BartSmall(device='cuda')

input_texts = [
    "द एजुकेशन रिसर्चर इवैल्युएटेड द इफेक्टिवनेस ऑफ ऑनलाइन लर्निंग", 
    "यह अभिषेक जल, इक्षुरस, दुध, चावल का आटा, लाल चंदन, हल्दी, अष्टगंध, चंदन चुरा, चार कलश, केसर वृष्टि, आरती, सुगंधित कलश, महाशांतिधारा एवं महाअर्घ्य के साथ भगवान नेमिनाथ को समर्पित किया जाता है।",
    "कुछ ने कहा ये चांद है कुछ ने कहा चेहरा तेरा"
    ]
ground_truths = [
    "the education researcher evaluated the effectiveness of online learning.",
    "yah abhishek jal, ikshuras, dudh, chaval ka ataa, laal chandan, haldi, ashtagandh, chandan chura, char kalash, kesar vrishti, aarti, sugandhit kalash, mahashantidhara evam mahaarghya ke saath bhagvan Neminath ko samarpit kiya jata hai.",
    "kuch ne kaha ye chand hai kuch ne kaha chehra ter"
    ]
import time
start = time.time()

def batch_long_string(text):
    batch = []
    temp = []
    count = 0
    for word in text.split():
        count+=len(word)
        temp.append(word.strip())
        if count > 40:
            count = 0
            batch.append(" ".join(temp).strip())
            temp = []
    if len(temp) > 0:
        batch.append(" ".join(temp).strip())
    return batch

predictions = [" ".join([" ".join(model.predict_batch(batch, batch_size=len(batch))) for batch in batch_long_string(text)]) for text in input_texts]
end = time.time()
print("TIME: ", end-start)
for i in range(len(input_texts)):
    print("‾‾‾‾‾‾‾‾‾‾‾‾")
    print("Input text:\t", input_texts[i])
    print("Prediction:\t", predictions[i])
    print("Ground Truth:\t", ground_truths[i])
bleu = evaluate.load("bleu")
results = bleu.compute(predictions=predictions, references=ground_truths)
print(results)

# TIME:  1.6740131378173828
# ‾‾‾‾‾‾‾‾‾‾‾‾
# Input text:	 द एजुकेशन रिसर्चर इवैल्युएटेड द इफेक्टिवनेस ऑफ ऑनलाइन लर्निंग
# Prediction:	 the education researcher evaluated the inflation of online. Larning
# Ground Truth:	 the education researcher evaluated the effectiveness of online learning.
# ‾‾‾‾‾‾‾‾‾‾‾‾
# Input text:	 यह अभिषेक जल, इक्षुरस, दुध, चावल का आटा, लाल चंदन, हल्दी, अष्टगंध, चंदन चुरा, चार कलश, केसर वृष्टि, आरती, सुगंधित कलश, महाशांतिधारा एवं महाअर्घ्य के साथ भगवान नेमिनाथ को समर्पित किया जाता है।
# Prediction:	 yah abhishek jal, ikshuras, dudh, chaval ka aata, laal chandan, Haldi, asthagandh, chandan chura, char kalash, kesar vritti, Aarti, Sugandhit kalash, Mahashantidhara evam Maharghya ke saath bhagwan Nemith ko samarpit kiya jata hai.
# Ground Truth:	 yah abhishek jal, ikshuras, dudh, chaval ka ataa, laal chandan, haldi, ashtagandh, chandan chura, char kalash, kesar vrishti, aarti, sugandhit kalash, mahashantidhara evam mahaarghya ke saath bhagvan Neminath ko samarpit kiya jata hai.
# ‾‾‾‾‾‾‾‾‾‾‾‾
# Input text:	 कुछ ने कहा ये चांद है कुछ ने कहा चेहरा तेरा
# Prediction:	 kuchh ne kaha ye chand hai kuch ne kaha chehra tera
# Ground Truth:	 kuch ne kaha ye chand hai kuch ne kaha chehra ter
# {'bleu': 0.5596481750975065, 'precisions': [0.7910447761194029, 0.609375, 0.4918032786885246, 0.41379310344827586], 'brevity_penalty': 1.0, 'length_ratio': 1.0, 'translation_length': 67, 'reference_length': 67}

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 3e-05
  • train_batch_size: 100
  • eval_batch_size: 40
  • seed: 42
  • distributed_type: multi-GPU
  • num_devices: 2
  • total_train_batch_size: 200
  • total_eval_batch_size: 80
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 80
  • num_epochs: 100.0

Training results

Training Loss Epoch Step Validation Loss Bleu Gen Len
6.046 1.0 71 5.7137 0.0237 78.975
4.9653 2.0 142 4.6488 0.463 68.7566
4.3594 3.0 213 3.9858 1.7108 51.6638
3.8595 4.0 284 3.5145 3.7857 48.8671
3.5045 5.0 355 3.1973 6.3952 46.3566
3.241 6.0 426 2.9686 8.4659 47.6658
3.0828 7.0 497 2.7850 10.5828 48.1118
2.9064 8.0 568 2.6409 11.8302 48.8211
2.7434 9.0 639 2.5048 12.5417 50.2257
2.6201 10.0 710 2.3933 13.7057 45.6704
2.4511 11.0 781 2.2927 14.7807 46.4112
2.3707 12.0 852 2.1978 15.9284 43.0941
2.2821 13.0 923 2.1169 17.0686 45.0566
2.1725 14.0 994 2.0360 17.7927 45.0487
2.0905 15.0 1065 1.9586 18.7905 43.5625
2.0224 16.0 1136 1.8913 19.8848 43.9507
1.9548 17.0 1207 1.8289 20.506 43.2441
1.8764 18.0 1278 1.7778 21.0069 41.9743
1.8262 19.0 1349 1.7314 22.0322 41.9711
1.7626 20.0 1420 1.6766 22.5132 43.1888
1.6689 21.0 1491 1.6242 23.3894 42.7395
1.6668 22.0 1562 1.5729 24.2888 43.1961
1.5834 23.0 1633 1.5277 24.7954 41.9934
1.5352 24.0 1704 1.4837 25.7943 41.5171
1.5149 25.0 1775 1.4402 26.4075 41.5632
1.4375 26.0 1846 1.4013 26.798 41.9704
1.4224 27.0 1917 1.3709 27.7495 41.4283
1.3972 28.0 1988 1.3359 28.2608 41.7559
1.3475 29.0 2059 1.3065 28.579 41.4954
1.3269 30.0 2130 1.2727 29.2762 41.0467
1.2329 31.0 2201 1.2481 29.2254 41.6296
1.2292 32.0 2272 1.2199 30.0158 41.7487
1.1868 33.0 2343 1.1981 30.8127 41.1414
1.1662 34.0 2414 1.1777 31.0606 41.3145
1.1341 35.0 2485 1.1608 31.4376 40.8375
1.1651 36.0 2556 1.1385 31.9947 41.1934
1.1019 37.0 2627 1.1238 32.5984 41.1112
1.1232 38.0 2698 1.1096 33.1094 41.0974
1.0553 39.0 2769 1.0930 33.1268 41.0842
1.0536 40.0 2840 1.0812 33.4825 41.0868
1.0212 41.0 2911 1.0672 34.0163 40.8362
0.9768 42.0 2982 1.0531 34.1846 41.0447
0.9923 43.0 3053 1.0426 34.4359 41.1908
0.9646 44.0 3124 1.0338 34.83 40.9336
0.9858 45.0 3195 1.0211 34.8589 40.723
0.963 46.0 3266 1.0159 35.1912 40.8447
0.9226 47.0 3337 1.0023 35.4973 40.7612
0.9169 48.0 3408 0.9912 35.7503 41.1454
0.9173 49.0 3479 0.9864 35.9269 40.7145
0.8846 50.0 3550 0.9783 36.5519 40.6513
0.9061 51.0 3621 0.9693 36.5456 40.4079
0.8699 52.0 3692 0.9601 36.9342 41.0151
0.8753 53.0 3763 0.9539 37.0866 40.6691
0.8265 54.0 3834 0.9444 37.0662 41.1809
0.8238 55.0 3905 0.9411 37.4991 40.5993
0.8125 56.0 3976 0.9340 37.4722 40.9829
0.8141 57.0 4047 0.9278 37.9354 40.6638
0.8089 58.0 4118 0.9221 37.8179 41.0704
0.7953 59.0 4189 0.9171 38.2691 40.6224
0.7781 60.0 4260 0.9121 38.2475 40.4526
0.7858 61.0 4331 0.9061 38.4115 40.7947
0.7879 62.0 4402 0.9013 38.2173 40.4717
0.7931 63.0 4473 0.8979 38.4403 40.7276
0.7698 64.0 4544 0.8942 38.7601 40.4849
0.7623 65.0 4615 0.8869 38.8371 40.8053
0.7548 66.0 4686 0.8830 38.935 40.6434
0.7696 67.0 4757 0.8796 38.8151 40.4355
0.7323 68.0 4828 0.8770 38.9874 40.5763
0.7357 69.0 4899 0.8733 39.2862 40.5138
0.718 70.0 4970 0.8695 38.9941 40.4559
0.7105 71.0 5041 0.8647 39.0562 40.5691
0.7124 72.0 5112 0.8611 39.5159 40.6039
0.7094 73.0 5183 0.8580 39.5358 40.6257
0.7137 74.0 5254 0.8542 39.7735 40.6539
0.7066 75.0 5325 0.8514 39.7981 40.3717
0.7118 76.0 5396 0.8498 39.7518 40.4428
0.687 77.0 5467 0.8464 39.7604 40.4053
0.683 78.0 5538 0.8426 39.9961 40.3941
0.693 79.0 5609 0.8394 40.1569 40.3941
0.6855 80.0 5680 0.8380 40.0677 40.448
0.6823 81.0 5751 0.8353 39.8297 40.6493
0.6603 82.0 5822 0.8324 40.0701 40.5842
0.6648 83.0 5893 0.8321 40.3281 40.4849
0.6491 84.0 5964 0.8295 40.2578 40.3303
0.6715 85.0 6035 0.8276 40.3384 40.4276
0.6542 86.0 6106 0.8266 40.359 40.3776
0.6273 87.0 6177 0.8257 40.5114 40.3941
0.6696 88.0 6248 0.8242 40.6565 40.3592
0.6485 89.0 6319 0.8230 40.7058 40.1993
0.682 90.0 6390 0.8220 40.665 40.3296
0.6625 91.0 6461 0.8196 40.6032 40.2908
0.6473 92.0 6532 0.8193 40.4884 40.3572
0.6544 93.0 6603 0.8186 40.4847 40.5513
0.6599 94.0 6674 0.8177 40.5928 40.4342
0.6368 95.0 6745 0.8168 40.6436 40.4625
0.6283 96.0 6816 0.8168 40.5861 40.4066
0.6301 97.0 6887 0.8165 40.62 40.2855
0.6356 98.0 6958 0.8161 40.7093 40.3072
0.6542 99.0 7029 0.8158 40.5941 40.3086
0.6463 100.0 7100 0.8156 40.6409 40.3178

Framework versions

  • Transformers 4.45.0.dev0
  • Pytorch 2.4.0+cu121
  • Datasets 2.21.0
  • Tokenizers 0.19.1
Downloads last month
3
Safetensors
Model size
49.4M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for ar5entum/bart_dev_rom_tl

Unable to build the model tree, the base model loops to the model itself. Learn more.

Dataset used to train ar5entum/bart_dev_rom_tl