mT5 Question/Answering fine tuning is generating empty sentences during inference
mT5-small Question Answering training is converging to high accuracy, high validation accuracy, near-zero low loss; however, when testing the model on trained questions, I am always receiving empty answers.
Experiment Language: Arabic
Dataset used: Arabic SQUAD
Optimizer tested: Adam or AdamW with learning rate: 3e-4
loss function: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
Epochs tested: 5, 30
I tried to train the model with 30 epochs and the same result is obtained:
Output: Output: [0, 250099, 1]
It is very strange that the model is converging to high accuracy and low loss and I am getting empty sentences during inference. I validated the dataset questions and answers and they are correct.
Below are some important code snippets:
def preprocess_function(examples):
padding = "max_length"
max_length = 200
inputs = [ex for ex in examples["question"]]
targets = [ex for ex in examples["text"]]
model_inputs = tokenizer(inputs, max_length=max_length, padding=padding, truncation=True)
labels = tokenizer(targets, max_length=max_length, padding=padding, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=tokenizer.pad_token_id,
pad_to_multiple_of=64,
return_tensors="np",
)
tf_train_dataset = model.prepare_tf_dataset(
train_dataset,
collate_fn=data_collator,
batch_size=8,
shuffle=True,
)
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=Adam(3e-5), loss=loss, metrics=['accuracy'])
Output:
below are the output of model.fit() for the first 5 epochs
accuracy โโโโโ
epoch โโโ โโ
loss โโโโโ
val_accuracy โโโโโ
val_loss โโโโโ
Run summary:
accuracy 0.96812 best_epoch 4 best_val_loss 0.21643 epoch 4 loss 0.35643 val_accuracy 0.97813 val_loss 0.21643
Sample question:
Q: ['ู
ุง ูู ููู
ุฉ ุงูุนูุฏ ุจูู ุดุฑูุฉ Under Armor ู Notre Dameุ']
A: ููู ุชู 100 ู ูููู ุฏููุงุฑ
A: [259, 42501, 3234, 966, 548, 36270, 259, 36136, 1]
Input: ู ุง ูู ููู ุฉ ุงูุนูุฏ ุจูู ุดุฑูุฉ Under Armor ู Notre Dameุ
Input: [1415, 7383, 2588, 23283, 402, 27419, 5373, 259, 11319, 8427, 259, 117220, 341, 259, 37126, 34600, 2273, 1]
Output:
Output: [0, 250099, 1]