|
|
|
|
|
|
|
import argparse
|
|
|
|
from datasets import load_dataset
|
|
from sentence_transformers import (
|
|
SentenceTransformer,
|
|
SentenceTransformerTrainer,
|
|
SentenceTransformerTrainingArguments,
|
|
)
|
|
from sentence_transformers.evaluation import NanoBEIREvaluator
|
|
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
|
|
from sentence_transformers.training_args import BatchSamplers
|
|
|
|
def main():
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--lr", type=float, default=8e-5)
|
|
parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base")
|
|
args = parser.parse_args()
|
|
lr = args.lr
|
|
model_name = args.model_name
|
|
model_shortname = model_name.split("/")[-1]
|
|
|
|
|
|
model = SentenceTransformer(model_name)
|
|
|
|
|
|
dataset = load_dataset("sentence-transformers/gooaq", split="train")
|
|
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
|
|
train_dataset = dataset_dict["train"]
|
|
eval_dataset = dataset_dict["test"]
|
|
|
|
|
|
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=128)
|
|
|
|
run_name = f"{model_shortname}-gooaq-{lr}"
|
|
|
|
args = SentenceTransformerTrainingArguments(
|
|
|
|
output_dir=f"output/{model_shortname}/{run_name}",
|
|
|
|
num_train_epochs=1,
|
|
per_device_train_batch_size=2048,
|
|
per_device_eval_batch_size=2048,
|
|
learning_rate=lr,
|
|
warmup_ratio=0.05,
|
|
fp16=False,
|
|
bf16=True,
|
|
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
|
|
|
eval_strategy="steps",
|
|
eval_steps=50,
|
|
save_strategy="steps",
|
|
save_steps=50,
|
|
save_total_limit=2,
|
|
logging_steps=10,
|
|
run_name=run_name,
|
|
)
|
|
|
|
|
|
dev_evaluator = NanoBEIREvaluator(dataset_names=["NQ", "MSMARCO"])
|
|
dev_evaluator(model)
|
|
|
|
|
|
trainer = SentenceTransformerTrainer(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
loss=loss,
|
|
evaluator=dev_evaluator,
|
|
)
|
|
trainer.train()
|
|
|
|
|
|
dev_evaluator(model)
|
|
|
|
|
|
model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
|
|
|
|
|
|
model.push_to_hub(run_name, private=False)
|
|
|
|
if __name__ == "__main__":
|
|
main() |