Add src folder
Browse files- src/__init__.py +0 -0
- src/api_demo.py +101 -0
- src/cli_demo.py +70 -0
- src/export_model.py +19 -0
- src/train_ppo.py +81 -0
- src/train_pt.py +81 -0
- src/train_rm.py +76 -0
- src/train_sft.py +97 -0
- src/utils/__init__.py +19 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/common.cpython-310.pyc +0 -0
- src/utils/__pycache__/config.cpython-310.pyc +0 -0
- src/utils/__pycache__/data_collator.cpython-310.pyc +0 -0
- src/utils/__pycache__/other.cpython-310.pyc +0 -0
- src/utils/__pycache__/pairwise.cpython-310.pyc +0 -0
- src/utils/__pycache__/peft_trainer.cpython-310.pyc +0 -0
- src/utils/__pycache__/ppo.cpython-310.pyc +0 -0
- src/utils/__pycache__/seq2seq.cpython-310.pyc +0 -0
- src/utils/__pycache__/template.cpython-310.pyc +0 -0
- src/utils/common.py +561 -0
- src/utils/config.py +283 -0
- src/utils/data_collator.py +64 -0
- src/utils/other.py +196 -0
- src/utils/pairwise.py +57 -0
- src/utils/peft_trainer.py +132 -0
- src/utils/ppo.py +223 -0
- src/utils/seq2seq.py +96 -0
- src/utils/template.py +138 -0
- src/web_demo.py +150 -0
src/__init__.py
ADDED
File without changes
|
src/api_demo.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements API for fine-tuned models.
|
3 |
+
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
4 |
+
|
5 |
+
# Request:
|
6 |
+
# curl http://127.0.0.1:8000 --header 'Content-Type: application/json' --data '{"prompt": "Hello there!", "history": []}'
|
7 |
+
|
8 |
+
# Response:
|
9 |
+
# {
|
10 |
+
# "response": "'Hi there!'",
|
11 |
+
# "history": "[('Hello there!', 'Hi there!')]",
|
12 |
+
# "status": 200,
|
13 |
+
# "time": "2000-00-00 00:00:00"
|
14 |
+
# }
|
15 |
+
|
16 |
+
|
17 |
+
import json
|
18 |
+
import torch
|
19 |
+
import uvicorn
|
20 |
+
import datetime
|
21 |
+
from fastapi import FastAPI, Request
|
22 |
+
|
23 |
+
from utils import (
|
24 |
+
Template,
|
25 |
+
load_pretrained,
|
26 |
+
prepare_infer_args,
|
27 |
+
get_logits_processor
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def torch_gc():
|
32 |
+
if torch.cuda.is_available():
|
33 |
+
num_gpus = torch.cuda.device_count()
|
34 |
+
for device_id in range(num_gpus):
|
35 |
+
with torch.cuda.device(device_id):
|
36 |
+
torch.cuda.empty_cache()
|
37 |
+
torch.cuda.ipc_collect()
|
38 |
+
|
39 |
+
|
40 |
+
app = FastAPI()
|
41 |
+
|
42 |
+
|
43 |
+
@app.post("/")
|
44 |
+
async def create_item(request: Request):
|
45 |
+
global model, tokenizer, prompt_template, generating_args
|
46 |
+
|
47 |
+
# Parse the request JSON
|
48 |
+
json_post_raw = await request.json()
|
49 |
+
json_post = json.dumps(json_post_raw)
|
50 |
+
json_post_list = json.loads(json_post)
|
51 |
+
prompt = json_post_list.get("prompt")
|
52 |
+
history = json_post_list.get("history")
|
53 |
+
max_new_tokens = json_post_list.get("max_new_tokens", None)
|
54 |
+
top_p = json_post_list.get("top_p", None)
|
55 |
+
temperature = json_post_list.get("temperature", None)
|
56 |
+
|
57 |
+
# Tokenize the input prompt
|
58 |
+
input_ids = tokenizer([prompt_template.get_prompt(prompt, history)], return_tensors="pt")["input_ids"]
|
59 |
+
input_ids = input_ids.to(model.device)
|
60 |
+
|
61 |
+
# Generation arguments
|
62 |
+
gen_kwargs = generating_args.to_dict()
|
63 |
+
gen_kwargs["input_ids"] = input_ids
|
64 |
+
gen_kwargs["logits_processor"] = get_logits_processor()
|
65 |
+
gen_kwargs["max_new_tokens"] = max_new_tokens if max_new_tokens else gen_kwargs["max_new_tokens"]
|
66 |
+
gen_kwargs["top_p"] = top_p if top_p else gen_kwargs["top_p"]
|
67 |
+
gen_kwargs["temperature"] = temperature if temperature else gen_kwargs["temperature"]
|
68 |
+
|
69 |
+
# Generate response
|
70 |
+
with torch.no_grad():
|
71 |
+
generation_output = model.generate(**gen_kwargs)
|
72 |
+
outputs = generation_output.tolist()[0][len(input_ids[0]):]
|
73 |
+
response = tokenizer.decode(outputs, skip_special_tokens=True)
|
74 |
+
|
75 |
+
# Update history
|
76 |
+
history = history + [(prompt, response)]
|
77 |
+
|
78 |
+
# Prepare response
|
79 |
+
now = datetime.datetime.now()
|
80 |
+
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
81 |
+
answer = {
|
82 |
+
"response": repr(response),
|
83 |
+
"history": repr(history),
|
84 |
+
"status": 200,
|
85 |
+
"time": time
|
86 |
+
}
|
87 |
+
|
88 |
+
# Log and clean up
|
89 |
+
log = "[" + time + "] " + "\", prompt:\"" + prompt + "\", response:\"" + repr(response) + "\""
|
90 |
+
print(log)
|
91 |
+
torch_gc()
|
92 |
+
|
93 |
+
return answer
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
98 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
99 |
+
prompt_template = Template(data_args.prompt_template)
|
100 |
+
|
101 |
+
uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
|
src/cli_demo.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements stream chat in command line for fine-tuned models.
|
3 |
+
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
4 |
+
|
5 |
+
|
6 |
+
from utils import (
|
7 |
+
Template,
|
8 |
+
load_pretrained,
|
9 |
+
prepare_infer_args,
|
10 |
+
get_logits_processor
|
11 |
+
)
|
12 |
+
from threading import Thread
|
13 |
+
from transformers import TextIteratorStreamer
|
14 |
+
|
15 |
+
|
16 |
+
def main():
|
17 |
+
|
18 |
+
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
19 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
20 |
+
|
21 |
+
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
|
22 |
+
prompt_template = Template(data_args.prompt_template)
|
23 |
+
|
24 |
+
def predict_and_print(query, history: list) -> list:
|
25 |
+
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
|
26 |
+
input_ids = input_ids.to(model.device)
|
27 |
+
|
28 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
29 |
+
|
30 |
+
gen_kwargs = generating_args.to_dict()
|
31 |
+
gen_kwargs["input_ids"] = input_ids
|
32 |
+
gen_kwargs["logits_processor"] = get_logits_processor()
|
33 |
+
gen_kwargs["streamer"] = streamer
|
34 |
+
|
35 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
36 |
+
thread.start()
|
37 |
+
|
38 |
+
print("{}: ".format(model_name), end="", flush=True)
|
39 |
+
response = ""
|
40 |
+
for new_text in streamer:
|
41 |
+
print(new_text, end="", flush=True)
|
42 |
+
response += new_text
|
43 |
+
print()
|
44 |
+
history = history + [(query, response)]
|
45 |
+
return history
|
46 |
+
|
47 |
+
history = []
|
48 |
+
print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name))
|
49 |
+
while True:
|
50 |
+
try:
|
51 |
+
query = input("\nInput: ")
|
52 |
+
except UnicodeDecodeError:
|
53 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
54 |
+
continue
|
55 |
+
except Exception:
|
56 |
+
raise
|
57 |
+
|
58 |
+
if query.strip() == "stop":
|
59 |
+
break
|
60 |
+
|
61 |
+
if query.strip() == "clear":
|
62 |
+
history = []
|
63 |
+
print("History has been removed.")
|
64 |
+
continue
|
65 |
+
|
66 |
+
history = predict_and_print(query, history)
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
main()
|
src/export_model.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Exports the fine-tuned model.
|
3 |
+
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
4 |
+
|
5 |
+
|
6 |
+
from utils import load_pretrained, prepare_args
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
|
11 |
+
model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
|
12 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
13 |
+
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
14 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
15 |
+
print("model and tokenizer have been saved at:", training_args.output_dir)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
main()
|
src/train_ppo.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements parameter-efficient PPO training of fine-tuned models.
|
3 |
+
# This code is inspired by:
|
4 |
+
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
from torch.optim import AdamW
|
9 |
+
from transformers.optimization import get_scheduler
|
10 |
+
from trl import PPOConfig
|
11 |
+
|
12 |
+
from utils import (
|
13 |
+
DynamicDataCollatorWithPadding,
|
14 |
+
PPOPeftTrainer,
|
15 |
+
LogCallback,
|
16 |
+
load_pretrained,
|
17 |
+
prepare_args,
|
18 |
+
prepare_data,
|
19 |
+
preprocess_data,
|
20 |
+
plot_loss
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def main():
|
25 |
+
|
26 |
+
# Prepare pretrained model and dataset
|
27 |
+
model_args, data_args, training_args, finetuning_args = prepare_args(stage="ppo")
|
28 |
+
dataset = prepare_data(model_args, data_args)
|
29 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
30 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
31 |
+
data_collator = DynamicDataCollatorWithPadding(tokenizer)
|
32 |
+
|
33 |
+
ppo_config = PPOConfig(
|
34 |
+
model_name=model_args.model_name_or_path,
|
35 |
+
learning_rate=training_args.learning_rate,
|
36 |
+
mini_batch_size=training_args.per_device_train_batch_size,
|
37 |
+
batch_size=training_args.per_device_train_batch_size,
|
38 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
39 |
+
ppo_epochs=1,
|
40 |
+
max_grad_norm=training_args.max_grad_norm
|
41 |
+
)
|
42 |
+
|
43 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
|
44 |
+
total_train_batch_size = \
|
45 |
+
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
46 |
+
lr_scheduler = get_scheduler(
|
47 |
+
training_args.lr_scheduler_type,
|
48 |
+
optimizer=optimizer,
|
49 |
+
num_warmup_steps=training_args.warmup_steps,
|
50 |
+
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
|
51 |
+
)
|
52 |
+
|
53 |
+
# Initialize our Trainer
|
54 |
+
ppo_trainer = PPOPeftTrainer(
|
55 |
+
training_args=training_args,
|
56 |
+
finetuning_args=finetuning_args,
|
57 |
+
callbacks=[LogCallback()],
|
58 |
+
config=ppo_config,
|
59 |
+
model=model,
|
60 |
+
ref_model=None,
|
61 |
+
tokenizer=tokenizer,
|
62 |
+
dataset=dataset,
|
63 |
+
data_collator=data_collator,
|
64 |
+
optimizer=optimizer,
|
65 |
+
lr_scheduler=lr_scheduler
|
66 |
+
)
|
67 |
+
|
68 |
+
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
69 |
+
ppo_trainer.save_model()
|
70 |
+
ppo_trainer.save_state() # must be after save_model
|
71 |
+
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
72 |
+
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
73 |
+
|
74 |
+
|
75 |
+
def _mp_fn(index):
|
76 |
+
# For xla_spawn (TPUs)
|
77 |
+
main()
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
main()
|
src/train_pt.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements several parameter-efficient pre-training method.
|
3 |
+
# This code is inspired by
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
from utils import (
|
10 |
+
DynamicDataCollatorWithPadding,
|
11 |
+
PeftTrainer,
|
12 |
+
LogCallback,
|
13 |
+
load_pretrained,
|
14 |
+
prepare_args,
|
15 |
+
prepare_data,
|
16 |
+
preprocess_data,
|
17 |
+
plot_loss
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
|
23 |
+
# Prepare pretrained model and dataset
|
24 |
+
model_args, data_args, training_args, finetuning_args = prepare_args(stage="pt")
|
25 |
+
dataset = prepare_data(model_args, data_args)
|
26 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
|
27 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
|
28 |
+
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
29 |
+
|
30 |
+
# Split the dataset
|
31 |
+
if training_args.do_train:
|
32 |
+
if data_args.dev_ratio > 1e-6:
|
33 |
+
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
34 |
+
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
35 |
+
else:
|
36 |
+
trainer_kwargs = {"train_dataset": dataset}
|
37 |
+
else: # do_eval or do_predict
|
38 |
+
trainer_kwargs = {"eval_dataset": dataset}
|
39 |
+
|
40 |
+
# Initialize our Trainer
|
41 |
+
trainer = PeftTrainer(
|
42 |
+
finetuning_args=finetuning_args,
|
43 |
+
model=model,
|
44 |
+
args=training_args,
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
data_collator=data_collator,
|
47 |
+
callbacks=[LogCallback()],
|
48 |
+
**trainer_kwargs
|
49 |
+
)
|
50 |
+
|
51 |
+
# Training
|
52 |
+
if training_args.do_train:
|
53 |
+
train_result = trainer.train()
|
54 |
+
trainer.log_metrics("train", train_result.metrics)
|
55 |
+
trainer.save_metrics("train", train_result.metrics)
|
56 |
+
trainer.save_state()
|
57 |
+
trainer.save_model()
|
58 |
+
if trainer.is_world_process_zero() and model_args.plot_loss:
|
59 |
+
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
60 |
+
|
61 |
+
# Evaluation
|
62 |
+
if training_args.do_eval:
|
63 |
+
metrics = trainer.evaluate(metric_key_prefix="eval")
|
64 |
+
|
65 |
+
try:
|
66 |
+
perplexity = math.exp(metrics["eval_loss"])
|
67 |
+
except OverflowError:
|
68 |
+
perplexity = float("inf")
|
69 |
+
metrics["perplexity"] = perplexity
|
70 |
+
|
71 |
+
trainer.log_metrics("eval", metrics)
|
72 |
+
trainer.save_metrics("eval", metrics)
|
73 |
+
|
74 |
+
|
75 |
+
def _mp_fn(index):
|
76 |
+
# For xla_spawn (TPUs)
|
77 |
+
main()
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
main()
|
src/train_rm.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements parameter-efficient training of reward models.
|
3 |
+
# This code is inspired by:
|
4 |
+
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
5 |
+
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
6 |
+
|
7 |
+
|
8 |
+
from utils import (
|
9 |
+
PairwiseDataCollatorWithPadding,
|
10 |
+
PairwisePeftTrainer,
|
11 |
+
LogCallback,
|
12 |
+
load_pretrained,
|
13 |
+
prepare_args,
|
14 |
+
prepare_data,
|
15 |
+
preprocess_data,
|
16 |
+
compute_accuracy,
|
17 |
+
plot_loss
|
18 |
+
)
|
19 |
+
|
20 |
+
def main():
|
21 |
+
|
22 |
+
# Prepare pretrained model and dataset
|
23 |
+
model_args, data_args, training_args, finetuning_args = prepare_args(stage="rm")
|
24 |
+
dataset = prepare_data(model_args, data_args)
|
25 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
|
26 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
|
27 |
+
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
28 |
+
|
29 |
+
training_args.remove_unused_columns = False # important for pairwise dataset
|
30 |
+
|
31 |
+
# Split the dataset
|
32 |
+
if training_args.do_train:
|
33 |
+
if data_args.dev_ratio > 1e-6:
|
34 |
+
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
35 |
+
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
36 |
+
else:
|
37 |
+
trainer_kwargs = {"train_dataset": dataset}
|
38 |
+
else: # do_eval or do_predict
|
39 |
+
trainer_kwargs = {"eval_dataset": dataset}
|
40 |
+
|
41 |
+
# Initialize our Trainer
|
42 |
+
trainer = PairwisePeftTrainer(
|
43 |
+
finetuning_args=finetuning_args,
|
44 |
+
model=model,
|
45 |
+
args=training_args,
|
46 |
+
tokenizer=tokenizer,
|
47 |
+
data_collator=data_collator,
|
48 |
+
callbacks=[LogCallback()],
|
49 |
+
compute_metrics=compute_accuracy,
|
50 |
+
**trainer_kwargs
|
51 |
+
)
|
52 |
+
|
53 |
+
# Training
|
54 |
+
if training_args.do_train:
|
55 |
+
train_result = trainer.train()
|
56 |
+
trainer.log_metrics("train", train_result.metrics)
|
57 |
+
trainer.save_metrics("train", train_result.metrics)
|
58 |
+
trainer.save_state()
|
59 |
+
trainer.save_model()
|
60 |
+
if trainer.is_world_process_zero() and model_args.plot_loss:
|
61 |
+
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
62 |
+
|
63 |
+
# Evaluation
|
64 |
+
if training_args.do_eval:
|
65 |
+
metrics = trainer.evaluate(metric_key_prefix="eval")
|
66 |
+
trainer.log_metrics("eval", metrics)
|
67 |
+
trainer.save_metrics("eval", metrics)
|
68 |
+
|
69 |
+
|
70 |
+
def _mp_fn(index):
|
71 |
+
# For xla_spawn (TPUs)
|
72 |
+
main()
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
main()
|
src/train_sft.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements several parameter-efficient supervised fine-tuning method.
|
3 |
+
# This code is inspired by
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
5 |
+
|
6 |
+
|
7 |
+
from utils import (
|
8 |
+
DynamicDataCollatorWithPadding,
|
9 |
+
Seq2SeqPeftTrainer,
|
10 |
+
ComputeMetrics,
|
11 |
+
LogCallback,
|
12 |
+
load_pretrained,
|
13 |
+
prepare_args,
|
14 |
+
prepare_data,
|
15 |
+
preprocess_data,
|
16 |
+
get_logits_processor,
|
17 |
+
plot_loss
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def main():
|
22 |
+
|
23 |
+
# Prepare pretrained model and dataset
|
24 |
+
model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
|
25 |
+
dataset = prepare_data(model_args, data_args)
|
26 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
|
27 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
28 |
+
data_collator = DynamicDataCollatorWithPadding(tokenizer, data_args.ignore_pad_token_for_loss)
|
29 |
+
|
30 |
+
# Override the decoding parameters of Seq2SeqTrainer
|
31 |
+
training_args.generation_max_length = training_args.generation_max_length if \
|
32 |
+
training_args.generation_max_length is not None else data_args.max_target_length
|
33 |
+
training_args.generation_num_beams = data_args.eval_num_beams if \
|
34 |
+
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
35 |
+
|
36 |
+
# Split the dataset
|
37 |
+
if training_args.do_train:
|
38 |
+
if data_args.dev_ratio > 1e-6:
|
39 |
+
dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
|
40 |
+
trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
41 |
+
else:
|
42 |
+
trainer_kwargs = {"train_dataset": dataset}
|
43 |
+
else: # do_eval or do_predict
|
44 |
+
trainer_kwargs = {"eval_dataset": dataset}
|
45 |
+
|
46 |
+
# Initialize our Trainer
|
47 |
+
trainer = Seq2SeqPeftTrainer(
|
48 |
+
finetuning_args=finetuning_args,
|
49 |
+
model=model,
|
50 |
+
args=training_args,
|
51 |
+
tokenizer=tokenizer,
|
52 |
+
data_collator=data_collator,
|
53 |
+
callbacks=[LogCallback()],
|
54 |
+
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
55 |
+
**trainer_kwargs
|
56 |
+
)
|
57 |
+
|
58 |
+
# Keyword arguments for `model.generate`
|
59 |
+
gen_kwargs = {
|
60 |
+
"do_sample": True,
|
61 |
+
"top_p": 0.7,
|
62 |
+
"max_new_tokens": data_args.max_target_length + 1,
|
63 |
+
"temperature": 0.95,
|
64 |
+
"logits_processor": get_logits_processor()
|
65 |
+
}
|
66 |
+
|
67 |
+
# Training
|
68 |
+
if training_args.do_train:
|
69 |
+
train_result = trainer.train()
|
70 |
+
trainer.log_metrics("train", train_result.metrics)
|
71 |
+
trainer.save_metrics("train", train_result.metrics)
|
72 |
+
trainer.save_state()
|
73 |
+
trainer.save_model()
|
74 |
+
if trainer.is_world_process_zero() and model_args.plot_loss:
|
75 |
+
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
76 |
+
|
77 |
+
# Evaluation
|
78 |
+
if training_args.do_eval:
|
79 |
+
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
80 |
+
trainer.log_metrics("eval", metrics)
|
81 |
+
trainer.save_metrics("eval", metrics)
|
82 |
+
|
83 |
+
# Predict
|
84 |
+
if training_args.do_predict:
|
85 |
+
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
86 |
+
trainer.log_metrics("predict", predict_results.metrics)
|
87 |
+
trainer.save_metrics("predict", predict_results.metrics)
|
88 |
+
trainer.save_predictions(predict_results, tokenizer)
|
89 |
+
|
90 |
+
|
91 |
+
def _mp_fn(index):
|
92 |
+
# For xla_spawn (TPUs)
|
93 |
+
main()
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import (
|
2 |
+
load_pretrained,
|
3 |
+
prepare_args,
|
4 |
+
prepare_infer_args,
|
5 |
+
prepare_data,
|
6 |
+
preprocess_data
|
7 |
+
)
|
8 |
+
|
9 |
+
from .data_collator import DynamicDataCollatorWithPadding
|
10 |
+
|
11 |
+
from .peft_trainer import PeftTrainer, LogCallback
|
12 |
+
|
13 |
+
from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
|
14 |
+
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer, compute_accuracy
|
15 |
+
from .ppo import PPOPeftTrainer
|
16 |
+
|
17 |
+
from .template import Template
|
18 |
+
|
19 |
+
from .other import get_logits_processor, plot_loss
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (785 Bytes). View file
|
|
src/utils/__pycache__/common.cpython-310.pyc
ADDED
Binary file (16 kB). View file
|
|
src/utils/__pycache__/config.cpython-310.pyc
ADDED
Binary file (11.3 kB). View file
|
|
src/utils/__pycache__/data_collator.cpython-310.pyc
ADDED
Binary file (2.99 kB). View file
|
|
src/utils/__pycache__/other.cpython-310.pyc
ADDED
Binary file (7.37 kB). View file
|
|
src/utils/__pycache__/pairwise.cpython-310.pyc
ADDED
Binary file (2.93 kB). View file
|
|
src/utils/__pycache__/peft_trainer.cpython-310.pyc
ADDED
Binary file (5.13 kB). View file
|
|
src/utils/__pycache__/ppo.cpython-310.pyc
ADDED
Binary file (7.01 kB). View file
|
|
src/utils/__pycache__/seq2seq.cpython-310.pyc
ADDED
Binary file (4.15 kB). View file
|
|
src/utils/__pycache__/template.cpython-310.pyc
ADDED
Binary file (3.12 kB). View file
|
|
src/utils/common.py
ADDED
@@ -0,0 +1,561 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import hashlib
|
5 |
+
from itertools import chain
|
6 |
+
from typing import List, Literal, Optional, Tuple
|
7 |
+
|
8 |
+
import transformers
|
9 |
+
from transformers import (
|
10 |
+
AutoConfig,
|
11 |
+
AutoModelForCausalLM,
|
12 |
+
AutoTokenizer,
|
13 |
+
HfArgumentParser,
|
14 |
+
Seq2SeqTrainingArguments,
|
15 |
+
BitsAndBytesConfig
|
16 |
+
)
|
17 |
+
from transformers.utils import check_min_version
|
18 |
+
from transformers.utils.versions import require_version
|
19 |
+
from transformers.modeling_utils import PreTrainedModel
|
20 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
21 |
+
|
22 |
+
import datasets
|
23 |
+
from datasets import Dataset, concatenate_datasets, load_dataset
|
24 |
+
|
25 |
+
from peft import (
|
26 |
+
PeftModel,
|
27 |
+
TaskType,
|
28 |
+
LoraConfig,
|
29 |
+
get_peft_model
|
30 |
+
)
|
31 |
+
|
32 |
+
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
33 |
+
|
34 |
+
from trl import AutoModelForCausalLMWithValueHead
|
35 |
+
|
36 |
+
from .config import (
|
37 |
+
ModelArguments,
|
38 |
+
DataTrainingArguments,
|
39 |
+
FinetuningArguments,
|
40 |
+
GeneratingArguments
|
41 |
+
)
|
42 |
+
|
43 |
+
from .template import Template
|
44 |
+
|
45 |
+
from .other import (
|
46 |
+
get_logger,
|
47 |
+
load_trainable_params,
|
48 |
+
load_valuehead_params,
|
49 |
+
print_trainable_params,
|
50 |
+
prepare_model_for_training,
|
51 |
+
IGNORE_INDEX
|
52 |
+
)
|
53 |
+
|
54 |
+
check_min_version("4.29.1")
|
55 |
+
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
56 |
+
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
|
57 |
+
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
|
58 |
+
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4")
|
59 |
+
|
60 |
+
|
61 |
+
logger = get_logger(__name__)
|
62 |
+
|
63 |
+
|
64 |
+
def _init_adapter(
|
65 |
+
model: PreTrainedModel,
|
66 |
+
model_args: ModelArguments,
|
67 |
+
finetuning_args: FinetuningArguments,
|
68 |
+
is_trainable: bool,
|
69 |
+
is_mergeable: bool
|
70 |
+
) -> PreTrainedModel:
|
71 |
+
r"""
|
72 |
+
Initializes the adapters.
|
73 |
+
|
74 |
+
Support full-parameter, freeze and LoRA training.
|
75 |
+
|
76 |
+
Note that the trainable parameters must be cast to float32.
|
77 |
+
"""
|
78 |
+
|
79 |
+
if finetuning_args.finetuning_type == "none" and is_trainable:
|
80 |
+
raise ValueError("You cannot use finetuning_type=none while training.")
|
81 |
+
|
82 |
+
if finetuning_args.finetuning_type == "full":
|
83 |
+
logger.info("Fine-tuning method: Full")
|
84 |
+
model = model.float()
|
85 |
+
|
86 |
+
if finetuning_args.finetuning_type == "freeze":
|
87 |
+
logger.info("Fine-tuning method: Freeze")
|
88 |
+
for name, param in model.named_parameters():
|
89 |
+
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
90 |
+
param.requires_grad_(False)
|
91 |
+
else:
|
92 |
+
param.data = param.data.to(torch.float32)
|
93 |
+
|
94 |
+
if model_args.checkpoint_dir is not None:
|
95 |
+
if finetuning_args.finetuning_type != "lora":
|
96 |
+
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
97 |
+
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
98 |
+
else:
|
99 |
+
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
|
100 |
+
|
101 |
+
if finetuning_args.finetuning_type == "lora":
|
102 |
+
logger.info("Fine-tuning method: LoRA")
|
103 |
+
lastest_checkpoint = None
|
104 |
+
|
105 |
+
if model_args.checkpoint_dir is not None:
|
106 |
+
if os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)) and \
|
107 |
+
not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)):
|
108 |
+
raise ValueError("The given checkpoint may be not a LoRA checkpoint, \
|
109 |
+
please specify `--finetuning_type full/freeze` instead.")
|
110 |
+
|
111 |
+
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
112 |
+
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
113 |
+
else:
|
114 |
+
checkpoints_to_merge = model_args.checkpoint_dir
|
115 |
+
|
116 |
+
for checkpoint in checkpoints_to_merge:
|
117 |
+
model = PeftModel.from_pretrained(model, checkpoint)
|
118 |
+
model = model.merge_and_unload()
|
119 |
+
|
120 |
+
if len(checkpoints_to_merge) > 0:
|
121 |
+
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
122 |
+
|
123 |
+
if lastest_checkpoint is not None: # resume lora training or quantized inference
|
124 |
+
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable)
|
125 |
+
|
126 |
+
if is_trainable and lastest_checkpoint is None: # create new lora weights while training
|
127 |
+
lora_config = LoraConfig(
|
128 |
+
task_type=TaskType.CAUSAL_LM,
|
129 |
+
inference_mode=False,
|
130 |
+
r=finetuning_args.lora_rank,
|
131 |
+
lora_alpha=finetuning_args.lora_alpha,
|
132 |
+
lora_dropout=finetuning_args.lora_dropout,
|
133 |
+
target_modules=finetuning_args.lora_target
|
134 |
+
)
|
135 |
+
model = get_peft_model(model, lora_config)
|
136 |
+
|
137 |
+
if model_args.checkpoint_dir is not None:
|
138 |
+
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
139 |
+
|
140 |
+
return model
|
141 |
+
|
142 |
+
|
143 |
+
def load_pretrained(
|
144 |
+
model_args: ModelArguments,
|
145 |
+
finetuning_args: FinetuningArguments,
|
146 |
+
is_trainable: Optional[bool] = False,
|
147 |
+
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
148 |
+
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
149 |
+
r"""
|
150 |
+
Loads pretrained model and tokenizer.
|
151 |
+
|
152 |
+
Support both training and inference.
|
153 |
+
"""
|
154 |
+
if (not is_trainable) and model_args.checkpoint_dir is None:
|
155 |
+
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
156 |
+
finetuning_args = FinetuningArguments(finetuning_type="none")
|
157 |
+
|
158 |
+
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
159 |
+
"RM and PPO training can only be performed with the LoRA method."
|
160 |
+
|
161 |
+
config_kwargs = {
|
162 |
+
"trust_remote_code": True,
|
163 |
+
"cache_dir": model_args.cache_dir,
|
164 |
+
"revision": model_args.model_revision,
|
165 |
+
"use_auth_token": True if model_args.use_auth_token else None,
|
166 |
+
}
|
167 |
+
|
168 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
169 |
+
model_args.model_name_or_path,
|
170 |
+
use_fast=model_args.use_fast_tokenizer,
|
171 |
+
padding_side="left",
|
172 |
+
**config_kwargs
|
173 |
+
)
|
174 |
+
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
175 |
+
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
176 |
+
|
177 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
178 |
+
is_mergeable = True
|
179 |
+
|
180 |
+
# Quantization configurations (using bitsandbytes library).
|
181 |
+
if model_args.quantization_bit is not None:
|
182 |
+
if model_args.quantization_bit == 8:
|
183 |
+
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
184 |
+
config_kwargs["load_in_8bit"] = True
|
185 |
+
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
186 |
+
load_in_8bit=True,
|
187 |
+
llm_int8_threshold=6.0
|
188 |
+
)
|
189 |
+
elif model_args.quantization_bit == 4:
|
190 |
+
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
191 |
+
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1")
|
192 |
+
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3")
|
193 |
+
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
|
194 |
+
config_kwargs["load_in_4bit"] = True
|
195 |
+
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
196 |
+
load_in_4bit=True,
|
197 |
+
bnb_4bit_compute_dtype=model_args.compute_dtype,
|
198 |
+
bnb_4bit_use_double_quant=model_args.double_quantization,
|
199 |
+
bnb_4bit_quant_type=model_args.quantization_type
|
200 |
+
)
|
201 |
+
is_mergeable = False
|
202 |
+
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
203 |
+
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
204 |
+
|
205 |
+
if not is_trainable: # `device_map=auto` should be used for inference only
|
206 |
+
config_kwargs["device_map"] = "auto"
|
207 |
+
|
208 |
+
# Load and prepare pretrained models (without valuehead).
|
209 |
+
model = AutoModelForCausalLM.from_pretrained(
|
210 |
+
model_args.model_name_or_path,
|
211 |
+
config=config,
|
212 |
+
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
213 |
+
low_cpu_mem_usage=True,
|
214 |
+
**config_kwargs
|
215 |
+
)
|
216 |
+
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
217 |
+
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
218 |
+
|
219 |
+
if stage == "rm" or stage == "ppo": # add value head
|
220 |
+
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
221 |
+
|
222 |
+
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
223 |
+
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
224 |
+
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
225 |
+
model.v_head.load_state_dict({
|
226 |
+
"summary.weight": getattr(model, "reward_head_weight"),
|
227 |
+
"summary.bias": getattr(model, "reward_head_bias")
|
228 |
+
})
|
229 |
+
|
230 |
+
if stage == "ppo": # load reward model
|
231 |
+
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
232 |
+
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
233 |
+
logger.info("Load reward model from {}".format(model_args.reward_model))
|
234 |
+
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
235 |
+
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
236 |
+
|
237 |
+
if not is_trainable:
|
238 |
+
model.requires_grad_(False) # fix all model params
|
239 |
+
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
240 |
+
|
241 |
+
print_trainable_params(model)
|
242 |
+
|
243 |
+
return model, tokenizer
|
244 |
+
|
245 |
+
|
246 |
+
def prepare_args(
|
247 |
+
stage: Literal["pt", "sft", "rm", "ppo"]
|
248 |
+
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]:
|
249 |
+
|
250 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments))
|
251 |
+
|
252 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
253 |
+
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
254 |
+
else:
|
255 |
+
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
|
256 |
+
|
257 |
+
# Setup logging
|
258 |
+
if training_args.should_log:
|
259 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
260 |
+
transformers.utils.logging.set_verbosity_info()
|
261 |
+
|
262 |
+
log_level = training_args.get_process_log_level()
|
263 |
+
datasets.utils.logging.set_verbosity(log_level)
|
264 |
+
transformers.utils.logging.set_verbosity(log_level)
|
265 |
+
transformers.utils.logging.enable_default_handler()
|
266 |
+
transformers.utils.logging.enable_explicit_format()
|
267 |
+
|
268 |
+
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
269 |
+
if stage != "sft" and training_args.predict_with_generate:
|
270 |
+
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
271 |
+
|
272 |
+
if training_args.do_train and training_args.predict_with_generate:
|
273 |
+
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
274 |
+
|
275 |
+
if training_args.do_predict and (not training_args.predict_with_generate):
|
276 |
+
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
277 |
+
|
278 |
+
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
279 |
+
raise ValueError("Quantization is only compatible with the LoRA method.")
|
280 |
+
|
281 |
+
if model_args.quantization_bit is not None and (not training_args.do_train):
|
282 |
+
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
283 |
+
|
284 |
+
if training_args.do_train and (not training_args.fp16):
|
285 |
+
logger.warning("We recommend enable fp16 mixed precision training.")
|
286 |
+
|
287 |
+
if data_args.prompt_template == "alpaca":
|
288 |
+
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
289 |
+
|
290 |
+
if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
|
291 |
+
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
|
292 |
+
training_args.ddp_find_unused_parameters = False
|
293 |
+
|
294 |
+
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
295 |
+
|
296 |
+
if model_args.quantization_bit is not None:
|
297 |
+
if training_args.fp16:
|
298 |
+
model_args.compute_dtype = torch.float16
|
299 |
+
elif training_args.bf16:
|
300 |
+
model_args.compute_dtype = torch.bfloat16
|
301 |
+
else:
|
302 |
+
model_args.compute_dtype = torch.float32
|
303 |
+
|
304 |
+
# Log on each process the small summary:
|
305 |
+
logger.info(
|
306 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
|
307 |
+
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
308 |
+
)
|
309 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
310 |
+
|
311 |
+
# Set seed before initializing model.
|
312 |
+
transformers.set_seed(training_args.seed)
|
313 |
+
|
314 |
+
return model_args, data_args, training_args, finetuning_args
|
315 |
+
|
316 |
+
|
317 |
+
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]:
|
318 |
+
|
319 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments))
|
320 |
+
|
321 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
|
322 |
+
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
323 |
+
else:
|
324 |
+
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
|
325 |
+
|
326 |
+
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
327 |
+
raise ValueError("Quantization is only compatible with the LoRA method.")
|
328 |
+
|
329 |
+
if data_args.prompt_template == "alpaca":
|
330 |
+
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.")
|
331 |
+
|
332 |
+
return model_args, data_args, finetuning_args, generating_args
|
333 |
+
|
334 |
+
|
335 |
+
def prepare_data(
|
336 |
+
model_args: ModelArguments,
|
337 |
+
data_args: DataTrainingArguments
|
338 |
+
) -> Dataset:
|
339 |
+
|
340 |
+
def checksum(file_path, hash):
|
341 |
+
with open(file_path, "rb") as datafile:
|
342 |
+
binary_data = datafile.read()
|
343 |
+
sha1 = hashlib.sha1(binary_data).hexdigest()
|
344 |
+
if sha1 != hash:
|
345 |
+
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
|
346 |
+
|
347 |
+
max_samples = data_args.max_samples
|
348 |
+
all_datasets: List[Dataset] = [] # support multiple datasets
|
349 |
+
|
350 |
+
for dataset_attr in data_args.dataset_list:
|
351 |
+
|
352 |
+
logger.info("Loading dataset {}...".format(dataset_attr))
|
353 |
+
|
354 |
+
if dataset_attr.load_from == "hf_hub":
|
355 |
+
raw_datasets = load_dataset(dataset_attr.dataset_name, cache_dir=model_args.cache_dir)
|
356 |
+
elif dataset_attr.load_from == "script":
|
357 |
+
raw_datasets = load_dataset(
|
358 |
+
os.path.join(data_args.dataset_dir, dataset_attr.dataset_name),
|
359 |
+
cache_dir=model_args.cache_dir
|
360 |
+
)
|
361 |
+
elif dataset_attr.load_from == "file":
|
362 |
+
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name)
|
363 |
+
|
364 |
+
extension = dataset_attr.file_name.split(".")[-1]
|
365 |
+
if extension == "csv":
|
366 |
+
file_type = "csv"
|
367 |
+
elif extension == "json" or extension == "jsonl":
|
368 |
+
file_type = "json"
|
369 |
+
else:
|
370 |
+
file_type = "text"
|
371 |
+
|
372 |
+
if dataset_attr.file_sha1 is not None:
|
373 |
+
checksum(data_file, dataset_attr.file_sha1)
|
374 |
+
else:
|
375 |
+
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
376 |
+
|
377 |
+
raw_datasets = load_dataset(
|
378 |
+
file_type,
|
379 |
+
data_files=data_file,
|
380 |
+
cache_dir=model_args.cache_dir,
|
381 |
+
use_auth_token=True if model_args.use_auth_token else None
|
382 |
+
)
|
383 |
+
else:
|
384 |
+
raise NotImplementedError
|
385 |
+
|
386 |
+
dataset = raw_datasets[data_args.split]
|
387 |
+
|
388 |
+
if max_samples is not None:
|
389 |
+
max_samples_temp = min(len(dataset), max_samples)
|
390 |
+
dataset = dataset.select(range(max_samples_temp))
|
391 |
+
|
392 |
+
dummy_data = [None] * len(dataset)
|
393 |
+
for column_name, target_name in [
|
394 |
+
("prompt_column", "prompt"),
|
395 |
+
("query_column", "query"),
|
396 |
+
("response_column", "response"),
|
397 |
+
("history_column", "history")
|
398 |
+
]: # every dataset will have 4 columns same as each other
|
399 |
+
if getattr(dataset_attr, column_name) != target_name:
|
400 |
+
if getattr(dataset_attr, column_name):
|
401 |
+
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
|
402 |
+
else: # None or empty string
|
403 |
+
dataset = dataset.add_column(target_name, dummy_data)
|
404 |
+
all_datasets.append(dataset)
|
405 |
+
|
406 |
+
if len(data_args.dataset_list) == 1:
|
407 |
+
all_datasets = all_datasets[0]
|
408 |
+
else:
|
409 |
+
all_datasets = concatenate_datasets(all_datasets)
|
410 |
+
|
411 |
+
return all_datasets
|
412 |
+
|
413 |
+
|
414 |
+
def preprocess_data(
|
415 |
+
dataset: Dataset,
|
416 |
+
tokenizer: PreTrainedTokenizer,
|
417 |
+
data_args: DataTrainingArguments,
|
418 |
+
training_args: Seq2SeqTrainingArguments,
|
419 |
+
stage: Literal["pt", "sft", "rm", "ppo"]
|
420 |
+
) -> Dataset:
|
421 |
+
|
422 |
+
column_names = list(dataset.column_names)
|
423 |
+
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
424 |
+
prompt_template = Template(data_args.prompt_template)
|
425 |
+
|
426 |
+
# support question with a single answer or multiple answers
|
427 |
+
def get_dialog(examples):
|
428 |
+
for i in range(len(examples["prompt"])):
|
429 |
+
if examples["prompt"][i] and examples["response"][i]:
|
430 |
+
query, answer = examples["prompt"][i], examples["response"][i]
|
431 |
+
query = query + "\n" + examples["query"][i] if examples["query"][i] else query
|
432 |
+
dialog = prompt_template.get_dialog(query, answer, examples["history"][i], prefix)
|
433 |
+
yield dialog
|
434 |
+
|
435 |
+
def preprocess_pretrain_dataset(examples):
|
436 |
+
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
|
437 |
+
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
|
438 |
+
concatenated_ids = list(chain(*text_ids))
|
439 |
+
total_length = len(concatenated_ids)
|
440 |
+
block_size = data_args.max_source_length - 1
|
441 |
+
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
442 |
+
total_length = (total_length // block_size) * block_size
|
443 |
+
# split by chunks of max_source_length
|
444 |
+
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
|
445 |
+
for i in range(0, total_length, block_size)]
|
446 |
+
return {
|
447 |
+
"input_ids": result,
|
448 |
+
"labels": result.copy()
|
449 |
+
}
|
450 |
+
|
451 |
+
def preprocess_supervised_dataset(examples):
|
452 |
+
# build inputs with format `X [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] Y [EOS]`
|
453 |
+
# for input with history, we build multiple input-label pairs just like:
|
454 |
+
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
455 |
+
model_inputs = {"input_ids": [], "labels": []}
|
456 |
+
for dialog in get_dialog(examples):
|
457 |
+
input_ids, labels = [], []
|
458 |
+
|
459 |
+
for i in range(len(dialog) // 2):
|
460 |
+
source_ids = tokenizer.encode(text=dialog[2*i], add_special_tokens=False)
|
461 |
+
target_ids = tokenizer.encode(text=dialog[2*i+1], add_special_tokens=False)
|
462 |
+
input_ids += source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id]
|
463 |
+
labels += [IGNORE_INDEX] * (len(source_ids) + 1) + target_ids + [tokenizer.eos_token_id]
|
464 |
+
|
465 |
+
model_inputs["input_ids"].append(input_ids[:data_args.max_source_length + data_args.max_target_length])
|
466 |
+
model_inputs["labels"].append(labels[:data_args.max_source_length + data_args.max_target_length])
|
467 |
+
return model_inputs
|
468 |
+
|
469 |
+
def preprocess_unsupervised_dataset(examples):
|
470 |
+
# build inputs with format `X [BOS]` and labels with format `Y [BOS]`
|
471 |
+
model_inputs = {"input_ids": [], "labels": []}
|
472 |
+
for dialog in get_dialog(examples):
|
473 |
+
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
474 |
+
|
475 |
+
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
476 |
+
target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
477 |
+
|
478 |
+
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
479 |
+
source_ids = source_ids[:data_args.max_source_length - 1]
|
480 |
+
if len(target_ids) > data_args.max_target_length - 1: # bos token
|
481 |
+
target_ids = target_ids[:data_args.max_target_length - 1]
|
482 |
+
|
483 |
+
input_ids = source_ids + [tokenizer.bos_token_id]
|
484 |
+
labels = target_ids + [tokenizer.bos_token_id]
|
485 |
+
|
486 |
+
model_inputs["input_ids"].append(input_ids)
|
487 |
+
model_inputs["labels"].append(labels)
|
488 |
+
return model_inputs
|
489 |
+
|
490 |
+
def preprocess_pairwise_dataset(examples):
|
491 |
+
# build input pairs with format `X [BOS] Y1 [EOS]` and `X [BOS] Y2 [EOS]`
|
492 |
+
model_inputs = {"accept_ids": [], "reject_ids": []}
|
493 |
+
for dialog in get_dialog(examples):
|
494 |
+
prompt, answer = "".join(dialog[:-1]), dialog[-1]
|
495 |
+
|
496 |
+
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
497 |
+
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
|
498 |
+
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
|
499 |
+
|
500 |
+
if len(source_ids) > data_args.max_source_length - 1: # bos token
|
501 |
+
source_ids = source_ids[:data_args.max_source_length - 1]
|
502 |
+
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
503 |
+
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
504 |
+
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
505 |
+
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
506 |
+
|
507 |
+
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id]
|
508 |
+
reject_ids = source_ids + [tokenizer.bos_token_id] + reject_ids + [tokenizer.eos_token_id]
|
509 |
+
|
510 |
+
model_inputs["accept_ids"].append(accept_ids)
|
511 |
+
model_inputs["reject_ids"].append(reject_ids)
|
512 |
+
return model_inputs
|
513 |
+
|
514 |
+
def print_supervised_dataset_example(example):
|
515 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
516 |
+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
517 |
+
print("label_ids:\n{}".format(example["labels"]))
|
518 |
+
print("labels:\n{}".format(
|
519 |
+
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))
|
520 |
+
)
|
521 |
+
|
522 |
+
def print_pairwise_dataset_example(example):
|
523 |
+
print("accept_ids:\n{}".format(example["accept_ids"]))
|
524 |
+
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"])))
|
525 |
+
print("reject_ids:\n{}".format(example["reject_ids"]))
|
526 |
+
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"])))
|
527 |
+
|
528 |
+
def print_unsupervised_dataset_example(example):
|
529 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
530 |
+
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"])))
|
531 |
+
|
532 |
+
if stage == "pt":
|
533 |
+
preprocess_function = preprocess_pretrain_dataset
|
534 |
+
elif stage == "sft":
|
535 |
+
preprocess_function = preprocess_unsupervised_dataset \
|
536 |
+
if training_args.predict_with_generate else preprocess_supervised_dataset
|
537 |
+
elif stage == "rm":
|
538 |
+
preprocess_function = preprocess_pairwise_dataset
|
539 |
+
elif stage == "ppo":
|
540 |
+
preprocess_function = preprocess_unsupervised_dataset
|
541 |
+
|
542 |
+
with training_args.main_process_first(desc="dataset map pre-processing"):
|
543 |
+
dataset = dataset.map(
|
544 |
+
preprocess_function,
|
545 |
+
batched=True,
|
546 |
+
num_proc=data_args.preprocessing_num_workers,
|
547 |
+
remove_columns=column_names,
|
548 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
549 |
+
desc="Running tokenizer on dataset"
|
550 |
+
)
|
551 |
+
|
552 |
+
if stage == "pt":
|
553 |
+
print_unsupervised_dataset_example(dataset[0])
|
554 |
+
elif stage == "sft":
|
555 |
+
print_supervised_dataset_example(dataset[0])
|
556 |
+
elif stage == "rm":
|
557 |
+
print_pairwise_dataset_example(dataset[0])
|
558 |
+
elif stage == "ppo":
|
559 |
+
print_unsupervised_dataset_example(dataset[0])
|
560 |
+
|
561 |
+
return dataset
|
src/utils/config.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from typing import Any, Dict, List, Literal, Optional
|
5 |
+
from dataclasses import asdict, dataclass, field
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class DatasetAttr:
|
10 |
+
|
11 |
+
load_from: str
|
12 |
+
dataset_name: Optional[str] = None
|
13 |
+
file_name: Optional[str] = None
|
14 |
+
file_sha1: Optional[str] = None
|
15 |
+
|
16 |
+
def __repr__(self) -> str:
|
17 |
+
if self.dataset_name is not None:
|
18 |
+
return self.dataset_name
|
19 |
+
else:
|
20 |
+
return self.file_name
|
21 |
+
|
22 |
+
def __post_init__(self):
|
23 |
+
self.prompt_column = "instruction"
|
24 |
+
self.query_column = "input"
|
25 |
+
self.response_column = "output"
|
26 |
+
self.history_column = None
|
27 |
+
|
28 |
+
|
29 |
+
@dataclass
|
30 |
+
class ModelArguments:
|
31 |
+
"""
|
32 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
33 |
+
"""
|
34 |
+
model_name_or_path: str = field(
|
35 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
|
36 |
+
)
|
37 |
+
cache_dir: Optional[str] = field(
|
38 |
+
default=None,
|
39 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
40 |
+
)
|
41 |
+
use_fast_tokenizer: Optional[bool] = field(
|
42 |
+
default=False,
|
43 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
44 |
+
)
|
45 |
+
use_auth_token: Optional[bool] = field(
|
46 |
+
default=False,
|
47 |
+
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
48 |
+
)
|
49 |
+
model_revision: Optional[str] = field(
|
50 |
+
default="main",
|
51 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
52 |
+
)
|
53 |
+
quantization_bit: Optional[int] = field(
|
54 |
+
default=None,
|
55 |
+
metadata={"help": "The number of bits to quantize the model."}
|
56 |
+
)
|
57 |
+
quantization_type: Optional[Literal["fp4", "nf4"]] = field(
|
58 |
+
default="nf4",
|
59 |
+
metadata={"help": "Quantization data type to use in int4 training."}
|
60 |
+
)
|
61 |
+
double_quantization: Optional[bool] = field(
|
62 |
+
default=True,
|
63 |
+
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
64 |
+
)
|
65 |
+
compute_dtype: Optional[torch.dtype] = field(
|
66 |
+
default=None,
|
67 |
+
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
68 |
+
)
|
69 |
+
checkpoint_dir: Optional[str] = field(
|
70 |
+
default=None,
|
71 |
+
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
72 |
+
)
|
73 |
+
reward_model: Optional[str] = field(
|
74 |
+
default=None,
|
75 |
+
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
76 |
+
)
|
77 |
+
resume_lora_training: Optional[bool] = field(
|
78 |
+
default=True,
|
79 |
+
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
80 |
+
)
|
81 |
+
plot_loss: Optional[bool] = field(
|
82 |
+
default=False,
|
83 |
+
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
84 |
+
)
|
85 |
+
|
86 |
+
def __post_init__(self):
|
87 |
+
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
88 |
+
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
89 |
+
|
90 |
+
if self.quantization_bit is not None:
|
91 |
+
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class DataTrainingArguments:
|
95 |
+
"""
|
96 |
+
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
97 |
+
"""
|
98 |
+
dataset: Optional[str] = field(
|
99 |
+
default="alpaca_zh",
|
100 |
+
metadata={"help": "The name of provided dataset(s) to use. Use comma to separate multiple datasets."}
|
101 |
+
)
|
102 |
+
dataset_dir: Optional[str] = field(
|
103 |
+
default="data",
|
104 |
+
metadata={"help": "The name of the folder containing datasets."}
|
105 |
+
)
|
106 |
+
split: Optional[str] = field(
|
107 |
+
default="train",
|
108 |
+
metadata={"help": "Which dataset split to use for training and evaluation."}
|
109 |
+
)
|
110 |
+
overwrite_cache: Optional[bool] = field(
|
111 |
+
default=False,
|
112 |
+
metadata={"help": "Overwrite the cached training and evaluation sets."}
|
113 |
+
)
|
114 |
+
preprocessing_num_workers: Optional[int] = field(
|
115 |
+
default=None,
|
116 |
+
metadata={"help": "The number of processes to use for the preprocessing."}
|
117 |
+
)
|
118 |
+
max_source_length: Optional[int] = field(
|
119 |
+
default=512,
|
120 |
+
metadata={"help": "The maximum total input sequence length after tokenization."}
|
121 |
+
)
|
122 |
+
max_target_length: Optional[int] = field(
|
123 |
+
default=512,
|
124 |
+
metadata={"help": "The maximum total output sequence length after tokenization."}
|
125 |
+
)
|
126 |
+
max_samples: Optional[int] = field(
|
127 |
+
default=None,
|
128 |
+
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
129 |
+
)
|
130 |
+
eval_num_beams: Optional[int] = field(
|
131 |
+
default=None,
|
132 |
+
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
|
133 |
+
)
|
134 |
+
ignore_pad_token_for_loss: Optional[bool] = field(
|
135 |
+
default=True,
|
136 |
+
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
137 |
+
)
|
138 |
+
source_prefix: Optional[str] = field(
|
139 |
+
default=None,
|
140 |
+
metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
141 |
+
)
|
142 |
+
dev_ratio: Optional[float] = field(
|
143 |
+
default=0,
|
144 |
+
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
145 |
+
)
|
146 |
+
prompt_template: Optional[str] = field(
|
147 |
+
default="alpaca",
|
148 |
+
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
149 |
+
)
|
150 |
+
|
151 |
+
def __post_init__(self): # support mixing multiple datasets
|
152 |
+
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
153 |
+
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
154 |
+
dataset_info = json.load(f)
|
155 |
+
|
156 |
+
self.dataset_list: List[DatasetAttr] = []
|
157 |
+
for name in dataset_names:
|
158 |
+
if name not in dataset_info:
|
159 |
+
raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
|
160 |
+
|
161 |
+
if "hf_hub_url" in dataset_info[name]:
|
162 |
+
dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
|
163 |
+
elif "script_url" in dataset_info[name]:
|
164 |
+
dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
|
165 |
+
else:
|
166 |
+
dataset_attr = DatasetAttr(
|
167 |
+
"file",
|
168 |
+
file_name=dataset_info[name]["file_name"],
|
169 |
+
file_sha1=dataset_info[name].get("file_sha1", None)
|
170 |
+
)
|
171 |
+
|
172 |
+
if "columns" in dataset_info[name]:
|
173 |
+
dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
|
174 |
+
dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
|
175 |
+
dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
|
176 |
+
dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
|
177 |
+
|
178 |
+
self.dataset_list.append(dataset_attr)
|
179 |
+
|
180 |
+
|
181 |
+
@dataclass
|
182 |
+
class FinetuningArguments:
|
183 |
+
"""
|
184 |
+
Arguments pertaining to which techniques we are going to fine-tuning with.
|
185 |
+
"""
|
186 |
+
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
187 |
+
default="lora",
|
188 |
+
metadata={"help": "Which fine-tuning method to use."}
|
189 |
+
)
|
190 |
+
num_layer_trainable: Optional[int] = field(
|
191 |
+
default=3,
|
192 |
+
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
193 |
+
)
|
194 |
+
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
195 |
+
default="mlp",
|
196 |
+
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
197 |
+
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
198 |
+
BLOOM choices: [\"mlp\", \"self_attention\"], \
|
199 |
+
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
200 |
+
)
|
201 |
+
lora_rank: Optional[int] = field(
|
202 |
+
default=8,
|
203 |
+
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
|
204 |
+
)
|
205 |
+
lora_alpha: Optional[float] = field(
|
206 |
+
default=32.0,
|
207 |
+
metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
|
208 |
+
)
|
209 |
+
lora_dropout: Optional[float] = field(
|
210 |
+
default=0.1,
|
211 |
+
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
212 |
+
)
|
213 |
+
lora_target: Optional[str] = field(
|
214 |
+
default="q_proj,v_proj",
|
215 |
+
metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
|
216 |
+
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
217 |
+
BLOOM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
218 |
+
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
219 |
+
)
|
220 |
+
|
221 |
+
def __post_init__(self):
|
222 |
+
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
223 |
+
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
224 |
+
|
225 |
+
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
226 |
+
trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
|
227 |
+
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
228 |
+
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
229 |
+
|
230 |
+
self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
231 |
+
|
232 |
+
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
233 |
+
|
234 |
+
def save_to_json(self, json_path: str):
|
235 |
+
"""Saves the content of this instance in JSON format inside `json_path`."""
|
236 |
+
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
237 |
+
with open(json_path, "w", encoding="utf-8") as f:
|
238 |
+
f.write(json_string)
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def load_from_json(cls, json_path: str):
|
242 |
+
"""Creates an instance from the content of `json_path`."""
|
243 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
244 |
+
text = f.read()
|
245 |
+
return cls(**json.loads(text))
|
246 |
+
|
247 |
+
|
248 |
+
@dataclass
|
249 |
+
class GeneratingArguments:
|
250 |
+
"""
|
251 |
+
Arguments pertaining to specify the decoding parameters.
|
252 |
+
"""
|
253 |
+
do_sample: Optional[bool] = field(
|
254 |
+
default=True,
|
255 |
+
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
|
256 |
+
)
|
257 |
+
temperature: Optional[float] = field(
|
258 |
+
default=0.95,
|
259 |
+
metadata={"help": "The value used to modulate the next token probabilities."}
|
260 |
+
)
|
261 |
+
top_p: Optional[float] = field(
|
262 |
+
default=0.7,
|
263 |
+
metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
|
264 |
+
)
|
265 |
+
top_k: Optional[int] = field(
|
266 |
+
default=50,
|
267 |
+
metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
|
268 |
+
)
|
269 |
+
num_beams: Optional[int] = field(
|
270 |
+
default=1,
|
271 |
+
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
272 |
+
)
|
273 |
+
max_new_tokens: Optional[int] = field(
|
274 |
+
default=512,
|
275 |
+
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
276 |
+
)
|
277 |
+
repetition_penalty: Optional[float] = field(
|
278 |
+
default=1.0,
|
279 |
+
metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
|
280 |
+
)
|
281 |
+
|
282 |
+
def to_dict(self) -> Dict[str, Any]:
|
283 |
+
return asdict(self)
|
src/utils/data_collator.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from typing import Dict, Optional, Sequence, Union
|
4 |
+
|
5 |
+
from transformers import DataCollatorWithPadding, BatchEncoding
|
6 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
7 |
+
|
8 |
+
from .other import IGNORE_INDEX
|
9 |
+
|
10 |
+
|
11 |
+
class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
|
12 |
+
r"""
|
13 |
+
Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data.
|
14 |
+
"""
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
tokenizer: PreTrainedTokenizer,
|
18 |
+
ignore_pad_token_for_loss: Optional[bool] = False
|
19 |
+
):
|
20 |
+
super().__init__(tokenizer, padding=True)
|
21 |
+
self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
|
22 |
+
|
23 |
+
def get_attention_masks(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
|
24 |
+
r"""
|
25 |
+
Generates attention masks for left-padded sequences.
|
26 |
+
"""
|
27 |
+
batch_size, seq_length = input_ids.size()
|
28 |
+
attention_mask = torch.ones((batch_size, seq_length), device=device)
|
29 |
+
for i, seq in enumerate(input_ids):
|
30 |
+
attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
|
31 |
+
attention_mask = attention_mask.bool()
|
32 |
+
return attention_mask
|
33 |
+
|
34 |
+
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> BatchEncoding:
|
35 |
+
r"""
|
36 |
+
Pads batched data to the longest sequence in the batch.
|
37 |
+
|
38 |
+
We adopt left-padding in both training and evaluation.
|
39 |
+
"""
|
40 |
+
if isinstance(features[0]["input_ids"], torch.Tensor):
|
41 |
+
input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
|
42 |
+
else:
|
43 |
+
input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
|
44 |
+
|
45 |
+
if "labels" in features[0]:
|
46 |
+
if isinstance(features[0]["labels"], torch.Tensor):
|
47 |
+
labels = [feature["labels"].clone().detach().flip(0) for feature in features]
|
48 |
+
else:
|
49 |
+
labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
|
50 |
+
input_ids = input_ids + labels # pad them to the same length
|
51 |
+
|
52 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id).flip(-1)
|
53 |
+
|
54 |
+
batch = {}
|
55 |
+
|
56 |
+
if "labels" in features[0]:
|
57 |
+
input_ids, labels = input_ids.split(len(features), dim=0)
|
58 |
+
labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
|
59 |
+
batch["labels"] = labels
|
60 |
+
|
61 |
+
batch["input_ids"] = input_ids
|
62 |
+
batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
|
63 |
+
|
64 |
+
return BatchEncoding(batch)
|
src/utils/other.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
from typing import Dict, List, Optional
|
7 |
+
|
8 |
+
from transformers.trainer import TRAINER_STATE_NAME
|
9 |
+
from transformers.modeling_utils import PreTrainedModel
|
10 |
+
from transformers.generation.utils import LogitsProcessorList
|
11 |
+
from transformers.generation.logits_process import LogitsProcessor
|
12 |
+
|
13 |
+
from peft.utils import WEIGHTS_NAME
|
14 |
+
|
15 |
+
|
16 |
+
IGNORE_INDEX = -100
|
17 |
+
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
18 |
+
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
19 |
+
|
20 |
+
|
21 |
+
def get_logger(name: str) -> logging.Logger:
|
22 |
+
return logging.getLogger(name)
|
23 |
+
|
24 |
+
|
25 |
+
logging.basicConfig(
|
26 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
27 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
28 |
+
level=logging.INFO,
|
29 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class AverageMeter:
|
37 |
+
r"""
|
38 |
+
Computes and stores the average and current value.
|
39 |
+
"""
|
40 |
+
def __init__(self):
|
41 |
+
self.reset()
|
42 |
+
|
43 |
+
def reset(self):
|
44 |
+
self.val = 0
|
45 |
+
self.avg = 0
|
46 |
+
self.sum = 0
|
47 |
+
self.count = 0
|
48 |
+
|
49 |
+
def update(self, val, n=1):
|
50 |
+
self.val = val
|
51 |
+
self.sum += val * n
|
52 |
+
self.count += n
|
53 |
+
self.avg = self.sum / self.count
|
54 |
+
|
55 |
+
|
56 |
+
# Avoid runtime error in model.generate(do_sample=True).
|
57 |
+
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
58 |
+
|
59 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
60 |
+
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
61 |
+
scores.zero_()
|
62 |
+
scores[..., 0] = 1.0
|
63 |
+
return scores
|
64 |
+
|
65 |
+
|
66 |
+
def get_logits_processor() -> LogitsProcessorList:
|
67 |
+
logits_processor = LogitsProcessorList()
|
68 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
69 |
+
return logits_processor
|
70 |
+
|
71 |
+
|
72 |
+
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
73 |
+
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
74 |
+
def prepare_model_for_training(
|
75 |
+
model: PreTrainedModel,
|
76 |
+
finetuning_type: str,
|
77 |
+
output_embedding_layer_name: Optional[str] = "lm_head",
|
78 |
+
use_gradient_checkpointing: Optional[bool] = True,
|
79 |
+
layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
|
80 |
+
) -> PreTrainedModel:
|
81 |
+
|
82 |
+
for name, param in model.named_parameters():
|
83 |
+
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
84 |
+
param.data = param.data.to(torch.float32)
|
85 |
+
|
86 |
+
if use_gradient_checkpointing:
|
87 |
+
if hasattr(model, "enable_input_require_grads"):
|
88 |
+
model.enable_input_require_grads()
|
89 |
+
else:
|
90 |
+
def make_inputs_require_grad(module, input, output):
|
91 |
+
output.requires_grad_(True)
|
92 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
93 |
+
|
94 |
+
model.gradient_checkpointing_enable()
|
95 |
+
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
96 |
+
|
97 |
+
if finetuning_type != "full" and hasattr(model, output_embedding_layer_name):
|
98 |
+
output_embedding_layer: torch.nn.Linear = getattr(model, output_embedding_layer_name)
|
99 |
+
input_dtype = output_embedding_layer.weight.dtype
|
100 |
+
|
101 |
+
class CastOutputToFloat(torch.nn.Sequential):
|
102 |
+
|
103 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
104 |
+
return super().forward(x.to(input_dtype)).to(torch.float32)
|
105 |
+
|
106 |
+
setattr(model, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
|
107 |
+
|
108 |
+
return model
|
109 |
+
|
110 |
+
|
111 |
+
def print_trainable_params(model: torch.nn.Module) -> None:
|
112 |
+
trainable_params, all_param = 0, 0
|
113 |
+
for param in model.parameters():
|
114 |
+
num_params = param.numel()
|
115 |
+
# if using DS Zero 3 and the weights are initialized empty
|
116 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
117 |
+
num_params = param.ds_numel
|
118 |
+
all_param += num_params
|
119 |
+
if param.requires_grad:
|
120 |
+
trainable_params += num_params
|
121 |
+
print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
122 |
+
trainable_params, all_param, 100 * trainable_params / all_param))
|
123 |
+
|
124 |
+
|
125 |
+
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
|
126 |
+
state_dict = model.state_dict()
|
127 |
+
filtered_state_dict = {}
|
128 |
+
|
129 |
+
for k, v in model.named_parameters():
|
130 |
+
if v.requires_grad:
|
131 |
+
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
132 |
+
|
133 |
+
return filtered_state_dict
|
134 |
+
|
135 |
+
|
136 |
+
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
137 |
+
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
138 |
+
if not os.path.exists(weights_file):
|
139 |
+
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
140 |
+
return False
|
141 |
+
model_state_dict = torch.load(weights_file, map_location="cpu")
|
142 |
+
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
143 |
+
return True
|
144 |
+
|
145 |
+
|
146 |
+
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
147 |
+
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
148 |
+
if not os.path.exists(valuehead_file):
|
149 |
+
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
150 |
+
return False
|
151 |
+
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
152 |
+
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
153 |
+
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
154 |
+
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
155 |
+
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
156 |
+
return True
|
157 |
+
|
158 |
+
|
159 |
+
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:
|
160 |
+
r"""
|
161 |
+
EMA implementation according to TensorBoard.
|
162 |
+
"""
|
163 |
+
last = scalars[0]
|
164 |
+
smoothed = list()
|
165 |
+
for next_val in scalars:
|
166 |
+
smoothed_val = last * weight + (1 - weight) * next_val
|
167 |
+
smoothed.append(smoothed_val)
|
168 |
+
last = smoothed_val
|
169 |
+
return smoothed
|
170 |
+
|
171 |
+
|
172 |
+
def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
|
173 |
+
import matplotlib.pyplot as plt
|
174 |
+
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
175 |
+
data = json.load(f)
|
176 |
+
|
177 |
+
for key in keys:
|
178 |
+
steps, metrics = [], []
|
179 |
+
for i in range(len(data["log_history"])):
|
180 |
+
if key in data["log_history"][i]:
|
181 |
+
steps.append(data["log_history"][i]["step"])
|
182 |
+
metrics.append(data["log_history"][i][key])
|
183 |
+
|
184 |
+
if len(metrics) == 0:
|
185 |
+
logger.warning(f"No metric {key} to plot.")
|
186 |
+
continue
|
187 |
+
|
188 |
+
plt.figure()
|
189 |
+
plt.plot(steps, metrics, alpha=0.4, label="original")
|
190 |
+
plt.plot(steps, smooth(metrics), label="smoothed")
|
191 |
+
plt.title("training {} of {}".format(key, save_dictionary))
|
192 |
+
plt.xlabel("step")
|
193 |
+
plt.ylabel(key)
|
194 |
+
plt.legend()
|
195 |
+
plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
|
196 |
+
print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
|
src/utils/pairwise.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from typing import Dict, Sequence, Tuple, Union
|
4 |
+
|
5 |
+
from .data_collator import DynamicDataCollatorWithPadding
|
6 |
+
|
7 |
+
from .peft_trainer import PeftTrainer
|
8 |
+
|
9 |
+
from .other import get_logger
|
10 |
+
|
11 |
+
logger = get_logger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
15 |
+
preds, _ = eval_preds
|
16 |
+
preds = np.array(preds)
|
17 |
+
return {"accuracy": (preds[:, 0] > preds[:, 1]).sum() / len(preds)}
|
18 |
+
|
19 |
+
|
20 |
+
class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
|
21 |
+
r"""
|
22 |
+
Data collator for pairwise data.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __call__(self, features: Sequence[Dict[str, Union[torch.Tensor, Sequence[int]]]]) -> Dict[str, torch.Tensor]:
|
26 |
+
r"""
|
27 |
+
Pads batched data to the longest sequence in the batch.
|
28 |
+
|
29 |
+
We generate 2 * n examples where the first n examples represent chosen examples and
|
30 |
+
the last n examples represent rejected examples.
|
31 |
+
"""
|
32 |
+
features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
|
33 |
+
return super().__call__(features)
|
34 |
+
|
35 |
+
|
36 |
+
class PairwisePeftTrainer(PeftTrainer):
|
37 |
+
r"""
|
38 |
+
Inherits PeftTrainer to compute pairwise loss.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, *args, **kwargs):
|
42 |
+
super().__init__(*args, **kwargs)
|
43 |
+
self.can_return_loss = True # override property to return eval_loss
|
44 |
+
|
45 |
+
def compute_loss(self, model, inputs, return_outputs=False):
|
46 |
+
r"""
|
47 |
+
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
48 |
+
|
49 |
+
We use score on the EOS token to represent reward of the whole sentence.
|
50 |
+
|
51 |
+
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
52 |
+
"""
|
53 |
+
batch_size = inputs["input_ids"].size(0) // 2
|
54 |
+
_, _, values = model(**inputs)
|
55 |
+
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
56 |
+
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
57 |
+
return (loss, torch.stack((r_accept, r_reject), dim=-1)) if return_outputs else loss
|
src/utils/peft_trainer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
from typing import Dict, Optional
|
6 |
+
from datetime import timedelta
|
7 |
+
|
8 |
+
from transformers import (
|
9 |
+
Seq2SeqTrainer,
|
10 |
+
TrainerCallback,
|
11 |
+
TrainerControl,
|
12 |
+
TrainerState,
|
13 |
+
TrainingArguments
|
14 |
+
)
|
15 |
+
|
16 |
+
from transformers.trainer import TRAINING_ARGS_NAME
|
17 |
+
from transformers.modeling_utils import unwrap_model
|
18 |
+
|
19 |
+
from peft.utils.other import WEIGHTS_NAME
|
20 |
+
|
21 |
+
from .config import FinetuningArguments
|
22 |
+
|
23 |
+
from .other import (
|
24 |
+
get_logger,
|
25 |
+
get_state_dict,
|
26 |
+
load_trainable_params,
|
27 |
+
load_valuehead_params,
|
28 |
+
FINETUNING_ARGS_NAME,
|
29 |
+
VALUE_HEAD_FILE_NAME
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class LogCallback(TrainerCallback):
|
37 |
+
r"""
|
38 |
+
TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
|
39 |
+
The on_log function primarily collects process parameters during training, such as training loss, learning rate,
|
40 |
+
and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
|
41 |
+
time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
|
42 |
+
purposes.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self):
|
46 |
+
self.start_time = time.time()
|
47 |
+
|
48 |
+
def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
|
49 |
+
r"""
|
50 |
+
Event called after logging the last logs.
|
51 |
+
"""
|
52 |
+
if "loss" not in state.log_history[-1]:
|
53 |
+
return
|
54 |
+
cur_time = time.time()
|
55 |
+
cur_steps = state.log_history[-1].get("step")
|
56 |
+
elapsed_time = cur_time - self.start_time
|
57 |
+
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
58 |
+
remaining_steps = state.max_steps - cur_steps
|
59 |
+
remaining_time = remaining_steps * avg_time_per_step
|
60 |
+
log_dict = {
|
61 |
+
"current_steps": cur_steps,
|
62 |
+
"total_steps": state.max_steps,
|
63 |
+
"loss": state.log_history[-1].get("loss", None),
|
64 |
+
"reward": state.log_history[-1].get("reward", None),
|
65 |
+
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
66 |
+
"epoch": state.log_history[-1].get("epoch", None),
|
67 |
+
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
68 |
+
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
69 |
+
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
70 |
+
}
|
71 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
72 |
+
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
|
73 |
+
f.write(json.dumps(log_dict) + "\n")
|
74 |
+
|
75 |
+
|
76 |
+
class PeftTrainer(Seq2SeqTrainer):
|
77 |
+
r"""
|
78 |
+
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
|
82 |
+
super().__init__(**kwargs)
|
83 |
+
self.finetuning_args = finetuning_args
|
84 |
+
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
85 |
+
logger.warning("Previous log file in this folder will be deleted.")
|
86 |
+
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
87 |
+
|
88 |
+
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
89 |
+
r"""
|
90 |
+
Saves trainable parameters as model checkpoint.
|
91 |
+
|
92 |
+
This function will only be executed at the process zero.
|
93 |
+
|
94 |
+
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
95 |
+
"""
|
96 |
+
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
97 |
+
os.makedirs(output_dir, exist_ok=True)
|
98 |
+
logger.info(f"Saving model checkpoint to {output_dir}")
|
99 |
+
model = unwrap_model(self.model)
|
100 |
+
|
101 |
+
if hasattr(model, "pretrained_model"): # for models with valuehead
|
102 |
+
backbone_model = getattr(model, "pretrained_model")
|
103 |
+
else:
|
104 |
+
backbone_model = model
|
105 |
+
|
106 |
+
if hasattr(backbone_model, "peft_config"): # peft methods
|
107 |
+
backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) # save lora weights
|
108 |
+
else:
|
109 |
+
torch.save(get_state_dict(backbone_model), os.path.join(output_dir, WEIGHTS_NAME)) # save trainable weights
|
110 |
+
|
111 |
+
if hasattr(model, "v_head"): # save valuehead weights
|
112 |
+
torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
113 |
+
|
114 |
+
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
115 |
+
f.write(self.args.to_json_string() + "\n")
|
116 |
+
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
117 |
+
|
118 |
+
def _load_best_model(self):
|
119 |
+
r"""
|
120 |
+
Loads trainable parameters from model checkpoint.
|
121 |
+
|
122 |
+
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
123 |
+
"""
|
124 |
+
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
125 |
+
model = unwrap_model(self.model)
|
126 |
+
if hasattr(model, "peft_config"): # peft methods
|
127 |
+
model.load_adapter(self.state.best_model_checkpoint, getattr(model, "active_adapter"))
|
128 |
+
else:
|
129 |
+
load_trainable_params(model, self.state.best_model_checkpoint)
|
130 |
+
|
131 |
+
if hasattr(model, "v_head"):
|
132 |
+
load_valuehead_params(model, self.state.best_model_checkpoint)
|
src/utils/ppo.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
from typing import Callable, Dict, List, Literal, Optional, Tuple
|
6 |
+
|
7 |
+
from transformers import Seq2SeqTrainingArguments, TrainerState
|
8 |
+
from transformers.modeling_utils import PreTrainedModel
|
9 |
+
|
10 |
+
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
|
11 |
+
from trl.core import LengthSampler
|
12 |
+
|
13 |
+
from .peft_trainer import PeftTrainer, LogCallback
|
14 |
+
|
15 |
+
from .config import FinetuningArguments
|
16 |
+
|
17 |
+
from .other import (
|
18 |
+
AverageMeter,
|
19 |
+
get_logger,
|
20 |
+
get_logits_processor
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
logger = get_logger(__name__)
|
25 |
+
|
26 |
+
|
27 |
+
def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
|
28 |
+
if target == "reward": # save original head temporarily
|
29 |
+
valuehead_state_dict = model.v_head.state_dict()
|
30 |
+
|
31 |
+
setattr(model, "origin_head_weight", valuehead_state_dict["summary.weight"])
|
32 |
+
setattr(model, "origin_head_bias", valuehead_state_dict["summary.bias"])
|
33 |
+
|
34 |
+
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
35 |
+
model.v_head.load_state_dict({
|
36 |
+
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
37 |
+
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
38 |
+
})
|
39 |
+
|
40 |
+
|
41 |
+
def cast_layernorm_dtype(
|
42 |
+
model: AutoModelForCausalLMWithValueHead,
|
43 |
+
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting
|
44 |
+
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
45 |
+
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
|
46 |
+
|
47 |
+
layer_norm_state_dict = {}
|
48 |
+
|
49 |
+
for name, param in model.named_parameters():
|
50 |
+
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
51 |
+
if layer_norm_params is not None:
|
52 |
+
param.data = layer_norm_params[name] # restore float32 weights
|
53 |
+
else:
|
54 |
+
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
55 |
+
param.data = param.data.to(torch.float16)
|
56 |
+
|
57 |
+
return model, layer_norm_state_dict
|
58 |
+
|
59 |
+
|
60 |
+
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
61 |
+
r"""
|
62 |
+
Inherits PPOTrainer.
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
training_args: Seq2SeqTrainingArguments,
|
68 |
+
finetuning_args: FinetuningArguments,
|
69 |
+
callbacks: List[LogCallback],
|
70 |
+
**kwargs
|
71 |
+
):
|
72 |
+
PPOTrainer.__init__(self, **kwargs)
|
73 |
+
self.args = training_args
|
74 |
+
self.finetuning_args = finetuning_args
|
75 |
+
self.log_callback = callbacks[0]
|
76 |
+
self.state = TrainerState()
|
77 |
+
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
78 |
+
|
79 |
+
def ppo_train(self, max_target_length: int) -> None:
|
80 |
+
r"""
|
81 |
+
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
82 |
+
"""
|
83 |
+
total_train_batch_size = self.config.batch_size * self.config.gradient_accumulation_steps * self.args.world_size
|
84 |
+
len_dataloader = len(self.dataloader)
|
85 |
+
num_steps_per_epoch = max(len_dataloader // self.config.gradient_accumulation_steps, 1)
|
86 |
+
num_examples = len(self.dataset)
|
87 |
+
num_train_epochs = self.args.num_train_epochs
|
88 |
+
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
|
89 |
+
|
90 |
+
self.state.max_steps = max_steps
|
91 |
+
self.state.num_train_epochs = num_train_epochs
|
92 |
+
self.state.is_local_process_zero = self.is_local_process_zero()
|
93 |
+
self.state.is_world_process_zero = self.is_world_process_zero()
|
94 |
+
|
95 |
+
if self.is_world_process_zero():
|
96 |
+
logger.info("***** Running training *****")
|
97 |
+
logger.info(f" Num examples = {num_examples}")
|
98 |
+
logger.info(f" Num Epochs = {num_train_epochs}")
|
99 |
+
logger.info(f" Instantaneous batch size per device = {self.config.batch_size}")
|
100 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
101 |
+
logger.info(f" Gradient Accumulation steps = {self.config.gradient_accumulation_steps}")
|
102 |
+
logger.info(f" Total optimization steps = {max_steps}")
|
103 |
+
logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
|
104 |
+
|
105 |
+
# Keyword arguments for `model.generate`
|
106 |
+
gen_kwargs = {
|
107 |
+
"top_k": 0.0,
|
108 |
+
"top_p": 1.0,
|
109 |
+
"do_sample": True,
|
110 |
+
"pad_token_id": self.tokenizer.pad_token_id,
|
111 |
+
"eos_token_id": self.tokenizer.eos_token_id,
|
112 |
+
"logits_processor": get_logits_processor()
|
113 |
+
}
|
114 |
+
output_length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
115 |
+
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
116 |
+
|
117 |
+
dataiter = iter(self.dataloader)
|
118 |
+
steps_trained = 0
|
119 |
+
loss_meter = AverageMeter()
|
120 |
+
reward_meter = AverageMeter()
|
121 |
+
|
122 |
+
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero()):
|
123 |
+
|
124 |
+
for _ in range(self.config.gradient_accumulation_steps):
|
125 |
+
|
126 |
+
batch = next(dataiter)
|
127 |
+
steps_trained += 1
|
128 |
+
|
129 |
+
unwrapped_model.gradient_checkpointing_disable()
|
130 |
+
unwrapped_model.config.use_cache = True
|
131 |
+
|
132 |
+
# Get response from model
|
133 |
+
query_tensors: torch.Tensor = batch["input_ids"]
|
134 |
+
response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
|
135 |
+
|
136 |
+
queries: List[torch.Tensor] = []
|
137 |
+
responses: List[torch.Tensor] = []
|
138 |
+
for i in range(len(query_tensors)):
|
139 |
+
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
140 |
+
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
141 |
+
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
142 |
+
if response_length < 2: # make response have at least 2 tokens
|
143 |
+
responses.append(response_tensors.new_empty(2).fill_(self.tokenizer.eos_token_id))
|
144 |
+
else:
|
145 |
+
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
146 |
+
|
147 |
+
# Compute rewards
|
148 |
+
replace_model(unwrapped_model, target="reward")
|
149 |
+
_, _, values = self.model(**self.prepare_model_inputs(queries, responses))
|
150 |
+
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
151 |
+
replace_model(unwrapped_model, target="default") # make sure the model is default at the end
|
152 |
+
|
153 |
+
# Run PPO step
|
154 |
+
unwrapped_model.gradient_checkpointing_enable()
|
155 |
+
unwrapped_model.config.use_cache = False
|
156 |
+
|
157 |
+
stats = self.step(queries, responses, rewards)
|
158 |
+
|
159 |
+
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
160 |
+
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
161 |
+
|
162 |
+
if steps_trained == len_dataloader:
|
163 |
+
dataiter = iter(self.dataloader)
|
164 |
+
steps_trained = 0
|
165 |
+
|
166 |
+
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
|
167 |
+
logs = {
|
168 |
+
"loss": round(loss_meter.avg, 4),
|
169 |
+
"reward": round(reward_meter.avg, 4),
|
170 |
+
"learning_rate": stats["ppo/learning_rate"],
|
171 |
+
"epoch": round(step / num_steps_per_epoch, 2)
|
172 |
+
}
|
173 |
+
print(logs)
|
174 |
+
logs["step"] = step
|
175 |
+
self.state.log_history.append(logs)
|
176 |
+
self.log_callback.on_log(self.args, self.state, None)
|
177 |
+
loss_meter.reset()
|
178 |
+
reward_meter.reset()
|
179 |
+
|
180 |
+
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
181 |
+
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
182 |
+
|
183 |
+
@torch.no_grad()
|
184 |
+
def generate(
|
185 |
+
self,
|
186 |
+
inputs: Dict[str, torch.Tensor],
|
187 |
+
length_sampler: Optional[Callable] = None,
|
188 |
+
return_prompt: Optional[bool] = True,
|
189 |
+
**generation_kwargs,
|
190 |
+
) -> torch.Tensor:
|
191 |
+
r"""
|
192 |
+
Generates model's responses given queries.
|
193 |
+
|
194 |
+
Subclass and override to inject custom behavior.
|
195 |
+
"""
|
196 |
+
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
197 |
+
|
198 |
+
if length_sampler is not None:
|
199 |
+
generation_kwargs["max_new_tokens"] = length_sampler()
|
200 |
+
|
201 |
+
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
202 |
+
|
203 |
+
response = unwrapped_model.generate(**inputs, **generation_kwargs)
|
204 |
+
|
205 |
+
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
206 |
+
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
207 |
+
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
208 |
+
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
209 |
+
|
210 |
+
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
211 |
+
|
212 |
+
if not return_prompt and not self.is_encoder_decoder:
|
213 |
+
return response[:, inputs["input_ids"].size(1):]
|
214 |
+
return response
|
215 |
+
|
216 |
+
def save_model(self, output_dir: Optional[str] = None) -> None:
|
217 |
+
r"""
|
218 |
+
Saves model checkpoint.
|
219 |
+
|
220 |
+
Subclass and override to inject custom behavior.
|
221 |
+
"""
|
222 |
+
if self.args.should_save:
|
223 |
+
self._save(output_dir)
|
src/utils/seq2seq.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import Dict, List, Sequence, Tuple, Union
|
6 |
+
|
7 |
+
from transformers.trainer import PredictionOutput
|
8 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
9 |
+
|
10 |
+
import jieba
|
11 |
+
from rouge_chinese import Rouge
|
12 |
+
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
|
13 |
+
|
14 |
+
from .peft_trainer import PeftTrainer
|
15 |
+
|
16 |
+
from .other import get_logger, IGNORE_INDEX
|
17 |
+
|
18 |
+
|
19 |
+
logger = get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class ComputeMetrics:
|
24 |
+
r"""
|
25 |
+
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
|
26 |
+
|
27 |
+
Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
|
28 |
+
"""
|
29 |
+
|
30 |
+
tokenizer: PreTrainedTokenizer
|
31 |
+
|
32 |
+
def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
33 |
+
r"""
|
34 |
+
Uses the model predictions to compute metrics.
|
35 |
+
"""
|
36 |
+
preds, labels = eval_preds
|
37 |
+
if isinstance(preds, tuple):
|
38 |
+
preds = preds[0]
|
39 |
+
# Replace IGNORE_INDEX in the labels with pad_token_id as we cannot decode them if ignore_pad_token_for_loss=True.
|
40 |
+
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
41 |
+
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
42 |
+
|
43 |
+
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
44 |
+
for pred, label in zip(preds, labels):
|
45 |
+
pred = pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] # remove the query
|
46 |
+
hypothesis = list(jieba.cut(self.tokenizer.decode(pred, skip_special_tokens=True)))
|
47 |
+
reference = list(jieba.cut(self.tokenizer.decode(label, skip_special_tokens=True)))
|
48 |
+
|
49 |
+
if len(" ".join(hypothesis).split()) == 0:
|
50 |
+
result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
|
51 |
+
else:
|
52 |
+
rouge = Rouge()
|
53 |
+
scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
|
54 |
+
result = scores[0]
|
55 |
+
|
56 |
+
for k, v in result.items():
|
57 |
+
score_dict[k].append(round(v["f"] * 100, 4))
|
58 |
+
|
59 |
+
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
60 |
+
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
61 |
+
|
62 |
+
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
63 |
+
|
64 |
+
|
65 |
+
class Seq2SeqPeftTrainer(PeftTrainer):
|
66 |
+
r"""
|
67 |
+
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def save_predictions(
|
71 |
+
self,
|
72 |
+
predict_results: PredictionOutput,
|
73 |
+
tokenizer: PreTrainedTokenizer
|
74 |
+
) -> None:
|
75 |
+
r"""
|
76 |
+
Saves model predictions to `output_dir`.
|
77 |
+
|
78 |
+
A custom behavior that not contained in Seq2SeqTrainer.
|
79 |
+
"""
|
80 |
+
if not self.is_world_process_zero():
|
81 |
+
return
|
82 |
+
|
83 |
+
preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
|
84 |
+
labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
|
85 |
+
|
86 |
+
preds = [pred[(pred == self.tokenizer.bos_token_id).nonzero()[0][0]:] for pred in preds] # remove the queries
|
87 |
+
preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
|
88 |
+
labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]
|
89 |
+
|
90 |
+
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
91 |
+
logger.info(f"Saving prediction results to {output_prediction_file}")
|
92 |
+
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
93 |
+
res: List[str] = []
|
94 |
+
for pred, label in zip(preds, labels):
|
95 |
+
res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
|
96 |
+
writer.write("\n".join(res))
|
src/utils/template.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Template:
|
7 |
+
|
8 |
+
name: str
|
9 |
+
|
10 |
+
def __post_init__(self):
|
11 |
+
|
12 |
+
if self.name == "vanilla":
|
13 |
+
r"""
|
14 |
+
Supports language model inference without histories.
|
15 |
+
"""
|
16 |
+
self._register_template(
|
17 |
+
prefix="",
|
18 |
+
prompt="{query}",
|
19 |
+
sep="",
|
20 |
+
use_history=False
|
21 |
+
)
|
22 |
+
|
23 |
+
elif self.name == "alpaca":
|
24 |
+
r"""
|
25 |
+
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
26 |
+
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
27 |
+
"""
|
28 |
+
self._register_template(
|
29 |
+
prefix="Below is an instruction that describes a task. "
|
30 |
+
"Write a response that appropriately completes the request.\n\n",
|
31 |
+
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
32 |
+
sep="\n\n",
|
33 |
+
use_history=True
|
34 |
+
)
|
35 |
+
|
36 |
+
elif self.name == "vicuna":
|
37 |
+
r"""
|
38 |
+
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
39 |
+
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
40 |
+
"""
|
41 |
+
self._register_template(
|
42 |
+
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
43 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
44 |
+
prompt="USER: {query} ASSISTANT: ",
|
45 |
+
sep="</s>",
|
46 |
+
use_history=True
|
47 |
+
)
|
48 |
+
|
49 |
+
elif self.name == "belle":
|
50 |
+
r"""
|
51 |
+
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
52 |
+
"""
|
53 |
+
self._register_template(
|
54 |
+
prefix="",
|
55 |
+
prompt="Human: {query}\n\nBelle: ",
|
56 |
+
sep="\n\n",
|
57 |
+
use_history=True
|
58 |
+
)
|
59 |
+
|
60 |
+
elif self.name == "linly":
|
61 |
+
r"""
|
62 |
+
Supports: https://github.com/CVI-SZU/Linly
|
63 |
+
"""
|
64 |
+
self._register_template(
|
65 |
+
prefix="",
|
66 |
+
prompt="User: {query}\nBot: ",
|
67 |
+
sep="\n",
|
68 |
+
use_history=True
|
69 |
+
)
|
70 |
+
|
71 |
+
elif self.name == "billa":
|
72 |
+
r"""
|
73 |
+
Supports: https://github.com/Neutralzz/BiLLa
|
74 |
+
"""
|
75 |
+
self._register_template(
|
76 |
+
prefix="",
|
77 |
+
prompt="Human: {query}\nAssistant: ",
|
78 |
+
sep="\n",
|
79 |
+
use_history=True
|
80 |
+
)
|
81 |
+
|
82 |
+
elif self.name == "ziya":
|
83 |
+
r"""
|
84 |
+
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
85 |
+
"""
|
86 |
+
self._register_template(
|
87 |
+
prefix="",
|
88 |
+
prompt="<human>:{query}\n<bot>:",
|
89 |
+
sep="\n",
|
90 |
+
use_history=True
|
91 |
+
)
|
92 |
+
|
93 |
+
elif self.name == "aquila":
|
94 |
+
r"""
|
95 |
+
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
96 |
+
"""
|
97 |
+
self._register_template(
|
98 |
+
prefix="A chat between a curious human and an artificial intelligence assistant. "
|
99 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
100 |
+
prompt="Human: {query}\nAssistant: ",
|
101 |
+
sep="###",
|
102 |
+
use_history=True
|
103 |
+
)
|
104 |
+
|
105 |
+
else:
|
106 |
+
raise ValueError("Template {} does not exist.".format(self.name))
|
107 |
+
|
108 |
+
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str:
|
109 |
+
r"""
|
110 |
+
Returns a string containing prompt without response.
|
111 |
+
"""
|
112 |
+
return "".join(self._format_example(query, history, prefix))
|
113 |
+
|
114 |
+
def get_dialog(self, query: str, resp: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
115 |
+
r"""
|
116 |
+
Returns a list containing 2 * n elements where the 2k-th is a query and the (2k+1)-th is a response.
|
117 |
+
"""
|
118 |
+
return self._format_example(query, history, prefix) + [resp]
|
119 |
+
|
120 |
+
def _register_template(self, prefix: str, prompt: str, sep: str, use_history: Optional[bool] = True) -> None:
|
121 |
+
self.prefix = prefix
|
122 |
+
self.prompt = prompt
|
123 |
+
self.sep = sep
|
124 |
+
self.use_history = use_history
|
125 |
+
|
126 |
+
def _format_example(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> List[str]:
|
127 |
+
prefix = prefix if prefix else self.prefix
|
128 |
+
history = history if (history and self.use_history) else []
|
129 |
+
history = history + [(query, "<dummy>")]
|
130 |
+
convs = []
|
131 |
+
for turn_idx, (user_query, bot_resp) in enumerate(history):
|
132 |
+
if turn_idx == 0:
|
133 |
+
convs.append(prefix + self.prompt.format(query=user_query))
|
134 |
+
convs.append(bot_resp)
|
135 |
+
else:
|
136 |
+
convs.append(self.sep + self.prompt.format(query=user_query))
|
137 |
+
convs.append(bot_resp)
|
138 |
+
return convs[:-1] # drop last
|
src/web_demo.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements user interface in browser for fine-tuned models.
|
3 |
+
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
4 |
+
|
5 |
+
|
6 |
+
import mdtex2html
|
7 |
+
import gradio as gr
|
8 |
+
|
9 |
+
from threading import Thread
|
10 |
+
from utils import (
|
11 |
+
Template,
|
12 |
+
load_pretrained,
|
13 |
+
prepare_infer_args,
|
14 |
+
get_logits_processor
|
15 |
+
)
|
16 |
+
|
17 |
+
from transformers import TextIteratorStreamer
|
18 |
+
from transformers.utils.versions import require_version
|
19 |
+
|
20 |
+
|
21 |
+
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0")
|
22 |
+
|
23 |
+
|
24 |
+
model_args, data_args, finetuning_args, generating_args = prepare_infer_args()
|
25 |
+
model, tokenizer = load_pretrained(model_args, finetuning_args)
|
26 |
+
|
27 |
+
prompt_template = Template(data_args.prompt_template)
|
28 |
+
|
29 |
+
|
30 |
+
def postprocess(self, y):
|
31 |
+
r"""
|
32 |
+
Overrides Chatbot.postprocess
|
33 |
+
"""
|
34 |
+
if y is None:
|
35 |
+
return []
|
36 |
+
for i, (message, response) in enumerate(y):
|
37 |
+
y[i] = (
|
38 |
+
None if message is None else mdtex2html.convert((message)),
|
39 |
+
None if response is None else mdtex2html.convert(response),
|
40 |
+
)
|
41 |
+
return y
|
42 |
+
|
43 |
+
|
44 |
+
gr.Chatbot.postprocess = postprocess
|
45 |
+
|
46 |
+
|
47 |
+
def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
|
48 |
+
lines = text.split("\n")
|
49 |
+
lines = [line for line in lines if line != ""]
|
50 |
+
count = 0
|
51 |
+
for i, line in enumerate(lines):
|
52 |
+
if "```" in line:
|
53 |
+
count += 1
|
54 |
+
items = line.split("`")
|
55 |
+
if count % 2 == 1:
|
56 |
+
lines[i] = "<pre><code class=\"language-{}\">".format(items[-1])
|
57 |
+
else:
|
58 |
+
lines[i] = "<br /></code></pre>"
|
59 |
+
else:
|
60 |
+
if i > 0:
|
61 |
+
if count % 2 == 1:
|
62 |
+
line = line.replace("`", "\`")
|
63 |
+
line = line.replace("<", "<")
|
64 |
+
line = line.replace(">", ">")
|
65 |
+
line = line.replace(" ", " ")
|
66 |
+
line = line.replace("*", "*")
|
67 |
+
line = line.replace("_", "_")
|
68 |
+
line = line.replace("-", "-")
|
69 |
+
line = line.replace(".", ".")
|
70 |
+
line = line.replace("!", "!")
|
71 |
+
line = line.replace("(", "(")
|
72 |
+
line = line.replace(")", ")")
|
73 |
+
line = line.replace("$", "$")
|
74 |
+
lines[i] = "<br />" + line
|
75 |
+
text = "".join(lines)
|
76 |
+
return text
|
77 |
+
|
78 |
+
|
79 |
+
def predict(query, chatbot, max_length, top_p, temperature, history):
|
80 |
+
chatbot.append((parse_text(query), ""))
|
81 |
+
|
82 |
+
input_ids = tokenizer([prompt_template.get_prompt(query, history)], return_tensors="pt")["input_ids"]
|
83 |
+
input_ids = input_ids.to(model.device)
|
84 |
+
|
85 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
86 |
+
|
87 |
+
gen_kwargs = {
|
88 |
+
"input_ids": input_ids,
|
89 |
+
"do_sample": generating_args.do_sample,
|
90 |
+
"top_p": top_p,
|
91 |
+
"temperature": temperature,
|
92 |
+
"num_beams": generating_args.num_beams,
|
93 |
+
"max_length": max_length,
|
94 |
+
"repetition_penalty": generating_args.repetition_penalty,
|
95 |
+
"logits_processor": get_logits_processor(),
|
96 |
+
"streamer": streamer
|
97 |
+
}
|
98 |
+
|
99 |
+
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
100 |
+
thread.start()
|
101 |
+
|
102 |
+
response = ""
|
103 |
+
for new_text in streamer:
|
104 |
+
response += new_text
|
105 |
+
new_history = history + [(query, response)]
|
106 |
+
chatbot[-1] = (parse_text(query), parse_text(response))
|
107 |
+
yield chatbot, new_history
|
108 |
+
|
109 |
+
|
110 |
+
def reset_user_input():
|
111 |
+
return gr.update(value="")
|
112 |
+
|
113 |
+
|
114 |
+
def reset_state():
|
115 |
+
return [], []
|
116 |
+
|
117 |
+
|
118 |
+
with gr.Blocks() as demo:
|
119 |
+
|
120 |
+
gr.HTML("""
|
121 |
+
<h1 align="center">
|
122 |
+
<a href="https://chato.cn/" target="_blank">
|
123 |
+
百姓AI助手
|
124 |
+
</a>
|
125 |
+
</h1>
|
126 |
+
""")
|
127 |
+
|
128 |
+
chatbot = gr.Chatbot()
|
129 |
+
|
130 |
+
with gr.Row():
|
131 |
+
with gr.Column(scale=4):
|
132 |
+
with gr.Column(scale=12):
|
133 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
|
134 |
+
with gr.Column(min_width=32, scale=1):
|
135 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
136 |
+
|
137 |
+
with gr.Column(scale=1):
|
138 |
+
emptyBtn = gr.Button("Clear History")
|
139 |
+
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True)
|
140 |
+
top_p = gr.Slider(0, 1, value=generating_args.top_p, step=0.01, label="Top P", interactive=True)
|
141 |
+
temperature = gr.Slider(0, 1.5, value=generating_args.temperature, step=0.01, label="Temperature", interactive=True)
|
142 |
+
|
143 |
+
history = gr.State([])
|
144 |
+
|
145 |
+
submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history], show_progress=True)
|
146 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
147 |
+
|
148 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
149 |
+
|
150 |
+
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True)
|