tomaarsen HF staff commited on
Commit
1440571
1 Parent(s): c159890

Update train_st_gooaq.py

Browse files
Files changed (1) hide show
  1. train_st_gooaq.py +87 -86
train_st_gooaq.py CHANGED
@@ -1,87 +1,88 @@
1
- # Copyright 2024 onwards Answer.AI, LightOn, and contributors
2
- # License: Apache-2.0
3
-
4
- import argparse
5
-
6
- from datasets import load_dataset
7
- from sentence_transformers import (
8
- SentenceTransformer,
9
- SentenceTransformerTrainer,
10
- SentenceTransformerTrainingArguments,
11
- )
12
- from sentence_transformers.evaluation import NanoBEIREvaluator
13
- from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
14
- from sentence_transformers.training_args import BatchSamplers
15
-
16
- def main():
17
- # parse the lr & model name
18
- parser = argparse.ArgumentParser()
19
- parser.add_argument("--lr", type=float, default=8e-5)
20
- parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base")
21
- args = parser.parse_args()
22
- lr = args.lr
23
- model_name = args.model_name
24
- model_shortname = model_name.split("/")[-1]
25
-
26
- # 1. Load a model to finetune
27
- model = SentenceTransformer(model_name)
28
-
29
- # 2. Load a dataset to finetune on
30
- dataset = load_dataset("sentence-transformers/gooaq", split="train")
31
- dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
32
- train_dataset = dataset_dict["train"]
33
- eval_dataset = dataset_dict["test"]
34
-
35
- # 3. Define a loss function
36
- loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=128) # Increase mini_batch_size if you have enough VRAM
37
-
38
- run_name = f"{model_shortname}-gooaq-{lr}"
39
- # 4. (Optional) Specify training arguments
40
- args = SentenceTransformerTrainingArguments(
41
- # Required parameter:
42
- output_dir=f"output/{model_shortname}/{run_name}",
43
- # Optional training parameters:
44
- num_train_epochs=1,
45
- per_device_train_batch_size=2048,
46
- per_device_eval_batch_size=2048,
47
- learning_rate=lr,
48
- warmup_ratio=0.05,
49
- fp16=False, # Set to False if GPU can't handle FP16
50
- bf16=True, # Set to True if GPU supports BF16
51
- batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
52
- # Optional tracking/debugging parameters:
53
- eval_strategy="steps",
54
- eval_steps=50,
55
- save_strategy="steps",
56
- save_steps=50,
57
- save_total_limit=2,
58
- logging_steps=10,
59
- run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
60
- )
61
-
62
- # 5. (Optional) Create an evaluator & evaluate the base model
63
- dev_evaluator = NanoBEIREvaluator(dataset_names=["NQ", "MSMARCO"])
64
- dev_evaluator(model)
65
-
66
- # 6. Create a trainer & train
67
- trainer = SentenceTransformerTrainer(
68
- model=model,
69
- args=args,
70
- train_dataset=train_dataset,
71
- eval_dataset=eval_dataset,
72
- loss=loss,
73
- evaluator=dev_evaluator,
74
- )
75
- trainer.train()
76
-
77
- # 7. (Optional) Evaluate the trained model on the evaluator after training
78
- dev_evaluator(model)
79
-
80
- # 8. Save the model
81
- model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
82
-
83
- # 9. (Optional) Push it to the Hugging Face Hub
84
- model.push_to_hub(run_name, private=False)
85
-
86
- if __name__ == "__main__":
 
87
  main()
 
1
+ # Copyright 2024 onwards Answer.AI, LightOn, and contributors
2
+ # License: Apache-2.0
3
+
4
+ import argparse
5
+
6
+ from datasets import load_dataset
7
+ from sentence_transformers import (
8
+ SentenceTransformer,
9
+ SentenceTransformerTrainer,
10
+ SentenceTransformerTrainingArguments,
11
+ )
12
+ from sentence_transformers.evaluation import NanoBEIREvaluator
13
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
14
+ from sentence_transformers.training_args import BatchSamplers
15
+
16
+ def main():
17
+ # parse the lr & model name
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--lr", type=float, default=8e-5)
20
+ parser.add_argument("--model_name", type=str, default="answerdotai/ModernBERT-base")
21
+ args = parser.parse_args()
22
+ lr = args.lr
23
+ model_name = args.model_name
24
+ model_shortname = model_name.split("/")[-1]
25
+
26
+ # 1. Load a model to finetune
27
+ model = SentenceTransformer(model_name)
28
+ model.max_seq_length = 8192
29
+
30
+ # 2. Load a dataset to finetune on
31
+ dataset = load_dataset("sentence-transformers/gooaq", split="train")
32
+ dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
33
+ train_dataset = dataset_dict["train"]
34
+ eval_dataset = dataset_dict["test"]
35
+
36
+ # 3. Define a loss function
37
+ loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=128) # Increase mini_batch_size if you have enough VRAM
38
+
39
+ run_name = f"{model_shortname}-gooaq-{lr}"
40
+ # 4. (Optional) Specify training arguments
41
+ args = SentenceTransformerTrainingArguments(
42
+ # Required parameter:
43
+ output_dir=f"output/{model_shortname}/{run_name}",
44
+ # Optional training parameters:
45
+ num_train_epochs=1,
46
+ per_device_train_batch_size=2048,
47
+ per_device_eval_batch_size=2048,
48
+ learning_rate=lr,
49
+ warmup_ratio=0.05,
50
+ fp16=False, # Set to False if GPU can't handle FP16
51
+ bf16=True, # Set to True if GPU supports BF16
52
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
53
+ # Optional tracking/debugging parameters:
54
+ eval_strategy="steps",
55
+ eval_steps=50,
56
+ save_strategy="steps",
57
+ save_steps=50,
58
+ save_total_limit=2,
59
+ logging_steps=10,
60
+ run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
61
+ )
62
+
63
+ # 5. (Optional) Create an evaluator & evaluate the base model
64
+ dev_evaluator = NanoBEIREvaluator(dataset_names=["NQ", "MSMARCO"])
65
+ dev_evaluator(model)
66
+
67
+ # 6. Create a trainer & train
68
+ trainer = SentenceTransformerTrainer(
69
+ model=model,
70
+ args=args,
71
+ train_dataset=train_dataset,
72
+ eval_dataset=eval_dataset,
73
+ loss=loss,
74
+ evaluator=dev_evaluator,
75
+ )
76
+ trainer.train()
77
+
78
+ # 7. (Optional) Evaluate the trained model on the evaluator after training
79
+ dev_evaluator(model)
80
+
81
+ # 8. Save the model
82
+ model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
83
+
84
+ # 9. (Optional) Push it to the Hugging Face Hub
85
+ model.push_to_hub(run_name, private=False)
86
+
87
+ if __name__ == "__main__":
88
  main()