{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Arabic Dialect Classifier\n", "This notebook contains the training of the classifier model. The goal is to classify the dialects at the country level." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "from datasets import DatasetDict, Dataset\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.model_selection import RandomizedSearchCV\n", "from sklearn.preprocessing import LabelEncoder\n", "import torch\n", "from transformers import AutoModel, AutoTokenizer\n", "import xgboost as xgb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Exploring the Dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(\"../data/DA_train_labeled.tsv\", sep=\"\\t\")\n", "df_test = pd.read_csv(\"../data/DA_dev_labeled.tsv\", sep=\"\\t\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
#1_tweetid#2_tweet#3_country_label#4_province_label
0TRAIN_0حاجة حلوة اكيدEgypteg_Faiyum
1TRAIN_1عم بشتغلوا للشعب الاميركي اما نحن يكذبوا ويغشو...Iraqiq_Dihok
2TRAIN_2ابشر طال عمركSaudi_Arabiasa_Ha'il
3TRAIN_3منطق 2017: أنا والغريب علي إبن عمي وأنا والغري...Mauritaniamr_Nouakchott
4TRAIN_4شهرين وتروح والباقي غير صيف مليناAlgeriadz_El-Oued
\n", "
" ], "text/plain": [ " #1_tweetid #2_tweet \\\n", "0 TRAIN_0 حاجة حلوة اكيد \n", "1 TRAIN_1 عم بشتغلوا للشعب الاميركي اما نحن يكذبوا ويغشو... \n", "2 TRAIN_2 ابشر طال عمرك \n", "3 TRAIN_3 منطق 2017: أنا والغريب علي إبن عمي وأنا والغري... \n", "4 TRAIN_4 شهرين وتروح والباقي غير صيف ملينا \n", "\n", " #3_country_label #4_province_label \n", "0 Egypt eg_Faiyum \n", "1 Iraq iq_Dihok \n", "2 Saudi_Arabia sa_Ha'il \n", "3 Mauritania mr_Nouakchott \n", "4 Algeria dz_El-Oued " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
#1_tweetid#2_tweet#3_country_label#4_province_label
0DEV_0قولنا اون لاين لا يا علي اون لاين لاEgypteg_Alexandria
1DEV_1ههههه بايخه ههههه URL  …Omanom_Muscat
2DEV_2ربنا يخليك يا دوك ولك المثل :DLebanonlb_South-Lebanon
3DEV_3#اوامر_ملكيه ياشباب اي واحد فيكم عنده شي يذكره...Syriasy_Damascus-City
4DEV_4شد عالخط حتى هيا اكويسهLibyaly_Misrata
\n", "
" ], "text/plain": [ " #1_tweetid #2_tweet \\\n", "0 DEV_0 قولنا اون لاين لا يا علي اون لاين لا \n", "1 DEV_1 ههههه بايخه ههههه URL  … \n", "2 DEV_2 ربنا يخليك يا دوك ولك المثل :D \n", "3 DEV_3 #اوامر_ملكيه ياشباب اي واحد فيكم عنده شي يذكره... \n", "4 DEV_4 شد عالخط حتى هيا اكويسه \n", "\n", " #3_country_label #4_province_label \n", "0 Egypt eg_Alexandria \n", "1 Oman om_Muscat \n", "2 Lebanon lb_South-Lebanon \n", "3 Syria sy_Damascus-City \n", "4 Libya ly_Misrata " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test.head()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#1_tweetid 0\n", " #2_tweet 0\n", " #3_country_label 0\n", " #4_province_label 0\n", " dtype: int64,\n", " #1_tweetid 0\n", " #2_tweet 0\n", " #3_country_label 0\n", " #4_province_label 0\n", " dtype: int64)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.isnull().sum(), df_test.isnull().sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at the distribution of the labels" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Value counts of country label in train data')" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.barh(y=df_train[\"#3_country_label\"].value_counts().sort_values(ascending=True).index,\n", " width=df_train[\"#3_country_label\"].value_counts().sort_values(ascending=True))\n", "plt.title(\"Value counts of country label in train data\")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 1.0, 'Value counts of country label in test data')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.barh(y=df_test[\"#3_country_label\"].value_counts().sort_values(ascending=True).index,\n", " width=df_test[\"#3_country_label\"].value_counts().sort_values(ascending=True))\n", "plt.title(\"Value counts of country label in test data\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some countries don't have a lot of observations, which means that it might be harder to detect their dialects. We need to take this into consideration when training and evaluating the model (by assigning weights/oversampling, and by choosing appropriate evaluation metrics)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Training the Classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this classifier, we will convert the tweets into vector embeddings using the AraBART model. We will use the last hidden layer of the model to extract the features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.1 Data Preparation\n", "The first step is to prepare our data by tokenizing it to use it with the model AraBART." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we load the model and its tokenizer." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.device" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device(\"cuda\")\n", "type(device)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "model = AutoModel.from_pretrained(\"moussaKam/AraBART\").to(device)\n", "tokenizer = AutoTokenizer.from_pretrained(\"moussaKam/AraBART\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we convert the datasets into a DatasetDict object." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "mapper = {\"#2_tweet\": \"tweet\", \"#3_country_label\": \"label\"}\n", "columns_to_keep = [\"tweet\", \"label\"]\n", "\n", "df_train = df_train.rename(columns=mapper)[columns_to_keep]\n", "df_test = df_test.rename(columns=mapper)[columns_to_keep]\n", "\n", "train_dataset = Dataset.from_pandas(df_train)\n", "test_dataset = Dataset.from_pandas(df_test)\n", "data = DatasetDict({'train': train_dataset, 'test': test_dataset})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we tokenkize the dataset." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "def tokenize(batch):\n", " return tokenizer(batch[\"tweet\"], padding=True)\n", "\n", "data_encoded = data.map(tokenize, batched=True, batch_size=None)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['tweet', 'label', 'input_ids', 'attention_mask']" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_encoded[\"train\"].column_names" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.2 Feature Extraction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we will extract the output of the last hidden layer of AraBART, and use those embeddings as the features of our classifier." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def extract_hidden_states(batch):\n", " inputs = {k:v.to(device) for k,v in batch.items()\n", " if k in tokenizer.model_input_names}\n", " with torch.no_grad():\n", " last_hidden_state = model(**inputs).last_hidden_state\n", "\n", " return{\"hidden_state\": last_hidden_state[:,0].cpu().numpy()}" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "data_encoded.set_format(\"torch\", columns=[\"input_ids\", \"attention_mask\", \"label\"])\n", "data_hidden = data_encoded.map(extract_hidden_states, batched=True, batch_size=50)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "with open(\"../data/data_hidden.pkl\", \"wb\") as f:\n", " pickle.dump(data_hidden, f)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "with open(\"../data/data_hidden.pkl\", \"rb\") as f:\n", " data_hidden = pickle.load(f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.3 Model Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we only need to convert the data into numpy arrays, and we are ready to train the models." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((21000, 768), (21000,))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train = np.array(data_hidden[\"train\"][\"hidden_state\"])\n", "X_test = np.array(data_hidden[\"test\"][\"hidden_state\"])\n", "y_train = np.array(data_hidden[\"train\"][\"label\"])\n", "y_test = np.array(data_hidden[\"test\"][\"label\"])\n", "\n", "X_train.shape, y_train.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can try different models. \n", "\n", "For the ensemble models, we can do a randomized or grid search to tune the hyperparameters. We will use a 5-fold cross validation strategy, and optimize for the macro averaged f1 score (because we want to give an equal importance to each class, regardless of how many observations each one has)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.3.1 Logistic Regression " ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
LogisticRegression(class_weight='balanced', max_iter=1000,\n",
       "                   multi_class='multinomial', random_state=2024)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "LogisticRegression(class_weight='balanced', max_iter=1000,\n", " multi_class='multinomial', random_state=2024)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr_model = LogisticRegression(multi_class='multinomial', \n", " class_weight=\"balanced\", \n", " max_iter=1000, \n", " random_state=2024)\n", "lr_model.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.3.2 Random Forest" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RandomizedSearchCV(cv=5,\n",
       "                   estimator=RandomForestClassifier(class_weight='balanced',\n",
       "                                                    random_state=2024),\n",
       "                   n_iter=20,\n",
       "                   param_distributions={'max_depth': [3, 4, 5, 6, 7, 8],\n",
       "                                        'n_estimators': [100, 150, 200, 250,\n",
       "                                                         300, 400, 500]},\n",
       "                   scoring='f1_macro')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RandomizedSearchCV(cv=5,\n", " estimator=RandomForestClassifier(class_weight='balanced',\n", " random_state=2024),\n", " n_iter=20,\n", " param_distributions={'max_depth': [3, 4, 5, 6, 7, 8],\n", " 'n_estimators': [100, 150, 200, 250,\n", " 300, 400, 500]},\n", " scoring='f1_macro')" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rf_model = RandomForestClassifier(class_weight=\"balanced\", random_state=2024)\n", "parameters = {\n", " \"n_estimators\": [100, 150, 200, 250, 300, 400, 500],\n", " \"max_depth\": [3, 4, 5, 6, 7, 8]\n", "}\n", "rf_search = RandomizedSearchCV(estimator=rf_model, param_distributions=parameters, \n", " scoring=\"f1_macro\", cv=5, n_iter=20)\n", "rf_search.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best Parameters: {'n_estimators': 400, 'max_depth': 8}\n", "Best Score: 0.15591886021384346\n" ] } ], "source": [ "print(\"Best Parameters:\", rf_search.best_params_)\n", "print(\"Best Score:\", rf_search.best_score_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### 2.3.3 XGBoost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For XGBoost, we first need to encode the target variable." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "label_encoder = LabelEncoder()\n", "y_train_encoded = label_encoder.fit_transform(y_train)\n", "y_test_encoded = label_encoder.transform(y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "xgb_model = xgb.XGBClassifier(device=\"cuda\", seed=2024)\n", "parameters = {\n", " \"n_estimators\" : [100, 150, 200, 300, 400, 450, 500],\n", " \"max_depth\" : [3, 4, 5, 6, 7, 8],\n", " \"learning_rate\": [0.1, 0.05, 0.01, 0.005, 0.001]\n", "}\n", "xgb_search = RandomizedSearchCV(estimator=xgb_model, param_distributions=parameters,\n", " scoring=\"f1_macro\", cv=5, n_iter=20)\n", "xgb_search.fit(X_train, y_train_encoded)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Best Parameters:\", xgb_search.best_params_)\n", "print(\"Best Score (Macro Average F1):\", xgb_search.best_score_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The best parameters and best score obtained are the following: \n", "```\n", "Best Parameters: {'n_estimators': 450, 'max_depth': 7, 'learning_rate': 0.1} \n", "Best Score (Macro Average F1): 0.17356889596239114\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Evaluating the Performance" ] } ], "metadata": { "kernelspec": { "display_name": "adc", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }