ara-e5-small / README.md
gimmeursocks's picture
Update README.md
981de30 verified
metadata
datasets:
  - Omartificial-Intelligence-Space/Arabic-NLi-Triplet
language:
  - ar
base_model: intfloat/multilingual-e5-small
library_name: sentence-transformers
pipeline_tag: sentence-similarity
tags:
  - sentence-transformers
  - sentence-similarity
  - feature-extraction
  - arabic
  - triplet-loss
widget: []

Arabic NLI Triplet - Sentence Transformer Model

This repository contains a fine-tuned Sentence Transformer model trained on the "Omartificial-Intelligence-Space/Arabic-NLi-Triplet" dataset. The model is trained to generate 384-dimensional embeddings for semantic similarity tasks like paraphrase mining, sentence similarity, and clustering in Arabic.

Model Overview

  • Model Type: Sentence Transformer
  • Base Model: intfloat/multilingual-e5-small
  • Training Dataset: Omartificial-Intelligence-Space/Arabic-NLi-Triplet
  • Similarity Function: Cosine Similarity
  • Embedding Dimensionality: 384 dimensions
  • Maximum Sequence Length: 128 tokens
  • Performance Improvement: The model achieved around 10% improvement when tested on the test set of the provided dataset, compared to the base model's performance.

Dataset

Arabic NLI Triplet Dataset

The dataset contains triplets of sentences in Arabic: an anchor sentence, a positive sentence (semantically similar to the anchor), and a negative sentence (semantically dissimilar to the anchor). The dataset is designed for learning sentence representations through triplet margin loss.

Dataset Link: Omartificial-Intelligence-Space/Arabic-NLi-Triplet

Training Process

Loss Function: Triplet Margin Loss

We used the Triplet Margin Loss with a margin of 1.0. The model is trained to minimize the distance between anchor and positive embeddings, while maximizing the distance between anchor and negative embeddings.

Training Loss Progress:

Below is the training loss recorded at various steps during the training process:

Step Training Loss
500 0.136500
1000 0.126500
1500 0.127300
2000 0.114500
2500 0.110600
3000 0.102300
3500 0.101300
4000 0.106900
4500 0.097200
5000 0.091700
5500 0.092400
6000 0.095500

Model Training Code

The model was trained using the following code (without resuming from checkpoints):

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer
from torch.nn import TripletMarginLoss

# Load dataset
dataset = load_dataset("Omartificial-Intelligence-Space/Arabic-NLi-Triplet")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-small")

# Tokenize function
def tokenize_function(examples):
    anchor_encodings = tokenizer(examples['anchor'], truncation=True, padding='max_length', max_length=128)
    positive_encodings = tokenizer(examples['positive'], truncation=True, padding='max_length', max_length=128)
    negative_encodings = tokenizer(examples['negative'], truncation=True, padding='max_length', max_length=128)

    return {
        'anchor_input_ids': anchor_encodings['input_ids'],
        'anchor_attention_mask': anchor_encodings['attention_mask'],
        'positive_input_ids': positive_encodings['input_ids'],
        'positive_attention_mask': positive_encodings['attention_mask'],
        'negative_input_ids': negative_encodings['input_ids'],
        'negative_attention_mask': negative_encodings['attention_mask'],
    }

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)

# Define triplet loss
triplet_loss = TripletMarginLoss(margin=1.0)

def compute_loss(anchor_embedding, positive_embedding, negative_embedding):
    return triplet_loss(anchor_embedding, positive_embedding, negative_embedding)

# Load model
model = AutoModel.from_pretrained("intfloat/multilingual-e5-small")

class TripletTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        anchor_input_ids = inputs['anchor_input_ids'].to(self.args.device)
        anchor_attention_mask = inputs['anchor_attention_mask'].to(self.args.device)
        positive_input_ids = inputs['positive_input_ids'].to(self.args.device)
        positive_attention_mask = inputs['positive_attention_mask'].to(self.args.device)
        negative_input_ids = inputs['negative_input_ids'].to(self.args.device)
        negative_attention_mask = inputs['negative_attention_mask'].to(self.args.device)

        anchor_embeds = model(input_ids=anchor_input_ids, attention_mask=anchor_attention_mask).last_hidden_state.mean(dim=1)
        positive_embeds = model(input_ids=positive_input_ids, attention_mask=positive_attention_mask).last_hidden_state.mean(dim=1)
        negative_embeds = model(input_ids=negative_input_ids, attention_mask=negative_attention_mask).last_hidden_state.mean(dim=1)

        return compute_loss(anchor_embeds, positive_embeds, negative_embeds)

# Training arguments
training_args = TrainingArguments(
    output_dir="/content/drive/MyDrive/results",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='/content/drive/MyDrive/logs',
    remove_unused_columns=False,
    fp16=True,
    save_total_limit=3,
)

# Initialize trainer
trainer = TripletTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
)

# Start training
trainer.train()

# Save model and evaluate
trainer.save_model("/content/drive/MyDrive/fine-tuned-multilingual-e5")
results = trainer.evaluate()
print(results)

Framework Versions

  • Python: 3.10.11
  • Sentence Transformers: 3.0.1
  • Transformers: 4.44.2
  • PyTorch: 2.4.0
  • Datasets: 2.21.0

How to Use

To use the model, install the required libraries and load the model with the following code:

pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer

# Load the fine-tuned model
model = SentenceTransformer("gimmeursocks/ara-e5-small")

# Run inference
sentences = ['أنا سعيد', 'الجو جميل اليوم', 'هذا كلب كبير']
embeddings = model.encode(sentences)
print(embeddings.shape)

Citation

If you use this model or dataset, please cite the corresponding paper or dataset source.