hqsiswiliam commited on
Commit
8359bb1
1 Parent(s): 824dd1d

Upload 43 files

Browse files
Files changed (43) hide show
  1. Figures/Exp.png +0 -0
  2. Figures/SelectivePromptTuning-SPT.png +0 -0
  3. README.md +112 -0
  4. config/all_values.yml +39 -0
  5. config/convai2/llama2-7b-selective-linear-both-prompt-causal-convai2-adding-target-noise.yml +34 -0
  6. config/convai2/llama2-7b-selective-linear-both-prompt-causal-convai2.yml +33 -0
  7. config/convai2/opt-1.3b-selective-linear-both-prompt-causal-convai2.yml +33 -0
  8. config/convai2/opt-125m-selective-linear-both-prompt-causal-convai2.yml +33 -0
  9. config/convai2/opt-2.7b-selective-linear-both-prompt-causal-convai2.yml +33 -0
  10. config/default.yml +16 -0
  11. dataset/__pycache__/dataset.cpython-310.pyc +0 -0
  12. dataset/__pycache__/dataset_helper.cpython-310.pyc +0 -0
  13. dataset/dataset.py +189 -0
  14. dataset/dataset_helper.py +117 -0
  15. ds_config.json +28 -0
  16. env.yml +257 -0
  17. evaluate_runs_results.py +150 -0
  18. evaluation.py +92 -0
  19. interactive_test.py +205 -0
  20. models/__pycache__/llm_chat.cpython-310.pyc +0 -0
  21. models/__pycache__/selective_llm_chat.cpython-310.pyc +0 -0
  22. models/llm_chat.py +227 -0
  23. models/selective_llm_chat.py +390 -0
  24. test.py +204 -0
  25. train.py +129 -0
  26. trainer/__init__.py +1 -0
  27. trainer/__pycache__/__init__.cpython-310.pyc +0 -0
  28. trainer/__pycache__/peft_trainer.cpython-310.pyc +0 -0
  29. trainer/peft_trainer.py +187 -0
  30. utils/__pycache__/config.cpython-310.pyc +0 -0
  31. utils/__pycache__/configure_optimizers.cpython-310.pyc +0 -0
  32. utils/__pycache__/dist_helper.cpython-310.pyc +0 -0
  33. utils/__pycache__/format_inputs.cpython-310.pyc +0 -0
  34. utils/__pycache__/model_helpers.cpython-310.pyc +0 -0
  35. utils/__pycache__/parser_helper.cpython-310.pyc +0 -0
  36. utils/__pycache__/seed_everything.cpython-310.pyc +0 -0
  37. utils/config.py +50 -0
  38. utils/configure_optimizers.py +6 -0
  39. utils/dist_helper.py +5 -0
  40. utils/format_inputs.py +173 -0
  41. utils/model_helpers.py +31 -0
  42. utils/parser_helper.py +17 -0
  43. utils/seed_everything.py +44 -0
Figures/Exp.png ADDED
Figures/SelectivePromptTuning-SPT.png ADDED
README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPT: Selective Prompting Tuning for Personalized Conversations with LLMs
2
+ Repo for `Selective Prompting Tuning for Personalized Conversations with LLMs`, the paper is available at: [Selective Prompting Tuning for Personalized Conversations with LLMs](https://openreview.net/pdf?id=Royo7My_EJ)
3
+ ## Introduction
4
+
5
+ In conversational AI, personalizing dialogues with persona profiles and contextual understanding is essential. Despite large language models' (LLMs) improved response coherence, effective persona integration remains a challenge. In this work, we first study two common approaches for personalizing LLMs: textual prompting and direct fine-tuning. We observed that textual prompting often struggles to yield responses that are similar to the ground truths in datasets, while direct fine-tuning tends to produce repetitive or overly generic replies. To alleviate those issues, we propose **S**elective **P**rompt **T**uning (SPT), which softly prompts LLMs for personalized conversations in a selective way. Concretely, SPT initializes a set of soft prompts and uses a trainable dense retriever to adaptively select suitable soft prompts for LLMs according to different input contexts, where the prompt retriever is dynamically updated through feedback from the LLMs. Additionally, we propose context-prompt contrastive learning and prompt fusion learning to encourage the SPT to enhance the diversity of personalized conversations. Experiments on the CONVAI2 dataset demonstrate that SPT significantly enhances response diversity by up to 90\%, along with improvements in other critical performance indicators. Those results highlight the efficacy of SPT in fostering engaging and personalized dialogue generation. The SPT model code is publicly available for further exploration.
6
+
7
+ ## Architecture
8
+ ![spt-arch](Figures/SelectivePromptTuning-SPT.png)
9
+
10
+ ## Experimental Results
11
+ ![exp](Figures/Exp.png)
12
+
13
+ ## Repo Details
14
+ ### Basic Project Structure
15
+ - `config`: contains all the configuration yml file from OPT-125M to Llama2-13B
16
+ - `data_file`: contains CONVAI2 dataset files, dataset can be donwloaded in this [Huggingface Repo](https://huggingface.co/hqsiswiliam/SPT)
17
+ - `dataset`: contains dataloader class and the pre-process methods
18
+ - `models`: contains SPT model classes
19
+ - `trainer`: contains trainer classes, responsible for model training & updating
20
+ - `utils`: provides helper classes and functions
21
+ - `test.py`: the entrance script for model decoding
22
+ - `train.py`: the entrance script for model training
23
+ ### Checkpoint downloading
24
+ - The trained checkpoint is located in `public_ckpt` from [Huggingface Repo](https://huggingface.co/hqsiswiliam/SPT)
25
+
26
+ ### Environment Initialization
27
+ #### Modifying `env.yml`
28
+ Since Deepspeed requires the CuDNN and CUDA, and we integrated Nvidia related tools in Anancoda, so it is essential to modify `env.yml`'s instance variable in the last two lines as:
29
+ ```yml
30
+ variables:
31
+ LD_LIBRARY_PATH: <CONDA_PATH>/envs/SPT/lib
32
+ LIBRARY_PATH: <CONDA_PATH>/envs/SPT/lib
33
+ ```
34
+ Please replace `<CONDA_PATH>` to your own actual conda installation path before importing the `env.yml` to your environment.
35
+ #### Environment Creation
36
+ The SPT's environment can be built using Anaconda (which we recommend), we provide the env.yml for environment creation:
37
+ ```bash
38
+ conda env create -f env.yml
39
+ ```
40
+
41
+ ```bash
42
+ conda activate SPT
43
+ ```
44
+ ## Model Training
45
+ Using following command to start training:
46
+ ```bash
47
+ deepspeed --num_nodes=1 train.py \
48
+ --config=config/convai2/opt-125m-selective-linear-both-prompt-causal-convai2.yml \
49
+ --batch=2 \
50
+ --lr=0.0001 \
51
+ --epoch=1 \
52
+ --save_model=yes \
53
+ --num_workers=0 \
54
+ --training_ratio=1.0 \
55
+ --log_dir=runs_ds_dev \
56
+ --deepspeed \
57
+ --deepspeed_config ds_config.json
58
+ ```
59
+ You can adjust `--num_nodes` if you have multiple GPUs in one node
60
+ ### Main Arguments
61
+ - `config`: the training configuration file
62
+ - `batch`: the batch size per GPU
63
+ - `lr`: learning rate
64
+ - `epoch`: epoch number
65
+ - `save_model`: whether to save model
66
+ - `training_ratio`: the percentage of data used for training, 1.0 means 100%
67
+ - `log_dir`: the log and model save directory
68
+ - `deepspeed & --deepspeed_config`: the necessary arguments for initialize deepspeed
69
+ - `selective_loss_weight`: weight for selection loss
70
+ - `contrastive_weight`: weight for contrastive loss
71
+ ## Model Inference
72
+ Model inference can be easily invoked by using the following command:
73
+ ```bash
74
+ deepspeed test.py \
75
+ --model_path=public_ckpt/OPT-125M-SPT \
76
+ --batch_size=16 \
77
+ --skip_exists=no \
78
+ --deepspeed \
79
+ --deepspeed_config ds_config.json
80
+ ```
81
+ ### Main Arguments
82
+ - `model_path`: the path to the checkpoint, containing the `ds_ckpt` folder
83
+ - `skip_exists`: whether to skip decoding if `evaluation_result.txt` exists
84
+
85
+ ## Computing Metrics for Generation Results
86
+ To compute the metric for the evaluation results, simply run:
87
+
88
+ `python evaluate_runs_results.py`
89
+
90
+ The input path can be changed in the script via:
91
+ ```python
92
+ _main_path = 'public_ckpt'
93
+ ```
94
+
95
+ ## Interactive Testing
96
+ Also, we support interactive testing via:
97
+ ```bash
98
+ deepspeed interactive_test.py \
99
+ --model_path=public_ckpt/Llama2-7B-SPT \
100
+ --batch_size=1 \
101
+ --deepspeed \
102
+ --deepspeed_config ds_config.json
103
+ ```
104
+ So an interactive interface will be invoked as:
105
+
106
+ Some shortcut keys:
107
+ - `exit`: exiting the interactive shell
108
+ - `clear`: clear the current dialog history
109
+ - `r`: reload SPT's persona
110
+
111
+ ## Citation
112
+ Will be available soon.
config/all_values.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_type: 'selective_pt'
3
+ model_name: "facebook/opt-125m"
4
+ load_bit: 32
5
+ peft_type: "prompt_tuning"
6
+ K: 4
7
+ peft_config:
8
+ num_virtual_tokens: 8
9
+ normalizer: linear
10
+ normalizer_on: ['prompt', 'lm']
11
+ retriever:
12
+ retriever_on: ['extra', 'lm']
13
+ retriever_type: transformer_encoder
14
+ n_head: 4
15
+ num_layers: 2
16
+
17
+ training:
18
+ learning_rate: 1e-5
19
+ batch_size: 32
20
+ num_epochs: 1
21
+ mode: causal
22
+ only_longest: True
23
+ task_type: generate_response
24
+ log_dir: runs_prompt_selective_linear
25
+ contrastive: true
26
+ ensemble: true
27
+ selective_loss_weight: 0.4
28
+ contrastive_metric: bleu
29
+ contrastive_threshold: 20.0
30
+ contrastive_weight: 0.4
31
+ freeze_persona: yes
32
+ freeze_context: yes
33
+
34
+
35
+ dataset:
36
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
37
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
38
+ max_context_turns: -1
39
+ max_token_length: 512
config/convai2/llama2-7b-selective-linear-both-prompt-causal-convai2-adding-target-noise.yml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_type: 'selective_pt'
3
+ model_name: "Llama-2-7b-chat-hf"
4
+ load_bit: 16
5
+ peft_type: "prompt_tuning"
6
+ K: 4
7
+ peft_config:
8
+ num_virtual_tokens: 1
9
+ normalizer: linear
10
+ normalizer_on: ['prompt', 'lm']
11
+
12
+ training:
13
+ learning_rate: 1e-5
14
+ batch_size: 32
15
+ num_epochs: 1
16
+ mode: causal
17
+ adding_noise: 0.1
18
+ only_longest: False
19
+ task_type: generate_response
20
+ log_dir: runs_prompt_convai2_selective_linear
21
+ contrastive: true
22
+ ensemble: true
23
+ selective_loss_weight: 1.0
24
+ contrastive_metric: bleu
25
+ contrastive_threshold: 20.0
26
+ contrastive_weight: 1.0
27
+ freeze_persona: yes
28
+ freeze_context: yes
29
+
30
+ dataset:
31
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
32
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
33
+ max_context_turns: -1
34
+ max_token_length: 512
config/convai2/llama2-7b-selective-linear-both-prompt-causal-convai2.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_type: 'selective_pt'
3
+ model_name: "Llama-2-7b-chat-hf"
4
+ load_bit: 16
5
+ peft_type: "prompt_tuning"
6
+ K: 4
7
+ peft_config:
8
+ num_virtual_tokens: 1
9
+ normalizer: linear
10
+ normalizer_on: ['prompt', 'lm']
11
+
12
+ training:
13
+ learning_rate: 1e-5
14
+ batch_size: 32
15
+ num_epochs: 1
16
+ mode: causal
17
+ only_longest: False
18
+ task_type: generate_response
19
+ log_dir: runs_prompt_convai2_selective_linear
20
+ contrastive: true
21
+ ensemble: true
22
+ selective_loss_weight: 1.0
23
+ contrastive_metric: bleu
24
+ contrastive_threshold: 20.0
25
+ contrastive_weight: 1.0
26
+ freeze_persona: yes
27
+ freeze_context: yes
28
+
29
+ dataset:
30
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
31
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
32
+ max_context_turns: -1
33
+ max_token_length: 512
config/convai2/opt-1.3b-selective-linear-both-prompt-causal-convai2.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_type: 'selective_pt'
3
+ model_name: "facebook/opt-1.3b"
4
+ load_bit: 32
5
+ peft_type: "prompt_tuning"
6
+ K: 4
7
+ peft_config:
8
+ num_virtual_tokens: 8
9
+ normalizer: linear
10
+ normalizer_on: ['prompt', 'lm']
11
+
12
+ training:
13
+ learning_rate: 1e-5
14
+ batch_size: 32
15
+ num_epochs: 1
16
+ mode: causal
17
+ only_longest: False
18
+ task_type: generate_response
19
+ log_dir: runs_prompt_convai2_selective_linear
20
+ contrastive: true
21
+ ensemble: true
22
+ selective_loss_weight: 0.4
23
+ contrastive_metric: bleu
24
+ contrastive_threshold: 20.0
25
+ contrastive_weight: 0.4
26
+ freeze_persona: yes
27
+ freeze_context: yes
28
+
29
+ dataset:
30
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
31
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
32
+ max_context_turns: -1
33
+ max_token_length: 512
config/convai2/opt-125m-selective-linear-both-prompt-causal-convai2.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_type: 'selective_pt'
3
+ model_name: "facebook/opt-125m"
4
+ load_bit: 32
5
+ peft_type: "prompt_tuning"
6
+ K: 4
7
+ peft_config:
8
+ num_virtual_tokens: 8
9
+ normalizer: linear
10
+ normalizer_on: ['prompt', 'lm']
11
+
12
+ training:
13
+ learning_rate: 1e-5
14
+ batch_size: 32
15
+ num_epochs: 1
16
+ mode: causal
17
+ only_longest: False
18
+ task_type: generate_response
19
+ log_dir: runs_prompt_convai2_selective_linear
20
+ contrastive: true
21
+ ensemble: true
22
+ selective_loss_weight: 0.4
23
+ contrastive_metric: bleu
24
+ contrastive_threshold: 20.0
25
+ contrastive_weight: 0.4
26
+ freeze_persona: yes
27
+ freeze_context: yes
28
+
29
+ dataset:
30
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
31
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
32
+ max_context_turns: -1
33
+ max_token_length: 512
config/convai2/opt-2.7b-selective-linear-both-prompt-causal-convai2.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ model_type: 'selective_pt'
3
+ model_name: "facebook/opt-2.7b"
4
+ load_bit: 16
5
+ peft_type: "prompt_tuning"
6
+ K: 4
7
+ peft_config:
8
+ num_virtual_tokens: 8
9
+ normalizer: linear
10
+ normalizer_on: ['prompt', 'lm']
11
+
12
+ training:
13
+ learning_rate: 1e-5
14
+ batch_size: 32
15
+ num_epochs: 1
16
+ mode: causal
17
+ only_longest: False
18
+ task_type: generate_response
19
+ log_dir: runs_prompt_convai2_selective_linear
20
+ contrastive: true
21
+ ensemble: true
22
+ selective_loss_weight: 0.4
23
+ contrastive_metric: bleu
24
+ contrastive_threshold: 20.0
25
+ contrastive_weight: 0.4
26
+ freeze_persona: yes
27
+ freeze_context: yes
28
+
29
+ dataset:
30
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
31
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
32
+ max_context_turns: -1
33
+ max_token_length: 512
config/default.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ train: data_file/ConvAI2/train_self_original_no_cands.txt
3
+ valid: data_file/ConvAI2/valid_self_original_no_cands.txt
4
+ max_context_turns: -1
5
+ max_token_length: 512
6
+
7
+ model:
8
+ score_activation: 'softplus'
9
+
10
+ training:
11
+ mode: normal
12
+ only_longest: False
13
+ task_type: generate_response
14
+ ensemble: false
15
+ tau_gold: 1.0
16
+ tau_sim: 1.0
dataset/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (5.37 kB). View file
 
dataset/__pycache__/dataset_helper.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
dataset/dataset.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pytorch_lightning import LightningDataModule
3
+ from torch.utils.data import DataLoader
4
+
5
+ from dataset.dataset_helper import read_personachat_split
6
+ from utils.format_inputs import TASK_TYPE
7
+
8
+
9
+ class PersonaChatDataset(torch.utils.data.Dataset):
10
+ # longest first for batch finder
11
+ def __init__(self, data_path, max_context_turns=-1,
12
+ add_role_indicator=True, only_longest=False, training_ratio=1.0,
13
+ task_type=TASK_TYPE.GENERATE_RESPONSE):
14
+ self.path = data_path
15
+ self.add_role_indicator = add_role_indicator
16
+ self.max_context_turns = max_context_turns
17
+ self.turns_data = read_personachat_split(data_path, only_longest=only_longest)
18
+ self.only_longest = only_longest
19
+ self.training_ratio = training_ratio
20
+ if training_ratio < 1.0:
21
+ self.turns_data = self.turns_data[:int(len(self.turns_data) * training_ratio)]
22
+ self.task_type = task_type
23
+ # # For debug only
24
+ # os.makedirs("data_logs", exist_ok=True)
25
+ # random_num = random.randint(0, 100000)
26
+ # self.file = open(f"data_logs/{random_num}_{data_path.split(os.sep)[-1]}", 'w')
27
+ # # add id to turns_data
28
+ # self.turns_data = [{'id': idx, **turn} for idx, turn in enumerate(self.turns_data)]
29
+ # self.file.write(f"total_turns: {len(self.turns_data)}\n")
30
+
31
+ def sort_longest_first(self):
32
+ self.turns_data = sorted(self.turns_data, key=lambda x: len(
33
+ (' '.join(x['persona']) + ' '.join(x['context']) + x['response']).split(' ')), reverse=True)
34
+
35
+ def __getitem__(self, idx):
36
+ # self.file.write(str(idx) + "\n")
37
+ # self.file.flush()
38
+ input_data = self.turns_data[idx]
39
+ persona_list = input_data['persona']
40
+ target = input_data['response']
41
+ context_input = input_data['context']
42
+ if self.add_role_indicator:
43
+ roled_context_input = [['Q: ', 'R: '][c_idx % 2] + context for c_idx, context in enumerate(context_input)]
44
+ context_input = roled_context_input
45
+ if self.max_context_turns != -1:
46
+ truncated_context = context_input[-(self.max_context_turns * 2 - 1):]
47
+ context_input = truncated_context
48
+ if self.only_longest:
49
+ context_input = context_input[:-1]
50
+ return {
51
+ 'context_input': context_input,
52
+ 'persona_list': persona_list,
53
+ 'target': target
54
+ }
55
+
56
+ def __len__(self):
57
+ return len(self.turns_data)
58
+
59
+
60
+ # class HGPersonaChatDataset(PersonaChatDataset):
61
+ # def __init__(self, data_path, max_context_turns=-1,
62
+ # add_role_indicator=True, only_longest=False, tokenizer=None):
63
+ # super().__init__(data_path, max_context_turns, add_role_indicator, only_longest)
64
+ # self.tokenizer = tokenizer
65
+ #
66
+ # def __getitem__(self, idx):
67
+ # data = super().__getitem__(idx)
68
+ # input = "P: " + ' '.join(data['persona_list']) + " C: " + ' '.join(data['context_input']) + " R: " + data[
69
+ # 'target']
70
+ # tokenized = self.tokenizer(input)
71
+ # return {**data, **tokenized}
72
+
73
+
74
+ def collate_fn(sample_list):
75
+ dont_be_a_tensor = ['context_input', 'persona_list', 'target']
76
+ to_be_flattened = [*dont_be_a_tensor]
77
+ data = {}
78
+ for key in to_be_flattened:
79
+ if key not in sample_list[0].keys():
80
+ continue
81
+ if sample_list[0][key] is None:
82
+ continue
83
+ flatten_samples = [sample[key] for sample in sample_list]
84
+ if flatten_samples[-1].__class__ == str or key in dont_be_a_tensor:
85
+ data[key] = flatten_samples
86
+ else:
87
+ data[key] = torch.tensor(flatten_samples)
88
+ return data
89
+
90
+
91
+ def collate_fn_straight(sample_list):
92
+ sample_list = collate_fn(sample_list)
93
+ return sample_list
94
+
95
+
96
+ def collate_fn_straight_with_fn(fn):
97
+ def build_collate_fn(sample_list):
98
+ sample_list = collate_fn(sample_list)
99
+ sample_list_processed = fn(sample_list)
100
+ return {**sample_list, **sample_list_processed}
101
+
102
+ return build_collate_fn
103
+
104
+
105
+ def get_dataloader(dataset, batch_size, shuffle=False, num_workers=None, collate_fn=None, sampler=None):
106
+ if num_workers is None:
107
+ num_workers = batch_size // 4
108
+ # num_workers = min(num_workers, batch_size)
109
+ if collate_fn == None:
110
+ _collate_fn = collate_fn_straight
111
+ else:
112
+ _collate_fn = collate_fn_straight_with_fn(collate_fn)
113
+ return DataLoader(dataset, batch_size=batch_size,
114
+ collate_fn=_collate_fn,
115
+ shuffle=shuffle,
116
+ num_workers=num_workers,
117
+ sampler=sampler)
118
+
119
+
120
+ def get_lightening_dataloader(dataset, batch_size, shuffle=False, num_workers=None):
121
+ return LitDataModule(batch_size, dataset, shuffle, num_workers)
122
+
123
+
124
+ class LitDataModule(LightningDataModule):
125
+ def __init__(self, batch_size, dataset, shuffle, num_workers):
126
+ super().__init__()
127
+ self.save_hyperparameters(ignore=['dataset'])
128
+ # or
129
+ self.batch_size = batch_size
130
+ self.dataset = dataset
131
+
132
+ def train_dataloader(self):
133
+ return DataLoader(self.dataset, batch_size=self.batch_size,
134
+ collate_fn=collate_fn_straight,
135
+ shuffle=self.hparams.shuffle,
136
+ num_workers=self.hparams.num_workers)
137
+
138
+ if __name__ == '__main__':
139
+ import json
140
+ train_ds = PersonaChatDataset(data_path='data_file/ConvAI2/train_self_original_no_cands.txt',
141
+ )
142
+ from tqdm import tqdm
143
+
144
+ jsonfy_data = []
145
+
146
+ for data in tqdm(train_ds):
147
+ context_input = "\n".join(data['context_input'])
148
+ persona_input = '\n'.join(data['persona_list'])
149
+ jsonfy_data.append({
150
+ "instruction": f"""Given the dialog history between Q and R is:
151
+ {context_input}
152
+
153
+ Given the personality of the R as:
154
+ {persona_input}
155
+
156
+ Please response to Q according to both the dialog history and the R's personality.
157
+ Now, the R would say:""",
158
+ "input": "",
159
+ "output": data['target'],
160
+ "answer": "",
161
+ })
162
+ with open('data_file/train.json', 'w') as writer:
163
+ json.dump(jsonfy_data, writer)
164
+ jsonfy_data = []
165
+ del train_ds
166
+
167
+ train_ds = PersonaChatDataset(data_path='data_file/ConvAI2/valid_self_original_no_cands.txt',
168
+ )
169
+
170
+ for data in tqdm(train_ds):
171
+ context_input = "\n".join(data['context_input'])
172
+ persona_input = '\n'.join(data['persona_list'])
173
+ jsonfy_data.append({
174
+ "instruction": f"""Given the dialog history between Q and R is:
175
+ {context_input}
176
+
177
+ Given the personality of the R as:
178
+ {persona_input}
179
+
180
+ Please response to Q according to both the dialog history and the R's personality.
181
+ Now, the R would say:""",
182
+ "input": "",
183
+ "output": data['target'],
184
+ "answer": "",
185
+ })
186
+ with open('data_file/valid.json', 'w') as writer:
187
+ json.dump(jsonfy_data, writer)
188
+ with open('data_file/test.json', 'w') as writer:
189
+ json.dump(jsonfy_data, writer)
dataset/dataset_helper.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ from tqdm import tqdm
4
+
5
+
6
+ def read_personachat_split(split_dir, only_longest=False):
7
+ results = []
8
+ their_per_group = None
9
+ try:
10
+ file = open(split_dir, 'r')
11
+ lines = file.readlines()
12
+ persona = []
13
+ context = []
14
+ response = None
15
+ candidates = []
16
+ is_longest = False
17
+ for line in tqdm(lines[:], desc='loading {}'.format(split_dir)):
18
+ if line.startswith('1 your persona:'):
19
+ is_longest = True
20
+ if is_longest and only_longest:
21
+ if response is not None:
22
+ results.append({'persona': persona.copy(), 'context': context.copy(), 'response': response,
23
+ 'candidates': candidates.copy()})
24
+ is_longest = False
25
+ persona = []
26
+ context = []
27
+ if 'persona:' in line:
28
+ persona.append(line.split(':')[1].strip())
29
+ if 'persona:' not in line:
30
+ context.append(re.sub(r"^\d+ ", "", line.split("\t")[0].strip()))
31
+ response = line.split("\t")[1].strip()
32
+ if len(line.split("\t\t"))==1:
33
+ candidates = []
34
+ else:
35
+ candidates = line.split("\t\t")[1].strip().split("|")
36
+ if not only_longest:
37
+ results.append({'persona': persona.copy(), 'context': context.copy(), 'response': response,
38
+ 'candidates': candidates.copy()})
39
+ context.append(response)
40
+ except FileNotFoundError:
41
+ print(f"Sorry! The file {split_dir} can't be found.")
42
+ return results
43
+
44
+
45
+ def combine_persona_query_response(persona, query, response, candidates):
46
+ assert ((len(persona) == len(query)) and (len(query) == len(response))), \
47
+ 'the length of persona, query, response must be equivalent'
48
+ data = {}
49
+ for index, psn in enumerate(persona):
50
+ split_persona = psn.strip().split("\t")
51
+ psn = psn.replace("\t", " ").strip()
52
+ if psn not in data.keys():
53
+ data[psn] = {'persona': psn, 'query': [], 'response': [], 'dialog': [], 'response_turns': 0,
54
+ 'persona_list': split_persona, 'candidates': []}
55
+ data[psn]['query'].append(query[index])
56
+ data[psn]['response'].append(response[index])
57
+ data[psn]['dialog'].append(query[index])
58
+ data[psn]['dialog'].append(response[index])
59
+ data[psn]['candidates'].append(candidates[index])
60
+ data[psn]['response_turns'] += 1
61
+ return data
62
+
63
+
64
+ def preprocess_text(text):
65
+ punctuations = '.,?'
66
+ for punc in punctuations:
67
+ text = text.replace(punc, ' {} '.format(punc))
68
+ text = re.sub(' +', ' ', text).strip()
69
+ return text
70
+
71
+
72
+ def preprocess_texts(text_array):
73
+ return [preprocess_text(t) for t in text_array]
74
+
75
+
76
+ # "turns" means we need at least how many turns
77
+ # "max_context_turns" means how many history turns should be kept
78
+ def get_chat_by_turns(combined_data, turns=1,
79
+ sep_token='[SEP]', add_role_indicator=True,
80
+ add_persona_indicator=True, max_context_turns=-1):
81
+ assert turns > 0, 'turns must be large than 0'
82
+ all_persona = list(combined_data.keys())
83
+ filtered_persona = list(filter(lambda p: combined_data[p]['response_turns'] >= turns, all_persona))
84
+ data = []
85
+
86
+ for single_persona in filtered_persona:
87
+ single_persona_data = combined_data[single_persona]
88
+ persona_list = single_persona_data['persona_list']
89
+ context = []
90
+ for index, (query, response) in enumerate(
91
+ zip(single_persona_data['query'], single_persona_data['response'])
92
+ ):
93
+ if max_context_turns != -1 and \
94
+ index + 1 < single_persona_data['response_turns'] - max_context_turns:
95
+ continue
96
+ if add_role_indicator:
97
+ query = "Q: {}".format(query)
98
+ if not index + 1 >= turns:
99
+ response = "R: {}".format(response)
100
+ context += [query, response]
101
+ if index + 1 >= turns:
102
+ break
103
+
104
+ response = context[-1]
105
+ context = context[:-1]
106
+
107
+ input_x_str = " {} ".format(sep_token).join(context)
108
+ input_x_str = re.sub(" +", " ", input_x_str)
109
+ if add_persona_indicator:
110
+ single_persona = "P: {}".format(single_persona)
111
+ data.append({'input': preprocess_texts(context),
112
+ 'input_str': preprocess_text(input_x_str),
113
+ 'target': preprocess_text(response),
114
+ 'persona': preprocess_text(single_persona),
115
+ 'persona_list': preprocess_texts(persona_list),
116
+ 'candidates': preprocess_texts(single_persona_data['candidates'][-1])})
117
+ return data
ds_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_micro_batch_size_per_gpu ": 1,
3
+ "gradient_accumulation_steps": 1,
4
+ "optimizer": {
5
+ "type": "Adam",
6
+ "params": {
7
+ "lr": 0.00015
8
+ }
9
+ },
10
+ "bf16": {
11
+ "enabled": false
12
+ },
13
+ "float16": {
14
+ "enabled": false
15
+ },
16
+
17
+ "zero_optimization": {
18
+ "stage": 2,
19
+ "offload_param": {
20
+ "device": "cpu",
21
+ "pin_memory": true,
22
+ "buffer_count": 5,
23
+ "buffer_size": 1e8,
24
+ "max_in_cpu": 1e9
25
+ }
26
+ }
27
+
28
+ }
env.yml ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: SPT
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - nvidia/label/cuda-11.8.0
6
+ - anaconda
7
+ - defaults
8
+ dependencies:
9
+ - _libgcc_mutex=0.1=main
10
+ - _openmp_mutex=5.1=1_gnu
11
+ - blas=1.0=mkl
12
+ - brotlipy=0.7.0=py310h7f8727e_1002
13
+ - bzip2=1.0.8=h7b6447c_0
14
+ - ca-certificates=2023.08.22=h06a4308_0
15
+ - certifi=2023.11.17=py310h06a4308_0
16
+ - cffi=1.15.1=py310h5eee18b_3
17
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
18
+ - cryptography=39.0.1=py310h9ce1e76_2
19
+ - cuda-cccl=11.8.89=0
20
+ - cuda-compiler=11.8.0=0
21
+ - cuda-cudart=11.8.89=0
22
+ - cuda-cudart-dev=11.8.89=0
23
+ - cuda-cuobjdump=11.8.86=0
24
+ - cuda-cupti=11.8.87=0
25
+ - cuda-cuxxfilt=11.8.86=0
26
+ - cuda-libraries=11.8.0=0
27
+ - cuda-nvcc=11.8.89=0
28
+ - cuda-nvprune=11.8.86=0
29
+ - cuda-nvrtc=11.8.89=0
30
+ - cuda-nvtx=11.8.86=0
31
+ - cuda-runtime=11.8.0=0
32
+ - cudatoolkit=11.8.0=h6a678d5_0
33
+ - ffmpeg=4.3=hf484d3e_0
34
+ - filelock=3.9.0=py310h06a4308_0
35
+ - freetype=2.12.1=h4a9f257_0
36
+ - giflib=5.2.1=h5eee18b_3
37
+ - gmp=6.2.1=h295c915_3
38
+ - gmpy2=2.1.2=py310heeb90bb_0
39
+ - gnutls=3.6.15=he1e5248_0
40
+ - idna=3.4=py310h06a4308_0
41
+ - intel-openmp=2023.1.0=hdb19cb5_46305
42
+ - jinja2=3.1.2=py310h06a4308_0
43
+ - jpeg=9e=h5eee18b_1
44
+ - lame=3.100=h7b6447c_0
45
+ - lcms2=2.12=h3be6417_0
46
+ - ld_impl_linux-64=2.38=h1181459_1
47
+ - lerc=3.0=h295c915_0
48
+ - libcublas=11.11.3.6=0
49
+ - libcublas-dev=11.11.3.6=0
50
+ - libcufft=10.9.0.58=0
51
+ - libcufile=1.6.1.9=0
52
+ - libcurand=10.3.2.106=0
53
+ - libcusolver=11.4.1.48=0
54
+ - libcusolver-dev=11.4.1.48=0
55
+ - libcusparse=11.7.5.86=0
56
+ - libcusparse-dev=11.7.5.86=0
57
+ - libdeflate=1.17=h5eee18b_0
58
+ - libffi=3.4.4=h6a678d5_0
59
+ - libgcc-ng=11.2.0=h1234567_1
60
+ - libgomp=11.2.0=h1234567_1
61
+ - libiconv=1.16=h7f8727e_2
62
+ - libidn2=2.3.4=h5eee18b_0
63
+ - libnpp=11.8.0.86=0
64
+ - libnvjpeg=11.9.0.86=0
65
+ - libpng=1.6.39=h5eee18b_0
66
+ - libstdcxx-ng=11.2.0=h1234567_1
67
+ - libtasn1=4.19.0=h5eee18b_0
68
+ - libtiff=4.5.0=h6a678d5_2
69
+ - libunistring=0.9.10=h27cfd23_0
70
+ - libuuid=1.41.5=h5eee18b_0
71
+ - libwebp=1.2.4=h11a3e52_1
72
+ - libwebp-base=1.2.4=h5eee18b_1
73
+ - lz4-c=1.9.4=h6a678d5_0
74
+ - markupsafe=2.1.1=py310h7f8727e_0
75
+ - mkl=2023.1.0=h6d00ec8_46342
76
+ - mkl-service=2.4.0=py310h5eee18b_1
77
+ - mkl_fft=1.3.6=py310h1128e8f_1
78
+ - mkl_random=1.2.2=py310h1128e8f_1
79
+ - mpc=1.1.0=h10f8cd9_1
80
+ - mpfr=4.0.2=hb69a4c5_1
81
+ - ncurses=6.4=h6a678d5_0
82
+ - nettle=3.7.3=hbbd107a_1
83
+ - networkx=2.8.4=py310h06a4308_1
84
+ - numpy=1.25.0=py310h5f9d8c6_0
85
+ - numpy-base=1.25.0=py310hb5e798b_0
86
+ - openh264=2.1.1=h4ff587b_0
87
+ - openssl=3.0.12=h7f8727e_0
88
+ - pillow=9.4.0=py310h6a678d5_0
89
+ - pip=23.1.2=py310h06a4308_0
90
+ - pycparser=2.21=pyhd3eb1b0_0
91
+ - pyopenssl=23.0.0=py310h06a4308_0
92
+ - pysocks=1.7.1=py310h06a4308_0
93
+ - python=3.10.11=h955ad1f_3
94
+ - pytorch-cuda=11.8=h7e8668a_5
95
+ - pytorch-mutex=1.0=cuda
96
+ - readline=8.2=h5eee18b_0
97
+ - requests=2.29.0=py310h06a4308_0
98
+ - setuptools=67.8.0=py310h06a4308_0
99
+ - sqlite=3.41.2=h5eee18b_0
100
+ - sympy=1.11.1=py310h06a4308_0
101
+ - tbb=2021.8.0=hdb19cb5_0
102
+ - tk=8.6.12=h1ccaba5_0
103
+ - torchtriton=2.0.0=py310
104
+ - typing_extensions=4.6.3=py310h06a4308_0
105
+ - urllib3=1.26.16=py310h06a4308_0
106
+ - wheel=0.38.4=py310h06a4308_0
107
+ - xz=5.4.2=h5eee18b_0
108
+ - zlib=1.2.13=h5eee18b_0
109
+ - zstd=1.5.5=hc292b87_0
110
+ - pip:
111
+ - absl-py==1.4.0
112
+ - accelerate==0.20.3
113
+ - aiohttp==3.8.4
114
+ - aiosignal==1.3.1
115
+ - annotated-types==0.6.0
116
+ - asttokens==2.2.1
117
+ - astunparse==1.6.3
118
+ - async-timeout==4.0.2
119
+ - attrs==23.1.0
120
+ - backcall==0.2.0
121
+ - bert-score==0.3.13
122
+ - bitsandbytes==0.41.0
123
+ - bleurt==0.0.2
124
+ - brotli==1.1.0
125
+ - cachetools==5.3.1
126
+ - click==8.1.7
127
+ - cmake==3.25.0
128
+ - colorama==0.4.6
129
+ - contourpy==1.1.0
130
+ - cycler==0.11.0
131
+ - datasets==2.13.1
132
+ - decorator==5.1.1
133
+ - deepspeed==0.12.6
134
+ - dill==0.3.6
135
+ - dotmap==1.3.30
136
+ - emoji==2.10.0
137
+ - evaluate==0.4.1
138
+ - executing==1.2.0
139
+ - flatbuffers==24.3.25
140
+ - fonttools==4.40.0
141
+ - frozenlist==1.3.3
142
+ - fsspec==2023.6.0
143
+ - gast==0.5.4
144
+ - google-auth==2.21.0
145
+ - google-auth-oauthlib==1.0.0
146
+ - google-pasta==0.2.0
147
+ - grpcio==1.56.0
148
+ - h5py==3.10.0
149
+ - hjson==3.1.0
150
+ - huggingface-hub==0.15.1
151
+ - inflate64==1.0.0
152
+ - ipython==8.14.0
153
+ - jedi==0.19.0
154
+ - joblib==1.3.2
155
+ - jsonlines==3.1.0
156
+ - keras==3.1.1
157
+ - kiwisolver==1.4.4
158
+ - libclang==18.1.1
159
+ - lightning-utilities==0.9.0
160
+ - lit==15.0.7
161
+ - loralib==0.1.1
162
+ - lxml==4.9.2
163
+ - markdown==3.4.3
164
+ - markdown-it-py==3.0.0
165
+ - matplotlib==3.7.2
166
+ - matplotlib-inline==0.1.6
167
+ - mdurl==0.1.2
168
+ - ml-dtypes==0.3.2
169
+ - mpmath==1.2.1
170
+ - multidict==6.0.4
171
+ - multiprocess==0.70.14
172
+ - multivolumefile==0.2.3
173
+ - namex==0.0.7
174
+ - ninja==1.11.1.1
175
+ - nltk==3.8.1
176
+ - nvidia-cuda-runtime-cu11==11.7.99
177
+ - oauthlib==3.2.2
178
+ - openai==0.27.8
179
+ - opencv-python==4.9.0.80
180
+ - opt-einsum==3.3.0
181
+ - optree==0.11.0
182
+ - packaging==23.1
183
+ - pandas==2.0.2
184
+ - parso==0.8.3
185
+ - peft==0.3.0
186
+ - pexpect==4.8.0
187
+ - pickleshare==0.7.5
188
+ - portalocker==2.7.0
189
+ - prompt-toolkit==3.0.39
190
+ - protobuf==4.23.3
191
+ - psutil==5.9.5
192
+ - ptyprocess==0.7.0
193
+ - pure-eval==0.2.2
194
+ - py-cpuinfo==9.0.0
195
+ - py7zr==0.20.8
196
+ - pyarrow==12.0.1
197
+ - pyasn1==0.5.0
198
+ - pyasn1-modules==0.3.0
199
+ - pybcj==1.0.2
200
+ - pycryptodomex==3.19.0
201
+ - pydantic==2.5.3
202
+ - pydantic-core==2.14.6
203
+ - pydotmap==0.1.3
204
+ - pygments==2.16.1
205
+ - pynvml==11.5.0
206
+ - pyparsing==3.0.9
207
+ - pyppmd==1.1.0
208
+ - python-dateutil==2.8.2
209
+ - python-dotenv==1.0.0
210
+ - pytictoc==1.5.2
211
+ - pytorch-lightning==2.0.4
212
+ - pytz==2023.3
213
+ - pyyaml==6.0
214
+ - pyzstd==0.15.9
215
+ - regex==2023.6.3
216
+ - requests-oauthlib==1.3.1
217
+ - responses==0.18.0
218
+ - retrying==1.3.4
219
+ - rich==13.7.1
220
+ - rouge==1.0.1
221
+ - rouge-score==0.1.2
222
+ - rsa==4.9
223
+ - sacrebleu==2.3.1
224
+ - safetensors==0.3.1
225
+ - scikit-learn==1.3.0
226
+ - scipy==1.11.0
227
+ - sentencepiece==0.1.99
228
+ - six==1.16.0
229
+ - sklearn==0.0.post7
230
+ - stack-data==0.6.2
231
+ - tabulate==0.9.0
232
+ - tensorboard==2.16.2
233
+ - tensorboard-data-server==0.7.1
234
+ - tensorflow==2.16.1
235
+ - tensorflow-io-gcs-filesystem==0.36.0
236
+ - termcolor==2.4.0
237
+ - texttable==1.7.0
238
+ - tf-slim==1.1.0
239
+ - threadpoolctl==3.2.0
240
+ - timm==0.4.5
241
+ - tokenizers==0.13.3
242
+ - torch==2.0.1+cu118
243
+ - torchaudio==2.0.2+cu118
244
+ - torchmetrics==0.11.4
245
+ - torchvision==0.15.2+cu118
246
+ - tqdm==4.65.0
247
+ - traitlets==5.9.0
248
+ - transformers==4.30.2
249
+ - tzdata==2023.3
250
+ - wcwidth==0.2.6
251
+ - werkzeug==2.3.6
252
+ - wrapt==1.16.0
253
+ - xxhash==3.2.0
254
+ - yarl==1.9.2
255
+ variables:
256
+ LD_LIBRARY_PATH: <CONDA_PATH>/envs/SPT/lib
257
+ LIBRARY_PATH: <CONDA_PATH>/envs/SPT/lib
evaluate_runs_results.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import glob
3
+ import pickle
4
+ import re
5
+
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ import numpy
10
+ from bert_score import BERTScorer
11
+ from evaluation import f1_score
12
+
13
+ import evaluate
14
+
15
+ rouge = evaluate.load('rouge')
16
+ # bertscore = BERTScorer(lang='en', device='cuda')
17
+ bertscore = BERTScorer(model_type='microsoft/deberta-xlarge-mnli', device='cuda')
18
+ _main_path = 'public_ckpt'
19
+
20
+ ADD_METEOR = True
21
+ if ADD_METEOR:
22
+ meteor_scorer = evaluate.load('meteor')
23
+
24
+ DO_PRED_CLEAN = True
25
+
26
+
27
+
28
+ def evaluate_folder(main_path, skip_exists=True):
29
+ results_path = f'{main_path}/results.txt'
30
+ results_csv_path = f'{main_path}/results.csv'
31
+ paths = glob.glob(f'{main_path}/*/evaluation_result*.pkl')
32
+ all_results = []
33
+ csv_results = []
34
+ csv_results.append(['path',
35
+ 'ppl',
36
+ 'F1',
37
+ 'bleu',
38
+ 'bleu-1',
39
+ 'bleu-2',
40
+ 'bleu-3',
41
+ 'bleu-4',
42
+ 'rouge1',
43
+ 'rouge2',
44
+ 'rougel',
45
+ 'BERT f1',
46
+ 'BERT precision',
47
+ 'BERT recall',
48
+ 'dist-1',
49
+ 'dist-2',
50
+ 'meteor',
51
+ 'valid_num'])
52
+ for path in paths:
53
+ with open(path, 'rb') as file:
54
+ results = pickle.load(file)
55
+ if results.get('result_str') is not None and skip_exists:
56
+ all_results.append(results['result_str'])
57
+ csv_results.append(results['csv'])
58
+ continue
59
+ preds = results['pred_text']
60
+ clean_preds = []
61
+ if DO_PRED_CLEAN:
62
+ for pred in preds:
63
+ search_result = re.search('R:|Q:|Summary:|\n|\:', pred)
64
+ if search_result is not None:
65
+ clean_preds.append(pred[:search_result.span()[0]])
66
+ else:
67
+ clean_preds.append(pred)
68
+ preds = clean_preds
69
+ tgt = results['gt_text']
70
+
71
+ def bleu_score(prediction, ground_truths):
72
+ from sacrebleu import BLEU
73
+ bleu = BLEU()
74
+ score = bleu.corpus_score(prediction, ground_truths)
75
+ return score
76
+
77
+ bleu = bleu_score(preds, [tgt])
78
+
79
+ precision, recall, f1 = bertscore.score(preds, tgt, verbose=False, batch_size=64)
80
+ mean_precision = precision.mean().item()
81
+ mean_recall = recall.mean().item()
82
+ mean_f1 = f1.mean().item()
83
+
84
+ def eval_distinct(corpus):
85
+ unigrams = []
86
+ bigrams = []
87
+ for n, rep in enumerate(corpus):
88
+ rep = rep.strip()
89
+ temp = rep.split(' ')
90
+ unigrams += temp
91
+ for i in range(len(temp) - 1):
92
+ bigrams.append(temp[i] + ' ' + temp[i + 1])
93
+ distink_1 = len(set(unigrams)) * 1.0 / len(unigrams)
94
+ distink_2 = len(set(bigrams)) * 1.0 / len(bigrams)
95
+ return distink_1, distink_2
96
+
97
+ rouge_results = rouge.compute(predictions=preds, references=tgt)
98
+ rouge1, rouge2, rougel = rouge_results['rouge1'], rouge_results['rouge2'], rouge_results['rougeL']
99
+ me_score = 0
100
+ if ADD_METEOR:
101
+ me_score = meteor_scorer.compute(predictions=preds, references=tgt)['meteor']
102
+ from evaluation import rouge_score
103
+ _rouge = rouge_score(preds, [tgt])
104
+ f1 = [f1_score(p, [t]) for p, t in zip(preds, tgt)]
105
+ f1 = numpy.asfarray(f1).mean()
106
+ ppl=''
107
+
108
+ result_str = f"""
109
+ path: {path}
110
+ F1: {f1}
111
+ bleu: {bleu.score}
112
+ bleu detail: {bleu.precisions}
113
+ rouge1, rouge2, rougel: {rouge1, rouge2, rougel}
114
+ BERT f1: {mean_f1}
115
+ BERT precision: {mean_precision}
116
+ BERT recall: {mean_recall}
117
+ dist: {eval_distinct(preds)}
118
+ METEOR: {me_score}
119
+ valid_num: {len(preds)}
120
+ """
121
+ csv_data = [path,
122
+ f1 * 100.0,
123
+ bleu.score,
124
+ *bleu.precisions,
125
+ rouge1 * 100.0,
126
+ rouge2 * 100.0,
127
+ rougel * 100.0,
128
+ mean_f1 * 100.0,
129
+ mean_precision * 100.0,
130
+ mean_recall * 100.0,
131
+ *eval_distinct(preds),
132
+ me_score,
133
+ len(preds)]
134
+ csv_results.append(csv_data)
135
+ print(result_str)
136
+ all_results.append(result_str)
137
+ with open(path, 'wb') as file:
138
+ results['result_str'] = result_str
139
+ results['csv'] = csv_data
140
+ pickle.dump(results, file)
141
+
142
+ with open(results_path, 'w') as file:
143
+ file.write("\n=====\n".join(all_results))
144
+ with open(results_csv_path, 'w') as file:
145
+ writer = csv.writer(file)
146
+ writer.writerows(csv_results)
147
+
148
+
149
+ if __name__ == '__main__':
150
+ evaluate_folder(_main_path, skip_exists=False)
evaluation.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import string
9
+ from collections import Counter
10
+ from typing import Callable
11
+
12
+ import regex
13
+ from rouge import Rouge
14
+
15
+ rouge = Rouge()
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ # Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
21
+ def normalize_answer(s: str) -> str:
22
+ def remove_articles(text):
23
+ return regex.sub(r"\b(a|an|the)\b", " ", text)
24
+
25
+ def white_space_fix(text):
26
+ return " ".join(text.split())
27
+
28
+ def remove_punc(text):
29
+ exclude = set(string.punctuation)
30
+ return "".join(ch for ch in text if ch not in exclude)
31
+
32
+ def lower(text):
33
+ return text.lower()
34
+
35
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
36
+
37
+
38
+ def em(prediction, ground_truth, normalize_fn):
39
+ return float(normalize_fn(prediction) == normalize_fn(ground_truth))
40
+
41
+
42
+ def f1(prediction, ground_truth, normalize_fn):
43
+ prediction_tokens = normalize_fn(prediction).split()
44
+ ground_truth_tokens = normalize_fn(ground_truth).split()
45
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
46
+ num_same = sum(common.values())
47
+
48
+ if num_same == 0:
49
+ return 0
50
+ precision = 1.0 * num_same / len(prediction_tokens)
51
+ recall = 1.0 * num_same / len(ground_truth_tokens)
52
+ f1 = (2 * precision * recall) / (precision + recall)
53
+ return f1
54
+
55
+
56
+ def rouge_wrapper(prediction, ground_truth):
57
+ try:
58
+ result = rouge.get_scores(prediction, ground_truth, avg=True)
59
+ return result["rouge-1"]["f"], result["rouge-2"]["f"], result["rouge-l"]["f"]
60
+ except:
61
+ return 0.0, 0.0, 0.0
62
+
63
+
64
+ # pred = [p1, p2 ..., pn] gt=[[g1,g2,...,gn]]
65
+ def f1_score(prediction, ground_truths, normalize_fn: Callable[[str], str] = lambda x: x):
66
+ return max([f1(prediction, gt, normalize_fn) for gt in ground_truths])
67
+
68
+
69
+ def exact_match_score(prediction, ground_truths, normalize_fn: Callable[[str], str] = lambda x: x):
70
+ return max([em(prediction, gt, normalize_fn) for gt in ground_truths])
71
+
72
+
73
+ # pred = [p1, p2 ..., pn] gt=[[g1,g2,...,gn]]
74
+ def rouge_score(prediction, ground_truths):
75
+ ground_truths = [x for x in ground_truths if len(x) > 0]
76
+ if (
77
+ len(prediction) == 0 or len(ground_truths) == 0
78
+ ): # check if empty prediction or if there is no hypothesis with len > 0
79
+ return 0.0, 0.0, 0.0
80
+ scores = [rouge_wrapper(prediction, gt) for gt in ground_truths]
81
+ rouge1 = max(s[0] for s in scores)
82
+ rouge2 = max(s[1] for s in scores)
83
+ rougel = max(s[2] for s in scores)
84
+ return rouge1, rouge2, rougel
85
+
86
+
87
+ # pred = [p1, p2 ..., pn] gt=[[g1,g2,...,gn]]
88
+ def bleu_score(prediction, ground_truths):
89
+ from sacrebleu import BLEU
90
+ bleu = BLEU()
91
+ score = bleu.corpus_score(prediction, ground_truths)
92
+ return score
interactive_test.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import locale
5
+ import os
6
+ import random
7
+ import re
8
+ import time
9
+ from multiprocessing import freeze_support
10
+
11
+ import deepspeed
12
+ import torch
13
+ from dotenv import load_dotenv
14
+ from torch.utils.data import DistributedSampler
15
+
16
+ from dataset.dataset import PersonaChatDataset
17
+ from utils.dist_helper import setup
18
+ from utils.format_inputs import TASK_TYPE
19
+ from utils.parser_helper import str2bool
20
+
21
+ os.environ["PYTHONIOENCODING"] = "utf-8"
22
+ myLocale = locale.setlocale(category=locale.LC_ALL, locale="C.UTF-8")
23
+ load_dotenv()
24
+
25
+ argparse = argparse.ArgumentParser()
26
+ argparse.add_argument('--model_path', type=str, default=None)
27
+ argparse.add_argument('--path_pattern', type=str, default=None)
28
+ argparse.add_argument('--batch_size', type=int)
29
+ argparse.add_argument('--valid_path', type=str, default=None)
30
+ argparse.add_argument('--local_rank', type=int, default=-1)
31
+ argparse.add_argument('--skip_exists', type=str2bool, default=False)
32
+ argparse.add_argument('--selection_noise', type=float, default=None)
33
+ parser = deepspeed.add_config_arguments(argparse)
34
+ args = argparse.parse_args()
35
+ _cmd_args = parser.parse_args()
36
+ freeze_support()
37
+
38
+ VICUNA_PREFIX = 'PATH_TO_VICUNA'
39
+
40
+
41
+ def test_process(model_paths, batch_size, valid_path, skip_exists, selection_noise, cmd_args):
42
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
43
+ with open(cmd_args.deepspeed_config) as json_file:
44
+ ds_config = json.load(json_file)
45
+ del cmd_args.deepspeed_config
46
+
47
+ setup()
48
+ for model_path in model_paths:
49
+ try:
50
+ if selection_noise is not None:
51
+ save_dir = os.sep.join(
52
+ model_path.split(os.sep)[:-1]) + os.sep + f'evaluation_result_selection_noise={selection_noise}.pkl'
53
+ else:
54
+ save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'evaluation_result.pkl'
55
+ if os.path.exists(save_dir) and (skip_exists):
56
+ continue
57
+ print(
58
+ f"Start setup rank {deepspeed.comm.get_local_rank()} of {world_size} on GPU {torch.cuda.current_device()}")
59
+
60
+ ckpt = torch.load(os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'checkpoint_best.pth',
61
+ map_location=f'cpu')
62
+ config = ckpt['config']
63
+ ds_config['train_micro_batch_size_per_gpu'] = batch_size
64
+ load_precision = '32'
65
+ if config.model.load_bit == 16:
66
+ ds_config['float16']['enabled'] = True
67
+ load_precision = 'fp16'
68
+ if 'llama' in config.model.model_name.lower():
69
+ ds_config['float16']['enabled'] = False
70
+ ds_config['bf16']['enabled'] = True
71
+ load_precision = 'bf16'
72
+ load_bit_map = {
73
+ 'fp16': torch.float16,
74
+ 'bf16': torch.bfloat16,
75
+ '32': torch.float32}
76
+
77
+ if config.model.model_type == 'selective_pt':
78
+ from models.selective_llm_chat import SelectLLMChat as LLMChat
79
+ else:
80
+ from models.llm_chat import LLMChat
81
+ if 'vicuna' in config.model.model_name and (not os.path.exists(config.model.model_name)):
82
+ config.model.model_name = VICUNA_PREFIX + os.sep + config.model.model_name.split(os.sep)[-1]
83
+ _model = LLMChat(config, batch_size)
84
+ left_tokenizer = _model.left_tokenizer
85
+ right_tokenizer = _model.right_tokenizer
86
+ print(f'LOADING {model_path} with {load_precision} precision')
87
+ model_engine, _, _, _ = deepspeed.initialize(args=cmd_args,
88
+ model=_model,
89
+ config=ds_config,
90
+ )
91
+ model_engine.load_checkpoint(model_path, load_module_strict=False, load_optimizer_states=False,
92
+ load_lr_scheduler_states=False,
93
+ load_module_only=True)
94
+ valid_path_file = valid_path
95
+ if valid_path_file is None:
96
+ valid_path_file = config.dataset.valid
97
+ if config.dataset.test.__class__ is str:
98
+ valid_path_file = config.dataset.test
99
+ print('using train split from personachat')
100
+ task_type = TASK_TYPE(config.training.task_type)
101
+
102
+ valid_dataset = PersonaChatDataset(valid_path_file, max_context_turns=config.dataset.max_context_turns)
103
+ from dataset.dataset import get_dataloader
104
+ max_new_token = 32
105
+ valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
106
+ drop_last=False)
107
+ valid_dataloader = get_dataloader(valid_dataset, batch_size, num_workers=0, sampler=valid_sampler)
108
+
109
+ context_input = []
110
+ persona_list = []
111
+ dist_pred_text = [None for _ in range(world_size)]
112
+ dist_gt_text = [None for _ in range(world_size)]
113
+ pred_text = []
114
+ gt_text = []
115
+ selected_prompts = []
116
+ print('Please enter your input:')
117
+ first_setence = input()
118
+ chosen_persona = random.choice([p['persona'] for p in valid_dataset.turns_data])
119
+ history = [f"Q: {first_setence}"]
120
+ history_with_prompt_idx = [f"USER: {first_setence}"]
121
+ selected_prompts = []
122
+ while True:
123
+ data = {'context_input': [history],
124
+ 'persona_list': [chosen_persona],
125
+ 'target': ['not use']}
126
+ _, text, batch_selected_prompts = LLMChat.test_step(model_engine, data, left_tokenizer,
127
+ right_tokenizer,
128
+ config, max_new_tokens=max_new_token,
129
+ tqdm_instance=None,
130
+ selection_noise=None,
131
+ no_repeat_ngram_size=4,
132
+ top_p=0.9,
133
+ num_beams=10)
134
+ response = text[0].strip()
135
+ search_result = re.search('R:|Q:|Summary:|\n|\:', response)
136
+ if search_result is not None:
137
+ response = response[:search_result.span()[0]]
138
+ response = response.strip()
139
+
140
+ selected_prompts.append(batch_selected_prompts.item())
141
+ history += [f"R: {response}"]
142
+ history_with_prompt_idx += [f"SPT: {response} [SPT Index: {batch_selected_prompts.item()}]"]
143
+ history_str = "\n".join(history_with_prompt_idx)
144
+ print_str = f"""
145
+ Persona: {' '.join(chosen_persona)}
146
+ Dialogue:
147
+ {history_str}
148
+ """
149
+ print(print_str)
150
+ print('Please enter your input:')
151
+ user_input = input()
152
+ if user_input == 'r':
153
+ history = history[:-1]
154
+ history_with_prompt_idx = history_with_prompt_idx[:-1]
155
+ continue
156
+ if user_input == 'exit':
157
+ exit()
158
+ elif user_input == 'save':
159
+ os.makedirs('interactive_dialog', exist_ok=True)
160
+ with open('interactive_dialog/'+time.strftime('%Y-%m-%d-%H%M')+'.txt', 'w') as file:
161
+ file.write(print_str)
162
+ history = []
163
+ history_with_prompt_idx = []
164
+ chosen_persona = random.choice([p['persona'] for p in valid_dataset.turns_data])
165
+ print('Please enter your input:')
166
+ user_input = input()
167
+ elif user_input == 'clear':
168
+ history = []
169
+ history_with_prompt_idx = []
170
+ chosen_persona = random.choice([p['persona'] for p in valid_dataset.turns_data])
171
+ print('Please enter your input:')
172
+ user_input = input()
173
+ history += [f"Q: {user_input}"]
174
+ history_with_prompt_idx += [f"USER: {user_input}"]
175
+
176
+ except Exception as e:
177
+ save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + "test_error.txt"
178
+ print(f'WRITING TESTING ERROR! ERROR: {str(e)}')
179
+ with open(save_dir, 'w') as file:
180
+ file.write(str(e))
181
+ deepspeed.comm.barrier()
182
+ deepspeed.comm.barrier()
183
+
184
+
185
+ model_path_arg = args.model_path
186
+ model_paths = [model_path_arg]
187
+ if len(glob.glob(model_path_arg + os.sep + 'ds_ckpt' + os.sep + '*')):
188
+ model_paths = [model_path_arg + os.sep + 'ds_ckpt']
189
+ elif not model_path_arg.endswith('.pth'):
190
+ import glob
191
+
192
+ path_pattern = args.path_pattern
193
+ if path_pattern is not None:
194
+ model_paths = glob.glob(f'{model_path_arg}/{path_pattern}/ds_ckpt/*/*.pt')
195
+ else:
196
+ model_paths = glob.glob(f'{model_path_arg}/*/ds_ckpt/*/*.pt')
197
+ model_paths = list(set([os.sep.join(p.split(os.sep)[:-2]) for p in model_paths]))
198
+ print(model_paths)
199
+ num_of_gpus = torch.cuda.device_count()
200
+ print(f"{num_of_gpus} GPUs available")
201
+ test_process(model_paths, args.batch_size, args.valid_path,
202
+ args.skip_exists, args.selection_noise, cmd_args=_cmd_args)
203
+ deepspeed.comm.barrier()
204
+ deepspeed.comm.destroy_process_group()
205
+ print('Test Ends')
models/__pycache__/llm_chat.cpython-310.pyc ADDED
Binary file (7.63 kB). View file
 
models/__pycache__/selective_llm_chat.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
models/llm_chat.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import get_peft_model, LoraConfig, PromptTuningConfig, TaskType, PrefixTuningConfig
3
+ from torch import nn, autocast
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from transformers.deepspeed import HfDeepSpeedConfig
6
+
7
+ from utils.format_inputs import TASK_TYPE
8
+ from utils.format_inputs import format_causal_personachat_input, format_personachat_input, \
9
+ format_generate_persona_input
10
+ from utils.model_helpers import print_trainable_parameters
11
+
12
+
13
+ # TODO: we need to extract LORA Weight and Bias from the model
14
+ # TODO: we need to do adaptive applied LORA
15
+ class LLMChat(nn.Module):
16
+ def __init__(self, config, batch_size, ds_config=None):
17
+ if ds_config is not None:
18
+ _hfdsc = HfDeepSpeedConfig(ds_config)
19
+ super(LLMChat, self).__init__()
20
+ self.model_name = config.model.model_name
21
+ self.load_bit = config.model.load_bit
22
+ self.left_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
23
+ original_vocab_size = len(self.left_tokenizer)
24
+ if config.training.mode != 'causal':
25
+ self.left_tokenizer.add_special_tokens({'pad_token': '[PAD]',
26
+ 'bos_token': '[BOS]',
27
+ 'eos_token': '[EOS]',
28
+ 'unk_token': '[UNK]',
29
+ 'sep_token': '[SEP]',
30
+ 'cls_token': '[CLS]',
31
+ 'mask_token': '[MASK]'})
32
+ self.left_tokenizer.padding_side = 'left'
33
+ self.left_tokenizer.truncation_side = 'left'
34
+ self.right_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
35
+ if config.training.mode != 'causal':
36
+ self.right_tokenizer.add_special_tokens({'pad_token': '[PAD]',
37
+ 'bos_token': '[BOS]',
38
+ 'eos_token': '[EOS]',
39
+ 'unk_token': '[UNK]',
40
+ 'sep_token': '[SEP]',
41
+ 'cls_token': '[CLS]',
42
+ 'mask_token': '[MASK]'})
43
+ self.right_tokenizer.padding_side = 'right'
44
+ self.right_tokenizer.truncation_side = 'right'
45
+ if self.left_tokenizer.pad_token is None and config.model.pad_token == 'bos':
46
+ self.left_tokenizer.pad_token = self.left_tokenizer.bos_token
47
+ self.right_tokenizer.pad_token = self.right_tokenizer.bos_token
48
+ elif self.left_tokenizer.pad_token_id is None:
49
+ self.left_tokenizer.pad_token = self.left_tokenizer.eos_token
50
+ self.right_tokenizer.pad_token = self.right_tokenizer.eos_token
51
+ self.batch_size = batch_size
52
+ load_bit_map = {4: {'load_in_4bit': True,
53
+ 'bnb_4bit_compute_dtype': torch.bfloat16},
54
+ 8: {'load_in_8bit': True},
55
+ 16: {'torch_dtype': torch.float16},
56
+ 32: {'torch_dtype': torch.float32}}
57
+ assert config.model.load_bit in [16, 32], 'deepspeed is not friendly with bnb!'
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ config.model.model_name,
60
+ **load_bit_map[config.model.load_bit],
61
+ )
62
+ if config.training.mode != 'causal':
63
+ model.resize_token_embeddings(len(self.left_tokenizer))
64
+ # for m in model.children():
65
+ # if hasattr(m, 'gradient_checkpointing_enable'):
66
+ # m.gradient_checkpointing_enable()
67
+ model.gradient_checkpointing_enable()
68
+ if config.model.peft_config is not None:
69
+ for param in model.parameters():
70
+ param.requires_grad = False # freeze the model - train adapters later
71
+ if param.ndim == 1:
72
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
73
+ param.data = param.data.to(torch.float32)
74
+ model.enable_input_require_grads()
75
+
76
+ # # enable special token embedding params, since we resized the vocabulary
77
+ # for name, param in model.named_parameters():
78
+ # if 'embed_tokens' in name:
79
+ # param[original_vocab_size:].requires_grad = True
80
+
81
+ class CastOutputToFloat(nn.Sequential):
82
+ def forward(self, x): return super().forward(x).to(torch.float32)
83
+
84
+ if config.model.peft_type == 'prompt_tuning':
85
+ peft_config = PromptTuningConfig(
86
+ **config.model.peft_config,
87
+ task_type=TaskType.CAUSAL_LM,
88
+ )
89
+ elif config.model.peft_type == 'prefix_tuning':
90
+ peft_config = PrefixTuningConfig(
91
+ **config.model.peft_config,
92
+ task_type=TaskType.CAUSAL_LM,
93
+ )
94
+ else:
95
+ peft_config = LoraConfig(**config.model.peft_config)
96
+ model.lm_head = CastOutputToFloat(model.lm_head)
97
+ model = get_peft_model(model, peft_config)
98
+ self.using_nn_modulelist = False
99
+ if config.model.using_nn_modulelist.__class__ is bool and config.model.using_nn_modulelist:
100
+ self.using_nn_modulelist = config.model.using_nn_modulelist
101
+ self.model = nn.ModuleList([model])
102
+ else:
103
+ self.model = model
104
+ if config.model.add_extra_layers.__class__ is bool and config.model.add_extra_layers:
105
+ self.prompt_normalizer = nn.Linear(
106
+ self.model[0].prompt_encoder.default.embedding.weight.shape[1],
107
+ self.model[0].word_embeddings.weight.shape[1])
108
+ self.score_activation = nn.Softplus(threshold=1, beta=10)
109
+ self.learning_rate = config.training.learning_rate
110
+ self.warmup_steps = config.training.warmup_steps
111
+ self.config = config
112
+ self.find_batch = False
113
+ print_trainable_parameters(self)
114
+
115
+ def print_llm_trainable_parameters(self):
116
+ print_trainable_parameters(self.model)
117
+
118
+ @autocast('cuda')
119
+ def forward(self, x):
120
+ if self.config._non_exists == 1:
121
+ self.prompt_normalizer(x)
122
+ self.score_activation(x)
123
+ for k in x.keys():
124
+ x[k] = x[k].cuda()
125
+ if self.find_batch:
126
+ x['attention_mask'] = x['attention_mask'].new_ones(x['attention_mask'].shape)
127
+ if self.using_nn_modulelist:
128
+ if self.config.model.using_output_stack.__class__ is bool and self.config.model.using_output_stack:
129
+ _outputs = [_model(**x) for _model in self.model]
130
+ _logits = torch.stack([_output['logits'] for _output in _outputs])
131
+ return {'logits': _logits}
132
+ return self.model[0](**x)
133
+ return self.model(**x)
134
+
135
+ def on_train_start(self) -> None:
136
+ self.print_llm_trainable_parameters()
137
+
138
+ @staticmethod
139
+ def training_step(model, batch, left_tokenizer, right_tokenizer, config, find_batch=False, mode='normal',
140
+ task_type=TASK_TYPE.GENERATE_RESPONSE, **_kwargs):
141
+ assert mode in ['normal', 'causal']
142
+ if task_type == TASK_TYPE.GENERATE_PERSONA and mode == 'normal':
143
+ lm_input, lm_target = format_generate_persona_input(batch, left_tokenizer, right_tokenizer,
144
+ config)
145
+ elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'causal':
146
+ lm_input, lm_target = format_causal_personachat_input(batch, left_tokenizer, right_tokenizer,
147
+ config)
148
+ elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'normal':
149
+ lm_input, lm_target = format_personachat_input(batch, left_tokenizer, right_tokenizer, config)
150
+ else:
151
+ raise NotImplementedError('mode and task_type not implemented')
152
+ output = model(lm_input)
153
+ if find_batch:
154
+ loss = nn.CrossEntropyLoss()(output['logits'].view(-1, output['logits'].shape[-1]),
155
+ lm_target.cuda().view(-1))
156
+ else:
157
+ if config.model.peft_type == 'prompt_tuning':
158
+ virtual_tokens = config.model.peft_config.num_virtual_tokens
159
+ batch_size = lm_target.size()[0]
160
+ _lm_target = torch.cat(
161
+ (lm_target.new_ones((batch_size, virtual_tokens)) * left_tokenizer.pad_token_id, lm_target), dim=1)
162
+ else:
163
+ _lm_target = lm_target
164
+ loss = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id)(
165
+ output['logits'].view(-1, output['logits'].shape[-1]),
166
+ _lm_target.cuda().view(-1))
167
+ # self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
168
+ if config.training.normalize_loss.__class__ == bool and config.training.normalize_loss.__class__:
169
+ model.module.normalize()
170
+ return loss
171
+
172
+ def normalize(self):
173
+ raise NotImplementedError('normalize trainable weights needs implementation')
174
+ return None
175
+
176
+ @staticmethod
177
+ def validation_step(model, batch, left_tokenizer, right_tokenizer, config, task_type, mode='normal'):
178
+ loss = LLMChat.training_step(model, batch, left_tokenizer, right_tokenizer, config, task_type=task_type,
179
+ find_batch=False, mode=mode)
180
+ return loss
181
+
182
+ def on_test_start(self) -> None:
183
+ from peft import get_peft_model_state_dict, set_peft_model_state_dict
184
+ peft_weight = get_peft_model_state_dict(self.model).copy()
185
+ peft_config = self.model.peft_config
186
+ del self.model
187
+ model = AutoModelForCausalLM.from_pretrained(
188
+ self.config.model.model_name,
189
+ torch_dtype=torch.bfloat16, low_cpu_mem_usage=True,
190
+ )
191
+ self.model = get_peft_model(model, peft_config['default'])
192
+ set_peft_model_state_dict(self.model, peft_weight, adapter_name='default')
193
+ self.model.merge_and_unload()
194
+ self.model.eval()
195
+
196
+ @staticmethod
197
+ @autocast('cuda')
198
+ def test_step(model, batch, left_tokenizer, right_tokenizer, config, max_new_tokens=16, tqdm_instance=None, **kwargs):
199
+ model.eval()
200
+ task_type = TASK_TYPE(config.training.task_type)
201
+ with torch.no_grad():
202
+ if config.training.mode == 'causal':
203
+ lm_input, lm_target, inference_tokenized = format_causal_personachat_input(batch,
204
+ left_tokenizer,
205
+ right_tokenizer,
206
+ config,
207
+ for_test=True)
208
+ else:
209
+ lm_input, lm_target, inference_tokenized = format_personachat_input(batch, left_tokenizer,
210
+ right_tokenizer, config,
211
+ for_test=True)
212
+ inference_tokenized.to('cuda')
213
+ model_for_generation = None
214
+ if 'deepspeed' in str(model.__class__):
215
+ model_for_generation = model.module.model
216
+ else:
217
+ model_for_generation = model.model
218
+ if model_for_generation.__class__ is nn.ModuleList:
219
+ model_for_generation = model_for_generation[0]
220
+ # adding do_sample=False to avoid inf error!
221
+ raw_output = model_for_generation.generate(**inference_tokenized, max_new_tokens=max_new_tokens,
222
+ do_sample=False)
223
+ trunc_output = raw_output[:, inference_tokenized['input_ids'].shape[1]:]
224
+ if trunc_output[trunc_output >= len(left_tokenizer)].size()[0] > 0:
225
+ trunc_output[trunc_output >= len(left_tokenizer)] = left_tokenizer.pad_token_id
226
+ text_output = right_tokenizer.batch_decode(trunc_output, skip_special_tokens=True)
227
+ return trunc_output, text_output, []
models/selective_llm_chat.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepspeed
2
+ import deepspeed
3
+ import torch
4
+ import transformers
5
+ from peft import get_peft_model, PromptTuningConfig, TaskType, PrefixTuningConfig
6
+ from torch import nn, autocast
7
+ from torch.functional import F
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from transformers.deepspeed import HfDeepSpeedConfig
11
+
12
+ from utils.format_inputs import TASK_TYPE
13
+ from utils.format_inputs import format_causal_personachat_input, format_personachat_input, format_generate_persona_input
14
+ from utils.model_helpers import print_trainable_parameters
15
+
16
+
17
+ class SelectLLMChat(nn.Module):
18
+ def __init__(self, config, batch_size, ds_config=None):
19
+ super(SelectLLMChat, self).__init__()
20
+ if ds_config is not None:
21
+ _hfdsc = HfDeepSpeedConfig(ds_config)
22
+ peft_type = config.model.peft_type
23
+ self.peft_type = peft_type
24
+ assert config.model.peft_type in ['prompt_tuning', 'prefix_tuning',
25
+ ], 'only prompt tuning is supported!'
26
+ K = config.model.K
27
+ self.K = K
28
+ self.ensemble_training = config.training.ensemble
29
+ self.model_name = config.model.model_name
30
+ self.load_bit = config.model.load_bit
31
+ self.left_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
32
+ if config.training.mode != 'causal':
33
+ self.left_tokenizer.add_special_tokens({'pad_token': '[PAD]',
34
+ 'bos_token': '[BOS]',
35
+ 'eos_token': '[EOS]',
36
+ 'unk_token': '[UNK]',
37
+ 'sep_token': '[SEP]',
38
+ 'cls_token': '[CLS]',
39
+ 'mask_token': '[MASK]'})
40
+ self.left_tokenizer.padding_side = 'left'
41
+ self.left_tokenizer.truncation_side = 'left'
42
+ self.right_tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=False)
43
+ if config.training.mode != 'causal':
44
+ self.right_tokenizer.add_special_tokens({'pad_token': '[PAD]',
45
+ 'bos_token': '[BOS]',
46
+ 'eos_token': '[EOS]',
47
+ 'unk_token': '[UNK]',
48
+ 'sep_token': '[SEP]',
49
+ 'cls_token': '[CLS]',
50
+ 'mask_token': '[MASK]'})
51
+ self.right_tokenizer.padding_side = 'right'
52
+ self.right_tokenizer.truncation_side = 'right'
53
+ if self.left_tokenizer.pad_token is None and config.model.pad_token=='bos':
54
+ self.left_tokenizer.pad_token = self.left_tokenizer.bos_token
55
+ self.right_tokenizer.pad_token = self.right_tokenizer.bos_token
56
+ elif self.left_tokenizer.pad_token_id is None:
57
+ self.left_tokenizer.pad_token = self.left_tokenizer.eos_token
58
+ self.right_tokenizer.pad_token = self.right_tokenizer.eos_token
59
+ self.batch_size = batch_size
60
+ load_bit_map = {4: {'load_in_4bit': True,
61
+ 'bnb_4bit_compute_dtype': torch.bfloat16},
62
+ 8: {'load_in_8bit': True},
63
+ 16: {'torch_dtype': torch.float16},
64
+ 'bf16': {'torch_dtype': torch.bfloat16},
65
+ 32: {'torch_dtype': torch.float32}}
66
+ assert config.model.load_bit in [16, 32, 'bf16'], 'deepspeed is not friendly with bnb!'
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ config.model.model_name,
69
+ **load_bit_map[config.model.load_bit]
70
+ )
71
+ if config.training.mode != 'causal':
72
+ model.resize_token_embeddings(len(self.left_tokenizer))
73
+ model.gradient_checkpointing_enable()
74
+ if config.model.peft_config is not None:
75
+ for param in model.parameters():
76
+ param.requires_grad = False # freeze the model - train adapters later
77
+ if param.ndim == 1:
78
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
79
+ param.data = param.data.to(torch.float32)
80
+ model.enable_input_require_grads()
81
+ class CastOutputToFloat(nn.Sequential):
82
+ def forward(self, x): return super().forward(x).to(torch.float32)
83
+
84
+ model.lm_head = CastOutputToFloat(model.lm_head)
85
+ self.model = model
86
+ models = []
87
+ peft_config = None
88
+ for _ in range(K):
89
+ if config.model.peft_type == 'prompt_tuning':
90
+ peft_config = PromptTuningConfig(
91
+ **config.model.peft_config,
92
+ task_type=TaskType.CAUSAL_LM,
93
+ )
94
+ elif config.model.peft_type == 'prefix_tuning':
95
+ peft_config = PrefixTuningConfig(
96
+ **config.model.peft_config,
97
+ task_type=TaskType.CAUSAL_LM,
98
+ )
99
+ else:
100
+ raise NotImplementedError()
101
+ _peft_model = get_peft_model(model, peft_config)
102
+ models.append(_peft_model)
103
+ self.models = nn.ModuleList(models)
104
+ self.learning_rate = config.training.learning_rate
105
+ self.warmup_steps = config.training.warmup_steps
106
+ self.config = config
107
+ self.find_batch = False
108
+ self.retriever = None
109
+ if config.model.retriever.retriever_type == 'transformer_encoder':
110
+ encoder_layer = nn.TransformerEncoderLayer(d_model=self.models[0].word_embeddings.weight.shape[1],
111
+ nhead=config.model.retriever.n_head)
112
+ transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.model.retriever.num_layers)
113
+ self.retriever = transformer_encoder
114
+ if config.model.peft_type in ['prompt_tuning'] and config.model.normalizer.__class__ is not str:
115
+ class DoNothing(nn.Sequential):
116
+ def forward(self, x): return x
117
+
118
+ self.prompt_normalizer = DoNothing()
119
+ elif config.model.normalizer == 'linear':
120
+ if config.model.peft_type in ['prompt_tuning', 'prefix_tuning']:
121
+ _d_peft = self.models[0].prompt_encoder.default.embedding.weight.shape[1]
122
+ else:
123
+ raise NotImplementedError('check here!')
124
+ self.prompt_normalizer = nn.Linear(_d_peft, _d_peft)
125
+
126
+ if config.model.score_activation == 'softplus':
127
+ self.score_activation = nn.Softplus(threshold=1, beta=10)
128
+ elif config.model.score_activation == 'relu':
129
+ self.score_activation = nn.ReLU()
130
+ elif config.model.score_activation == 'leaky_relu':
131
+ self.score_activation = nn.LeakyReLU()
132
+ else:
133
+ self.score_activation = nn.Softplus(threshold=1, beta=10)
134
+ # raise NotImplementedError()
135
+ self.retriever_on = ['extra']
136
+ if config.model.retriever.retriever_on.__class__ is list:
137
+ self.retriever_on = config.model.retriever.retriever_on
138
+ if config.training.all_tunable.__class__ is bool and config.training.all_tunable:
139
+ for param in self.parameters():
140
+ param.requires_grad = True
141
+ print_trainable_parameters(self)
142
+ self.contrastive_metric = None
143
+ if config.training.contrastive_metric.__class__ is str:
144
+ self.contrastive_metric = config.training.contrastive_metric
145
+ self.contrastive_threshold = 0.0
146
+ if config.training.contrastive_threshold.__class__ is float:
147
+ self.contrastive_threshold = config.training.contrastive_threshold
148
+ self.config = config
149
+ self.annealing_nll = False
150
+ self.annealing_scalar = 0.0
151
+ if self.config.training.annealing_nll.__class__ == bool:
152
+ self.annealing_nll = self.config.training.annealing_nll
153
+ self.annealing_scalar = self.config.training.annealing_scalar
154
+
155
+
156
+ def print_llm_trainable_parameters(self):
157
+ print_trainable_parameters(self.model)
158
+
159
+ def retrieve_based_on_input_x(self, x, K):
160
+ return self.retrieve_prompts(x, K)
161
+
162
+ @autocast('cuda')
163
+ def retrieve_prompts(self, x, K):
164
+ batch_size = x['input_ids'].shape[0]
165
+ input_ids = x['input_ids']
166
+ spawned_x = input_ids.repeat(K, 1)
167
+ if self.models[0].base_model.__class__ == transformers.models.llama.modeling_llama.LlamaForCausalLM:
168
+ spawned_x_emb = self.models[0].base_model.model.embed_tokens(spawned_x)
169
+ else:
170
+ spawned_x_emb = self.models[0].base_model.model.decoder.embed_tokens(spawned_x)
171
+ if spawned_x_emb.shape[-1] != self.models[0].config.hidden_size:
172
+ # need project_in here
173
+ spawned_x_emb = self.models[0].base_model.model.decoder.project_in(spawned_x_emb)
174
+ prompt_embeddings = torch.stack([_model.prompt_encoder.default.embedding.weight for _model in self.models])
175
+ if self.retriever is not None:
176
+ if 'extra' in self.retriever_on:
177
+ prompt_embeddings = self.retriever(self.prompt_normalizer(prompt_embeddings))
178
+ if 'lm' in self.retriever_on:
179
+ spawned_x_emb = self.retriever(spawned_x_emb)
180
+ spawned_x_emb_mean = spawned_x_emb.mean(dim=1)
181
+ prompt_embeddings_mean = prompt_embeddings.mean(dim=1)
182
+ if self.retriever is None:
183
+ normalizer_on = self.config.model.normalizer_on
184
+ if normalizer_on.__class__ is not list:
185
+ prompt_embeddings_mean = self.prompt_normalizer(prompt_embeddings_mean)
186
+ if 'prompt' in normalizer_on:
187
+ prompt_embeddings_mean = self.prompt_normalizer(prompt_embeddings_mean)
188
+ if 'lm' in normalizer_on:
189
+ spawned_x_emb_mean = self.prompt_normalizer(spawned_x_emb_mean)
190
+ prompt_embeddings_mean_spawn = torch.repeat_interleave(prompt_embeddings_mean, batch_size, dim=0)
191
+ sim_scores = self.score_activation(
192
+ torch.nn.CosineSimilarity()(prompt_embeddings_mean_spawn, spawned_x_emb_mean))
193
+ return sim_scores
194
+
195
+ @autocast('cuda')
196
+ def forward(self, x, mode='training'):
197
+ for k in x.keys():
198
+ x[k] = x[k].cuda(device=deepspeed.comm.get_local_rank())
199
+ if self.find_batch:
200
+ x['attention_mask'] = x['attention_mask'].new_ones(x['attention_mask'].shape)
201
+ if mode == 'training':
202
+ if self.config.training.skip_retrieval.__class__ is bool and self.config.training.skip_retrieval:
203
+ sim_scores = None
204
+ else:
205
+ sim_scores = self.retrieve_based_on_input_x(x, self.K)
206
+ # get pt embeddings
207
+ _outputs = [_model(**x) for _model in self.models]
208
+ _logits = torch.stack([_output['logits'] for _output in _outputs])
209
+ return {'logits': _logits, 'sim_scores': sim_scores}
210
+ else:
211
+ raise NotImplementedError('validation and testing not implemented')
212
+
213
+ def on_train_start(self) -> None:
214
+ self.print_llm_trainable_parameters()
215
+ deepspeed.zero.Init()
216
+
217
+ @staticmethod
218
+ def training_step(model, batch, left_tokenizer, right_tokenizer, config, mode='normal',
219
+ task_type=TASK_TYPE.GENERATE_RESPONSE, training_process=0.0):
220
+ assert mode in ['normal', 'causal']
221
+ if task_type == TASK_TYPE.GENERATE_PERSONA and mode == 'normal':
222
+ lm_input, lm_target = format_generate_persona_input(batch, left_tokenizer, right_tokenizer,
223
+ config)
224
+ elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'causal':
225
+ lm_input, lm_target = format_causal_personachat_input(batch, left_tokenizer, right_tokenizer,
226
+ config)
227
+ elif task_type == TASK_TYPE.GENERATE_RESPONSE and mode == 'normal':
228
+ lm_input, lm_target = format_personachat_input(batch, left_tokenizer, right_tokenizer, config)
229
+ else:
230
+ raise NotImplementedError('mode and task_type not implemented')
231
+ output = model.module(dict(lm_input))
232
+ # suppose batch=2, K=3, the logits is presented interleave:
233
+ # [0,1]
234
+ # [0,1]
235
+ # [0,1]
236
+ logits = output['logits'] # (K*Batch,SeqLen,VocabSize)
237
+ logits = logits.view(-1, logits.shape[2], logits.shape[3])
238
+ sim_scores = output['sim_scores']
239
+ batch_size = lm_target.size()[0]
240
+ if config.model.peft_type == 'prompt_tuning':
241
+ virtual_tokens = config.model.peft_config.num_virtual_tokens
242
+ _lm_target = torch.cat(
243
+ (lm_target.new_ones((batch_size, virtual_tokens)) * left_tokenizer.pad_token_id, lm_target), dim=1)
244
+ else:
245
+ _lm_target = lm_target
246
+ _lm_target_spawn = _lm_target.repeat(config.model.K, 1)
247
+ losses = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id, reduction='none')(
248
+ logits.view(-1, logits.shape[-1]),
249
+ _lm_target_spawn.cuda(device=deepspeed.comm.get_local_rank()).view(-1))
250
+ if config.training.only_nll.__class__ == bool and config.training.only_nll:
251
+ return losses[losses != 0].mean()
252
+
253
+ reshaped_losses = losses.view(logits.shape[0], logits.shape[1]).detach().clone()
254
+ reshaped_losses = torch.stack([_losses[_losses != 0].mean() for _losses in reshaped_losses.detach().clone()])
255
+ # reshaped_losses = reshaped_losses.clone().detach().mean(dim=1)
256
+
257
+ softmaxed_neg_losses = nn.Softmax(dim=0)(
258
+ -reshaped_losses.view(config.model.K, batch_size) / config.training.tau_gold).permute(1, 0)
259
+ if config.training.adding_noise.__class__ is float:
260
+ noise = torch.randn_like(softmaxed_neg_losses, device=softmaxed_neg_losses.device)
261
+ softmaxed_neg_losses = softmaxed_neg_losses + config.training.adding_noise * noise
262
+ logsoftmaxed_sim_scores = F.log_softmax(sim_scores.view(config.model.K, batch_size) / config.training.tau_sim,
263
+ dim=0).permute(1, 0)
264
+ kldiv_loss = nn.KLDivLoss(reduction='batchmean')(logsoftmaxed_sim_scores,
265
+ softmaxed_neg_losses)
266
+ selective_loss_weight = 1.0
267
+ if config.training.annealing_nll.__class__ is bool and config.training.annealing_nll:
268
+ _ann_scalar = config.training.annealing_scalar * (1 - training_process)
269
+ _sim_score = torch.clamp(_ann_scalar * nn.Softmax(-1)(sim_scores),
270
+ config.training.annealing_min, config.training.annealing_max).detach()
271
+ losses = torch.einsum('ab,a->ab', losses.view(logits.shape[0], logits.shape[1]), _sim_score).view(-1)
272
+
273
+ if config.training.selective_loss_weight.__class__ == float:
274
+ selective_loss_weight = config.training.selective_loss_weight
275
+ if config.training.selective_loss.__class__ == bool and (config.training.selective_loss == False):
276
+ loss = losses[losses != 0].mean()
277
+ elif config.training.disable_nll.__class__ is bool and config.training.disable_nll:
278
+ loss = selective_loss_weight * kldiv_loss
279
+ else:
280
+ loss = losses[losses != 0].mean() + selective_loss_weight * kldiv_loss
281
+
282
+ if model.module.ensemble_training:
283
+ K = config.model.K
284
+ enb_losses = []
285
+ for data_idx in range(batch_size):
286
+ data_indices = [data_idx + (batch_size * inc) for inc in range(K)]
287
+ ensemble_preds = logits[data_indices, :, :]
288
+ ensemble_sims = sim_scores[data_indices]
289
+ normed_preds = ensemble_sims.unsqueeze(-1).unsqueeze(-1).mul(ensemble_preds)
290
+ normed_preds = normed_preds.sum(dim=0)
291
+ _target = _lm_target_spawn[data_indices, :]
292
+ assert _target.unique(dim=0).shape[0] == 1, 'error in resemble the preds'
293
+ enb_loss = nn.CrossEntropyLoss(ignore_index=left_tokenizer.pad_token_id)(normed_preds,
294
+ _target[0].cuda(
295
+ device=deepspeed.comm.get_local_rank()))
296
+ enb_losses.append(enb_loss)
297
+ loss += torch.stack(enb_losses).mean()
298
+ if model.module.contrastive_metric:
299
+ ctr_losses = []
300
+ from sacrebleu import BLEU
301
+ ctr_metrics = BLEU(effective_order=True)
302
+ batch_persona = [' '.join(row) for row in batch['persona_list']]
303
+ statics = {}
304
+ # Dim here
305
+ # x1 x2
306
+ # p1 s11 s21
307
+ # p2 s12 s22
308
+ # p3 s13 s23
309
+ permuted_sim_scores = sim_scores.unsqueeze(0).view(model.module.K, batch_size)
310
+ if model.module.contrastive_metric == 'bleu':
311
+ for idx in range(len(batch_persona) - 1):
312
+ for jdx in range(idx + 1, len(batch_persona)):
313
+ iele = batch_persona[idx]
314
+ jele = batch_persona[jdx]
315
+ scores = ctr_metrics.sentence_score(iele, [jele]).score
316
+ idist = permuted_sim_scores[:, idx]
317
+ jdist = permuted_sim_scores[:, jdx]
318
+ cosine_emb_loss = nn.CosineEmbeddingLoss()
319
+ if scores > model.module.contrastive_threshold:
320
+ cosine_target = 1
321
+ else:
322
+ cosine_target = -1
323
+ cos_loss = cosine_emb_loss(idist, jdist, torch.tensor(cosine_target))
324
+ ctr_losses.append(cos_loss)
325
+ statics[(idx, jdx)] = {'iele': iele, 'jele': jele, 'scores': scores,
326
+ 'idist': idist,
327
+ 'jdist': jdist, 'cos_emb_loss': cos_loss}
328
+ if len(ctr_losses) != 0:
329
+ ctr_losses_pt = torch.stack(ctr_losses).mean()
330
+ loss += config.training.contrastive_weight * ctr_losses_pt
331
+ else:
332
+ print(f'CTR ERROR: {statics}')
333
+ return loss
334
+
335
+ @staticmethod
336
+ def validation_step(model, batch, left_tokenizer, right_tokenizer, config, task_type, mode='normal'):
337
+ loss = SelectLLMChat.training_step(model, batch, left_tokenizer, right_tokenizer, config, task_type=task_type,
338
+ mode=mode, training_process=0.0)
339
+ return loss
340
+
341
+ @staticmethod
342
+ @autocast('cuda')
343
+ def test_step(model, batch, left_tokenizer, right_tokenizer, config, max_new_tokens=16, tqdm_instance: tqdm = None,
344
+ selection_noise=None, **gen_kwargs):
345
+ model.eval()
346
+ with torch.no_grad():
347
+ if config.training.mode == 'causal':
348
+ lm_input, lm_target, inference_tokenized = format_causal_personachat_input(batch,
349
+ left_tokenizer,
350
+ right_tokenizer,
351
+ config,
352
+ for_test=True)
353
+ else:
354
+ lm_input, lm_target, inference_tokenized = format_personachat_input(batch, left_tokenizer,
355
+ right_tokenizer,
356
+ config,
357
+ for_test=True)
358
+ inference_tokenized.to('cuda')
359
+ if 'deepspeed' in str(model.__class__):
360
+ batch_size = inference_tokenized['input_ids'].shape[0]
361
+ sim_scores = model.module.retrieve_based_on_input_x(inference_tokenized, config.model.K)
362
+ sim_scores = sim_scores.reshape(config.model.K, batch_size).permute(1, 0)
363
+ if selection_noise:
364
+ noise = torch.randn_like(sim_scores, device=sim_scores.device)
365
+ sim_scores = sim_scores + selection_noise * noise
366
+ selected_prompts = torch.argmax(sim_scores, dim=1)
367
+ if tqdm_instance is not None:
368
+ tqdm_instance.set_postfix_str(f"selected prompts: {selected_prompts}")
369
+ detached_selected_prompts = selected_prompts.detach().cpu().numpy()
370
+ selected_prompts_set = set(detached_selected_prompts)
371
+ output_dicts = {}
372
+ # adding do_sample=False to avoid inf error!
373
+ for key in selected_prompts_set:
374
+ outputs = model.module.models[key].generate(
375
+ input_ids=inference_tokenized['input_ids'],
376
+ attention_mask=inference_tokenized['attention_mask'],
377
+ max_new_tokens=max_new_tokens,
378
+ do_sample=False,
379
+ **gen_kwargs
380
+ )
381
+ output_dicts[key] = outputs.detach().cpu()
382
+ raw_output = []
383
+ for idx, prompt_idx in enumerate(detached_selected_prompts):
384
+ raw_output.append(output_dicts[prompt_idx][idx][inference_tokenized['input_ids'].shape[1]:])
385
+ # raw_output = torch.stack(raw_output).squeeze(1)
386
+ trunc_output = raw_output
387
+ text_output = right_tokenizer.batch_decode(trunc_output, skip_special_tokens=True)
388
+ return trunc_output, text_output, selected_prompts
389
+ else:
390
+ raise NotImplementedError('not implemented')
test.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import locale
5
+ import os
6
+ import re
7
+ from functools import reduce
8
+ from multiprocessing import freeze_support
9
+
10
+ import deepspeed
11
+ import torch
12
+ import torch.distributed as dist
13
+ from dotenv import load_dotenv
14
+ from torch.utils.data import DistributedSampler
15
+ from tqdm import tqdm
16
+
17
+ from dataset.dataset import PersonaChatDataset
18
+ from utils.dist_helper import setup
19
+ from utils.format_inputs import TASK_TYPE
20
+ from utils.parser_helper import str2bool
21
+
22
+ os.environ["PYTHONIOENCODING"] = "utf-8"
23
+ myLocale = locale.setlocale(category=locale.LC_ALL, locale="C.UTF-8")
24
+ load_dotenv()
25
+
26
+ argparse = argparse.ArgumentParser()
27
+ argparse.add_argument('--model_path', type=str, default=None)
28
+ argparse.add_argument('--path_pattern', type=str, default=None)
29
+ argparse.add_argument('--batch_size', type=int)
30
+ argparse.add_argument('--valid_path', type=str, default=None)
31
+ argparse.add_argument('--local_rank', type=int, default=-1)
32
+ argparse.add_argument('--skip_exists', type=str2bool, default=False)
33
+ argparse.add_argument('--selection_noise', type=float, default=None)
34
+ parser = deepspeed.add_config_arguments(argparse)
35
+ args = argparse.parse_args()
36
+ _cmd_args = parser.parse_args()
37
+ freeze_support()
38
+
39
+ VICUNA_PREFIX = 'PATH_TO_VICUNA'
40
+
41
+
42
+ def test_process(model_paths, batch_size, valid_path, skip_exists, selection_noise, cmd_args):
43
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
44
+ with open(cmd_args.deepspeed_config) as json_file:
45
+ ds_config = json.load(json_file)
46
+ del cmd_args.deepspeed_config
47
+
48
+ setup()
49
+ for model_path in model_paths:
50
+ try:
51
+ if selection_noise is not None:
52
+ save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + f'evaluation_result_selection_noise={selection_noise}.pkl'
53
+ else:
54
+ save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'evaluation_result.pkl'
55
+ if os.path.exists(save_dir) and (skip_exists):
56
+ continue
57
+ print(
58
+ f"Start setup rank {deepspeed.comm.get_local_rank()} of {world_size} on GPU {torch.cuda.current_device()}")
59
+
60
+ ckpt = torch.load(os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + 'checkpoint_best.pth',
61
+ map_location=f'cpu')
62
+ config = ckpt['config']
63
+ ds_config['train_micro_batch_size_per_gpu'] = batch_size
64
+ load_precision = '32'
65
+ if config.model.load_bit == 16:
66
+ ds_config['float16']['enabled'] = True
67
+ load_precision = 'fp16'
68
+ if 'llama' in config.model.model_name.lower():
69
+ ds_config['float16']['enabled'] = False
70
+ ds_config['bf16']['enabled'] = True
71
+ load_precision = 'bf16'
72
+ load_bit_map = {
73
+ 'fp16': torch.float16,
74
+ 'bf16': torch.bfloat16,
75
+ '32': torch.float32}
76
+
77
+ if config.model.model_type == 'selective_pt':
78
+ from models.selective_llm_chat import SelectLLMChat as LLMChat
79
+ else:
80
+ from models.llm_chat import LLMChat
81
+ if 'vicuna' in config.model.model_name and (not os.path.exists(config.model.model_name)):
82
+ config.model.model_name = VICUNA_PREFIX + os.sep + config.model.model_name.split(os.sep)[-1]
83
+ _model = LLMChat(config, batch_size)
84
+ left_tokenizer = _model.left_tokenizer
85
+ right_tokenizer = _model.right_tokenizer
86
+ print(f'LOADING {model_path} with {load_precision} precision')
87
+ model_engine, _, _, _ = deepspeed.initialize(args=cmd_args,
88
+ model=_model,
89
+ config=ds_config,
90
+ )
91
+ model_engine.load_checkpoint(model_path, load_module_strict=False, load_optimizer_states=False,
92
+ load_lr_scheduler_states=False,
93
+ load_module_only=True)
94
+ valid_path_file = valid_path
95
+ if valid_path_file is None:
96
+ valid_path_file = config.dataset.valid
97
+ if config.dataset.test.__class__ is str:
98
+ valid_path_file = config.dataset.test
99
+ print('using train split from personachat')
100
+ task_type = TASK_TYPE(config.training.task_type)
101
+ valid_dataset = PersonaChatDataset(valid_path_file, max_context_turns=config.dataset.max_context_turns)
102
+ from dataset.dataset import get_dataloader
103
+ max_new_token = 32
104
+ valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
105
+ drop_last=False)
106
+ valid_dataloader = get_dataloader(valid_dataset, batch_size, num_workers=0, sampler=valid_sampler)
107
+
108
+ context_input = []
109
+ persona_list = []
110
+ dist_pred_text = [None for _ in range(world_size)]
111
+ dist_gt_text = [None for _ in range(world_size)]
112
+ pred_text = []
113
+ gt_text = []
114
+ tqdm_iterator = tqdm(valid_dataloader, total=len(valid_dataloader))
115
+ selected_prompts = []
116
+ for data in tqdm_iterator:
117
+ _, text, batch_selected_prompts = LLMChat.test_step(model_engine, data, left_tokenizer,
118
+ right_tokenizer,
119
+ config, max_new_tokens=max_new_token,
120
+ tqdm_instance=tqdm_iterator,
121
+ selection_noise=selection_noise)
122
+ if batch_selected_prompts.__class__ != list:
123
+ selected_prompts += (batch_selected_prompts.detach().cpu().tolist())
124
+
125
+ context_input += data['context_input']
126
+ persona_list += data['persona_list']
127
+ pred_text += text
128
+ gt_text += data['target']
129
+
130
+ clean_preds = []
131
+ for pred in pred_text:
132
+ search_result = re.search('R:|Q:|Summary:|\n|\:', pred)
133
+ if search_result is not None:
134
+ clean_preds.append(pred[:search_result.span()[0]])
135
+ else:
136
+ clean_preds.append(pred)
137
+ pred_text = clean_preds
138
+ dist.all_gather_object(dist_pred_text, pred_text)
139
+ dist.all_gather_object(dist_gt_text, gt_text)
140
+ pred_text = reduce(lambda x, y: x + y, dist_pred_text)
141
+ gt_text = reduce(lambda x, y: x + y, dist_gt_text)
142
+ from evaluation import bleu_score, f1_score, normalize_answer
143
+ bleu = bleu_score(pred_text, [gt_text])
144
+ import pickle
145
+
146
+ result = {
147
+ 'context_input': context_input,
148
+ 'persona_list': persona_list,
149
+ 'pred_text': pred_text,
150
+ 'gt_text': gt_text,
151
+ 'bleu': bleu,
152
+ }
153
+ from collections import Counter
154
+ counter = Counter(selected_prompts)
155
+ if deepspeed.comm.get_local_rank() == 0:
156
+ print('bleu: ', bleu)
157
+ with open(save_dir, 'wb') as file:
158
+ pickle.dump(result, file)
159
+ with open(save_dir.replace('.pkl', '.txt'), 'w', encoding='utf-8') as file:
160
+ file.write('bleu: ' + str(bleu) + '\n')
161
+ if len(selected_prompts) > 0:
162
+ file.write('selected prompt: ' + str(counter) + '\n')
163
+ for i in range(len(context_input)):
164
+ if context_input[i].__class__ == list:
165
+ file.write('context: ' + str(u' '.join(context_input[i]).encode('utf-8')) + '\n')
166
+ else:
167
+ file.write('context: ' + str(context_input[i].encode('utf-8')) + '\n')
168
+ file.write('persona: ' + str(u' '.join(persona_list[i]).encode('utf-8')) + '\n')
169
+ file.write('pred: ' + pred_text[i] + '\n')
170
+ file.write('gt: ' + gt_text[i] + '\n')
171
+ if len(selected_prompts) > 0:
172
+ file.write('selected prompt: ' + str(selected_prompts[i]) + '\n')
173
+ file.write('\n')
174
+ except Exception as e:
175
+ save_dir = os.sep.join(model_path.split(os.sep)[:-1]) + os.sep + "test_error.txt"
176
+ print(f'WRITING TESTING ERROR! ERROR: {str(e)}')
177
+ with open(save_dir, 'w') as file:
178
+ file.write(str(e))
179
+ deepspeed.comm.barrier()
180
+ deepspeed.comm.barrier()
181
+
182
+
183
+ model_path_arg = args.model_path
184
+ model_paths = [model_path_arg]
185
+ if len(glob.glob(model_path_arg+os.sep+'ds_ckpt'+os.sep+'*')):
186
+ model_paths = [model_path_arg+os.sep+'ds_ckpt']
187
+ elif not model_path_arg.endswith('.pth'):
188
+ import glob
189
+ path_pattern = args.path_pattern
190
+ if path_pattern is not None:
191
+ model_paths = glob.glob(f'{model_path_arg}/{path_pattern}/ds_ckpt/*/*.pt')
192
+ else:
193
+ model_paths = glob.glob(f'{model_path_arg}/*/ds_ckpt/*/*.pt')
194
+ model_paths = list(set([os.sep.join(p.split(os.sep)[:-2]) for p in model_paths]))
195
+ print(model_paths)
196
+ num_of_gpus = torch.cuda.device_count()
197
+ print(f"{num_of_gpus} GPUs available")
198
+ test_process(model_paths, args.batch_size, args.valid_path,
199
+ args.skip_exists, args.selection_noise, cmd_args=_cmd_args)
200
+ deepspeed.comm.barrier()
201
+ deepspeed.comm.destroy_process_group()
202
+ # if not model_path_arg.endswith('.pth'):
203
+ # evaluate_folder(model_path_arg, skip_exists=args.skip_exists)
204
+ print('Test Ends')
train.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from multiprocessing import freeze_support
4
+
5
+ import deepspeed
6
+ import torch
7
+ from dotenv import load_dotenv
8
+ from transformers.utils import logging
9
+
10
+ from trainer.peft_trainer import train_generator
11
+ from utils.config import get_config
12
+ from utils.parser_helper import str2bool
13
+
14
+ load_dotenv()
15
+ logging.set_verbosity_error()
16
+ torch.multiprocessing.set_sharing_strategy('file_system')
17
+ torch.set_float32_matmul_precision('medium')
18
+
19
+
20
+ def set_model_config(the_config, value, key):
21
+ if value is not None:
22
+ the_config.model[key] = value
23
+
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--config', type=str)
27
+ parser.add_argument('--batch', type=int, default=2)
28
+ parser.add_argument('--lr', type=float, default=None)
29
+ parser.add_argument('--find_batch', type=str2bool, default=False)
30
+ parser.add_argument('--find_lr', type=str2bool, default=False)
31
+ parser.add_argument('--bf16', type=str2bool, default=True)
32
+ parser.add_argument('--auto_scale_batch_size', type=str2bool, default=False)
33
+ parser.add_argument('--train_after_tune', type=str2bool, default=False)
34
+ parser.add_argument('--num_workers', type=int, default=0)
35
+ parser.add_argument('--epoch', type=int, default=None)
36
+ parser.add_argument('--scheduler_patience', type=int, default=10)
37
+ parser.add_argument('--scheduler_monitor', type=str, default='train_loss', choices=['train_loss'])
38
+ parser.add_argument('--seed', type=int, default=3407)
39
+ parser.add_argument('--grad_clip', type=float, default=-1)
40
+ parser.add_argument('--save_model', type=str2bool, default=True)
41
+ parser.add_argument('--shuffle_train', type=str2bool, default=True)
42
+ parser.add_argument('--training_ratio', type=float, default=1.0)
43
+ parser.add_argument('--adding_noise', type=float, default=None)
44
+
45
+ # parser.add_argument('--retriever_type', type=str, default=None,
46
+ # choices=['bert-base-uncased', 'albert-base-v2'])
47
+
48
+ parser.add_argument('--tokenizer_parallel', type=str2bool, default=True)
49
+ parser.add_argument('--do_test', type=str2bool, default=False)
50
+ parser.add_argument('--exp_name', type=str, default=None)
51
+ parser.add_argument('--mode', default=None, type=str, choices=['normal', 'causal', None])
52
+ parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
53
+ parser.add_argument('--selective_loss_weight', type=float, default=None)
54
+ parser.add_argument('--contrastive_weight', type=float, default=None)
55
+ parser.add_argument('--log_dir', type=str, default=None)
56
+
57
+ parser.add_argument('--warmup_type', type=str, default=None)
58
+ parser.add_argument('--warmup_min', type=float, default=0)
59
+ parser.add_argument('--warmup_ratio', type=float, default=0.05)
60
+
61
+ parser.add_argument('--ckpt_path', type=str, default=None)
62
+
63
+
64
+
65
+
66
+ parser = deepspeed.add_config_arguments(parser)
67
+ cmd_args = parser.parse_args()
68
+
69
+ freeze_support()
70
+ args = parser.parse_args()
71
+ os.environ["TOKENIZERS_PARALLELISM"] = "true" if args.tokenizer_parallel else "false"
72
+ config = get_config(args.config)
73
+ if args.exp_name is not None:
74
+ config.exp_name = args.exp_name
75
+ elif config.exp_name.__class__ != str:
76
+ config.exp_name = args.config.split(os.sep)[-1][:-4]
77
+ if args.lr is not None:
78
+ config.exp_name += f'_LR={args.lr}'
79
+ if args.selective_loss_weight is not None:
80
+ config.training.selective_loss_weight = args.selective_loss_weight
81
+ config.exp_name += f'_SLW={args.selective_loss_weight}'
82
+ if args.contrastive_weight is not None:
83
+ config.training.contrastive_weight = args.contrastive_weight
84
+ config.exp_name += f'_CTRW={args.contrastive_weight}'
85
+ if args.adding_noise is not None:
86
+ config.training.adding_noise = args.adding_noise
87
+ config.exp_name += f'_NOISE={args.adding_noise}'
88
+ if args.training_ratio < 1.0:
89
+ config.exp_name += f'_PTRAIN={args.training_ratio}'
90
+ # Done model config
91
+ generator_type = config.model.generator_type
92
+ if args.mode is not None:
93
+ config.training.mode = args.mode
94
+ if args.epoch is not None:
95
+ config.training.num_epoch = args.epoch
96
+ epoch = config.training.num_epoch
97
+ if args.log_dir is not None:
98
+ config.training.log_dir = args.log_dir
99
+ if 'llama-2' in config.model.model_name.lower():
100
+ folder_name = config.model.model_name.split('/')[-1]
101
+ config.model.model_name = os.getenv('LLAMA2_PATH')+'/'+folder_name
102
+ warmup_config = None
103
+ if args.warmup_type is not None:
104
+ warmup_config = {
105
+ "type": args.warmup_type,
106
+ "params": {
107
+ # "warmup_min_lr": args.warmup_min,
108
+ # "warmup_max_lr": args.lr,
109
+ "warmup_ratio": args.warmup_ratio
110
+ }
111
+ }
112
+ config.exp_name += f'_WP={args.warmup_type}@{args.warmup_ratio}'
113
+ if __name__ == '__main__':
114
+ train_generator(config, args.batch, args.lr, args.num_workers,
115
+ epoch, args.grad_clip, args.seed, args.save_model,
116
+ args.training_ratio, cmd_args=cmd_args, shuffle_train=args.shuffle_train,
117
+ warmup_config=warmup_config, ckpt_path=args.ckpt_path)
118
+ # num_of_gpus = torch.cuda.device_count()
119
+ # print(f"{num_of_gpus} GPUs available")
120
+ # mp.spawn(train_generator, args=(config, args.batch, args.lr, args.num_workers,
121
+ # epoch, args.grad_clip, num_of_gpus, args.seed, args.save_model,
122
+ # args.training_ratio), nprocs=num_of_gpus)
123
+
124
+ # train_generator(args.local_rank, config,
125
+ # batch_size=args.batch,
126
+ # lr=args.lr,
127
+ # num_workers=args.num_workers,
128
+ # epoch=args.epoch,
129
+ # gradient_clipping=args.grad_clip)
trainer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
trainer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
trainer/__pycache__/peft_trainer.cpython-310.pyc ADDED
Binary file (5.85 kB). View file
 
trainer/peft_trainer.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+
5
+ import deepspeed
6
+ import torch
7
+ from pytictoc import TicToc
8
+ from torch.utils.data import DistributedSampler
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from tqdm import tqdm
11
+
12
+ from dataset.dataset import PersonaChatDataset
13
+ from utils.dist_helper import setup
14
+ from utils.format_inputs import TASK_TYPE
15
+ from utils.seed_everything import seed_everything
16
+
17
+
18
+ def save_checkpoint(model, optimizer, config, filename):
19
+ torch.save({
20
+ # 'model_state_dict': model.module.state_dict(),
21
+ # 'optimizer_state_dict': optimizer.state_dict(),
22
+ 'config': config
23
+ }, filename)
24
+
25
+
26
+ def train_generator(config, batch_size, lr,
27
+ num_workers,
28
+ epoch,
29
+ gradient_clipping, seed, save_model,
30
+ training_ratio, cmd_args, shuffle_train=True,warmup_config=None,
31
+ ckpt_path=None):
32
+ with open(cmd_args.deepspeed_config) as json_file:
33
+ ds_config = json.load(json_file)
34
+ del cmd_args.deepspeed_config
35
+ ds_config['train_micro_batch_size_per_gpu'] = batch_size
36
+ ds_config['optimizer']['params']['lr'] = lr
37
+ if config.model.load_bit == 16:
38
+ ds_config['float16']['enabled'] = True
39
+ if config.model.load_bit == 'bf16':
40
+ ds_config['bf16']['enabled'] = True
41
+ if gradient_clipping > 0:
42
+ ds_config['gradient_clipping'] = gradient_clipping
43
+
44
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
45
+ if config.model.model_type == 'selective_pt':
46
+ from models.selective_llm_chat import SelectLLMChat as LLMChat
47
+ else:
48
+ from models.llm_chat import LLMChat
49
+ seed_everything(seed)
50
+ # initialize the distributed environment
51
+ # time setup function using tictoc
52
+ t = TicToc()
53
+ t.tic()
54
+ setup()
55
+ # print(f"Time for setup is {t.tocvalue()} seconds")
56
+ config.training.learning_rate = float(lr)
57
+ # Create model and move it to GPU
58
+
59
+ task_type: str = config.training.task_type
60
+ enum_task = TASK_TYPE(task_type)
61
+ train_dataset = PersonaChatDataset(config.dataset.train, max_context_turns=config.dataset.max_context_turns,
62
+ training_ratio=training_ratio,
63
+ only_longest=config.training.only_longest,
64
+ task_type=enum_task)
65
+ valid_dataset = PersonaChatDataset(config.dataset.valid, max_context_turns=config.dataset.max_context_turns,
66
+ task_type=enum_task)
67
+ from dataset.dataset import get_dataloader
68
+ if warmup_config is not None:
69
+ warmup_config["params"]['warmup_num_steps'] = int(len(train_dataset)/batch_size * warmup_config["params"]['warmup_ratio'] / world_size)
70
+ warmup_config["params"]['warmup_num_steps'] = int(len(train_dataset)/batch_size * warmup_config["params"]['warmup_ratio'] / world_size)
71
+ warmup_config["params"]['total_num_steps'] = int(len(train_dataset)/batch_size)/world_size
72
+ del warmup_config["params"]['warmup_ratio']
73
+ ds_config['scheduler'] = warmup_config
74
+ _pt_model = LLMChat(config, batch_size=batch_size, ds_config=ds_config)
75
+
76
+ # ddp_model = DDP(_pt_model, device_ids=[0], output_device=0, find_unused_parameters=False)
77
+ left_tokenizer = _pt_model.left_tokenizer
78
+ right_tokenizer = _pt_model.right_tokenizer
79
+ # So there are always training samples
80
+ right_tokenizer.truncation_side = 'left'
81
+ # If it is lengthy, cut the right side
82
+ left_tokenizer.truncation_side = 'right'
83
+ # Create distributed sampler
84
+ all_params = [p for p in _pt_model.parameters()]
85
+ require_grads = [p for p in all_params if p.requires_grad]
86
+ model_engine, optimizer, train_dataloader, _ = deepspeed.initialize(args=cmd_args,
87
+ model=_pt_model,
88
+ model_parameters=require_grads,
89
+ training_data=train_dataset,
90
+ config=ds_config,
91
+ )
92
+ if ckpt_path is not None:
93
+ model_engine.load_checkpoint(ckpt_path, load_module_strict=False, load_optimizer_states=True,
94
+ load_lr_scheduler_states=True,
95
+ load_module_only=False)
96
+
97
+ valid_sampler = DistributedSampler(valid_dataset, num_replicas=world_size, shuffle=False,
98
+ drop_last=False)
99
+
100
+ valid_dataloader = get_dataloader(valid_dataset, batch_size, shuffle=False, num_workers=num_workers,
101
+ sampler=valid_sampler)
102
+
103
+ if enum_task in [TASK_TYPE.GENERATE_RESPONSE, TASK_TYPE.GENERATE_PERSONA]:
104
+ train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, shuffle=shuffle_train,
105
+ drop_last=False)
106
+ train_dataloader = get_dataloader(train_dataset, batch_size, shuffle=False, num_workers=num_workers,
107
+ sampler=train_sampler)
108
+
109
+
110
+ # You might want to adjust this depending on your specific requirements
111
+ # scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
112
+ if config.training.log_dir.__class__ is str:
113
+ logdir = f"{config.training.log_dir}/{config.exp_name}_{time.strftime('%Y-%m-%d-%H%M')}"
114
+ else:
115
+ logdir = f"runs/{config.exp_name}_{time.strftime('%Y-%m-%d-%H%M')}"
116
+ # Tensorboard logger
117
+ writer = SummaryWriter(log_dir=logdir)
118
+ best_valid_loss = 65535
119
+ # Training Loop
120
+ counter = 0
121
+ valid_counter = 0
122
+ for _epoch in range(epoch):
123
+ model_engine.train()
124
+ total_loss = 0.0
125
+ gathered_train_loss = [torch.zeros(1, dtype=torch.float32, device=model_engine.device) for _ in range(world_size)]
126
+ train_iter = tqdm(train_dataloader, total=len(train_dataloader), desc=f'epoch: {_epoch}')
127
+ total_steps_per_epoch = len(train_dataloader)
128
+ total_steps = total_steps_per_epoch*epoch
129
+ for idx, inputs in enumerate(train_iter):
130
+ current_step = idx+_epoch*total_steps_per_epoch
131
+ current_training_percent = current_step/total_steps
132
+ model_engine.zero_grad()
133
+ loss = LLMChat.training_step(model_engine, inputs, left_tokenizer, right_tokenizer, config,
134
+ mode=config.training.mode, task_type=enum_task, training_process=current_training_percent)
135
+ skipped = False
136
+ params = []
137
+ if deepspeed.comm.get_local_rank() in [-1, 0]:
138
+ for n, p in model_engine.named_parameters():
139
+ if p.requires_grad:
140
+ params.append(p)
141
+ norm = torch.stack([p.norm() for p in params]).sum()
142
+ print(f'NORM: {norm}')
143
+ if loss.isnan():
144
+ model_engine.backward(loss.new_zeros(loss.shape, requires_grad=True))
145
+ skipped = True
146
+ print(inputs)
147
+ raise ValueError('Meet NaN in training!')
148
+ else:
149
+ model_engine.backward(loss)
150
+ if gradient_clipping > 0:
151
+ model_engine.gradient_clipping()
152
+
153
+ model_engine.step()
154
+
155
+ total_loss += loss.item()
156
+ writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/train', loss.item(), counter)
157
+ counter += 1
158
+ train_iter.set_postfix_str(f'loss: {loss.item()}'+(" (Skipped)" if skipped else ""))
159
+ outputs_valid_losses = [torch.zeros(1, dtype=torch.float32, device=model_engine.device) for _ in range(world_size)]
160
+ valid_loss = []
161
+ for inputs in tqdm(valid_dataloader, total=len(valid_dataloader), desc='valid'):
162
+ model_engine.eval()
163
+ with torch.no_grad():
164
+ loss = LLMChat.validation_step(model_engine, inputs, left_tokenizer, right_tokenizer, config,
165
+ mode=config.training.mode, task_type=enum_task)
166
+ valid_loss.append(loss.item())
167
+ writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/valid', loss.item(), valid_counter)
168
+ valid_counter += 1
169
+ deepspeed.comm.all_gather(outputs_valid_losses, torch.tensor(valid_loss).mean().to(model_engine.device))
170
+ gathered_valid_loss = torch.stack(outputs_valid_losses).mean()
171
+ deepspeed.comm.all_gather(gathered_train_loss, torch.tensor(total_loss / len(train_dataloader), device=model_engine.device))
172
+ writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/total_train', torch.stack(gathered_train_loss).mean(), _epoch)
173
+
174
+ writer.add_scalar(f'Loss-{deepspeed.comm.get_local_rank()}/total_valid', gathered_valid_loss, _epoch)
175
+ deepspeed.comm.barrier()
176
+ print(
177
+ f'\nepoch: {_epoch}, train_loss: {total_loss / len(train_dataloader)}, valid_loss: {gathered_valid_loss}\n')
178
+ if best_valid_loss > gathered_valid_loss and save_model:
179
+ # Save pt_model checkpoint
180
+ if model_engine.global_rank == 0:
181
+ print(f"Saving model checkpoint with valid loss {gathered_valid_loss}")
182
+ save_checkpoint(model_engine, optimizer, config, f'{logdir}/checkpoint_best.pth')
183
+ model_engine.save_checkpoint(f'{logdir}/ds_ckpt', tag='best', exclude_frozen_parameters=True)
184
+ best_valid_loss = gathered_valid_loss
185
+
186
+
187
+ deepspeed.comm.destroy_process_group()
utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.36 kB). View file
 
utils/__pycache__/configure_optimizers.cpython-310.pyc ADDED
Binary file (354 Bytes). View file
 
utils/__pycache__/dist_helper.cpython-310.pyc ADDED
Binary file (292 Bytes). View file
 
utils/__pycache__/format_inputs.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
utils/__pycache__/model_helpers.cpython-310.pyc ADDED
Binary file (948 Bytes). View file
 
utils/__pycache__/parser_helper.cpython-310.pyc ADDED
Binary file (542 Bytes). View file
 
utils/__pycache__/seed_everything.cpython-310.pyc ADDED
Binary file (1.78 kB). View file
 
utils/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ from dotmap import DotMap
3
+
4
+
5
+ def extend_dict(extend_me, extend_by):
6
+ if isinstance(extend_me, dict):
7
+ for k, v in extend_by.iteritems():
8
+ if k in extend_me:
9
+ extend_dict(extend_me[k], v)
10
+ else:
11
+ extend_me[k] = v
12
+ else:
13
+ if isinstance(extend_me, list):
14
+ extend_list(extend_me, extend_by)
15
+ else:
16
+ if extend_by is not None:
17
+ extend_me += extend_by
18
+
19
+
20
+ def extend_list(extend_me, extend_by):
21
+ missing = []
22
+ for item1 in extend_me:
23
+ if not isinstance(item1, dict):
24
+ continue
25
+
26
+ for item2 in extend_by:
27
+ if not isinstance(item2, dict) or item2 in missing:
28
+ continue
29
+ extend_dict(item1, item2)
30
+
31
+
32
+ def extend_compatibility_for_gated_transformer(configuration):
33
+ dict_config = configuration.toDict()
34
+ return configuration
35
+
36
+
37
+ def get_config(path):
38
+ with open(path, 'r') as file:
39
+ configuration = yaml.load(file, Loader=yaml.FullLoader)
40
+ with open('config/default.yml', 'r') as file:
41
+ base_configuration = yaml.load(file, Loader=yaml.FullLoader)
42
+ configuration = DotMap(configuration)
43
+ base_configuration = DotMap(base_configuration)
44
+ extend_dict(configuration, base_configuration)
45
+ configuration = extend_compatibility_for_gated_transformer(configuration)
46
+ return configuration
47
+
48
+
49
+ if __name__ == '__main__':
50
+ config = get_config('config/bert-base.yml')
utils/configure_optimizers.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def configure_optimizers(model, lr):
5
+ adam = torch.optim.Adam(model.parameters(), lr=lr)
6
+ return adam
utils/dist_helper.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import deepspeed
2
+
3
+
4
+ def setup():
5
+ deepspeed.init_distributed()
utils/format_inputs.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ import torch
4
+
5
+
6
+ class TASK_TYPE(Enum):
7
+ GENERATE_RESPONSE = 'generate_response'
8
+ GENERATE_PERSONA = 'generate_persona'
9
+
10
+
11
+
12
+ def format_personachat_input(batch, left_tokenizer, right_tokenizer, config, for_test=False, find_batch=False):
13
+ batch_size = len(batch['context_input'])
14
+ pad_token_id = left_tokenizer.pad_token_id
15
+ targets = [t.strip() for t in batch['target']]
16
+ eos_token = left_tokenizer.eos_token
17
+ concat_context = [' '.join(context) for context in batch['context_input']]
18
+ concat_persona = [' '.join(persona) for persona in batch['persona_list']]
19
+ concat_input = [f'#persona#{persona}#context#{context}' for persona, context in
20
+ zip(concat_persona, concat_context)]
21
+ inference_tokenized = None
22
+ bos_token = left_tokenizer.bos_token
23
+ if for_test:
24
+ inference_input = [f'#persona#{persona}#context#{context}{bos_token}' for persona, context in
25
+ zip(concat_persona, concat_context)]
26
+ inference_tokenized = left_tokenizer(inference_input, add_special_tokens=False, return_tensors='pt',
27
+ padding='max_length', truncation=True,
28
+ max_length=config.dataset.max_token_length - 16)
29
+ # processing target
30
+ _target_with_bos = [f'{bos_token}{target}{eos_token}' for target in targets]
31
+ _target_with_bos_pt = right_tokenizer(_target_with_bos,
32
+ add_special_tokens=False, return_tensors='pt', \
33
+ padding=True)
34
+ _target_pt = _target_with_bos_pt.copy()
35
+ _target_pt['input_ids'] = torch.cat((_target_pt['input_ids'][:, 1:],
36
+ _target_pt['input_ids'].new_ones(batch_size, 1) * pad_token_id), dim=1)
37
+ _target_pt['attention_mask'] = torch.cat((_target_pt['attention_mask'][:, 1:],
38
+ _target_pt['attention_mask'].new_zeros(batch_size, 1)), dim=1)
39
+ # processing concat
40
+ context_pt = left_tokenizer(concat_input, add_special_tokens=False, return_tensors='pt',
41
+ padding='max_length', truncation=True,
42
+ max_length=config.dataset.max_token_length)
43
+ input_pt = torch.cat((context_pt['input_ids'], _target_with_bos_pt['input_ids']),
44
+ dim=1)[:, -config.dataset.max_token_length:]
45
+ input_attn = torch.cat((context_pt['attention_mask'], _target_with_bos_pt['attention_mask']),
46
+ dim=1)[:, -config.dataset.max_token_length:]
47
+ lm_input = {'input_ids': input_pt, 'attention_mask': input_attn}
48
+ if find_batch:
49
+ lm_target = torch.cat((context_pt['input_ids'],
50
+ _target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
51
+ else:
52
+ lm_target = torch.cat((context_pt['input_ids'] * 0 - 1,
53
+ _target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
54
+ if for_test:
55
+ return lm_input, lm_target, inference_tokenized
56
+ return lm_input, lm_target
57
+
58
+
59
+ # Template Type:
60
+ # 0: </s>
61
+
62
+ def format_causal_personachat_input(batch, left_tokenizer, right_tokenizer, config, for_test=False,
63
+ find_batch=False, template_type=0):
64
+ template_types = [
65
+ '{cinput} R: {target}',
66
+ '{cinput} R: [COMPLETE] the answer for [COMPLETE] is {target}'
67
+ ]
68
+ bos_token = left_tokenizer.bos_token
69
+ eos_token = left_tokenizer.eos_token
70
+ batch_size = len(batch['context_input'])
71
+ pad_token_id = right_tokenizer.pad_token_id
72
+ targets = [t.strip() for t in batch['target']]
73
+ concat_context = [' '.join(context) for context in batch['context_input']]
74
+ concat_persona = [' '.join(persona) for persona in batch['persona_list']]
75
+ concat_input = [f'given persona: {persona}; context: {context}' for persona, context in
76
+ zip(concat_persona, concat_context)]
77
+ concat_input_target = [template_types[template_type].format(cinput=cinput, target=target) for cinput, target in
78
+ zip(concat_input, targets)]
79
+ bos_concat_input = [f'{bos_token}{cinput}{eos_token}' for cinput in concat_input_target]
80
+ lm_input = right_tokenizer(bos_concat_input, add_special_tokens=False, return_tensors='pt',
81
+ padding='max_length', truncation=True,
82
+ max_length=config.dataset.max_token_length)
83
+ lm_target = lm_input.copy()
84
+ lm_target = torch.cat((lm_target['input_ids'][:, 1:], lm_target['input_ids'].new_full(
85
+ (batch_size, 1), pad_token_id)), dim=1)
86
+ # lm_target['attention_mask'] = torch.cat(
87
+ # (lm_target['attention_mask'][:, 1:], lm_target['attention_mask'].new_full(
88
+ # (batch_size, 1), 0)), dim=1)
89
+ # freeze persona
90
+ if config.training.freeze_persona.__class__ is bool and config.training.freeze_persona:
91
+ for _lm_target in lm_target:
92
+ if 'given persona:' not in left_tokenizer.decode(_lm_target):
93
+ continue
94
+ _tokens = left_tokenizer.convert_ids_to_tokens(_lm_target)
95
+ _token_ids = _lm_target
96
+ _token_idx = None
97
+ for idx in range(0, len(_tokens) - 1):
98
+ if _tokens[idx].endswith('context') and _tokens[idx + 1].endswith(':'):
99
+ _token_idx = idx
100
+ break
101
+ _token_ids[idx] = left_tokenizer.pad_token_id
102
+ # freeze context
103
+ if config.training.freeze_context.__class__ is bool and config.training.freeze_context:
104
+ for _lm_target in lm_target:
105
+ _tokens = left_tokenizer.convert_ids_to_tokens(_lm_target)
106
+ _token_ids = _lm_target
107
+ _start_idx = None
108
+ _end_idx = None
109
+ for idx in range(0, len(_tokens) - 1):
110
+ if _tokens[idx].endswith('context') and _tokens[idx + 1].endswith(':'):
111
+ _start_idx = idx
112
+ if _tokens[idx].endswith('R') and _tokens[idx + 1].endswith(':'):
113
+ _end_idx = idx + 2
114
+ if _start_idx is None or _end_idx is None:
115
+ continue
116
+ for idx in range(_start_idx, _end_idx):
117
+ _token_ids[idx] = left_tokenizer.pad_token_id
118
+
119
+ if for_test:
120
+ inference_input = [template_types[template_type].format(cinput=cinput, target='') for cinput in concat_input]
121
+ bos_concat_input = [f'{bos_token}{cinput}' for cinput in inference_input]
122
+ inference_tokenized = left_tokenizer(bos_concat_input, add_special_tokens=False
123
+ , return_tensors='pt',
124
+ padding=True, truncation=True,
125
+ max_length=config.dataset.max_token_length)
126
+ return lm_input, lm_target, inference_tokenized
127
+ return lm_input, lm_target
128
+
129
+
130
+ def format_generate_persona_input(batch, left_tokenizer, right_tokenizer, config, for_test=False, find_batch=False):
131
+ batch_size = len(batch['context_input'])
132
+ pad_token_id = left_tokenizer.pad_token_id
133
+ targets = [' '.join(persona) for persona in batch['persona_list']]
134
+ eos_token = left_tokenizer.eos_token
135
+ concat_context = [' '.join(context) for context in batch['context_input']]
136
+ concat_input = [f'#context#{context}' for context in
137
+ concat_context]
138
+ inference_tokenized = None
139
+ bos_token = left_tokenizer.bos_token
140
+ if for_test:
141
+ inference_input = [f'#context#{context}{bos_token}' for context in
142
+ concat_context]
143
+ inference_tokenized = left_tokenizer(inference_input, add_special_tokens=False, return_tensors='pt',
144
+ padding='max_length', truncation=True,
145
+ max_length=config.dataset.max_token_length - 16)
146
+ # processing target
147
+ _target_with_bos = [f'{bos_token}{target}{eos_token}' for target in targets]
148
+ _target_with_bos_pt = right_tokenizer(_target_with_bos,
149
+ add_special_tokens=False, return_tensors='pt',
150
+ padding=True)
151
+ _target_pt = _target_with_bos_pt.copy()
152
+ _target_pt['input_ids'] = torch.cat((_target_pt['input_ids'][:, 1:],
153
+ _target_pt['input_ids'].new_ones(batch_size, 1) * pad_token_id), dim=1)
154
+ _target_pt['attention_mask'] = torch.cat((_target_pt['attention_mask'][:, 1:],
155
+ _target_pt['attention_mask'].new_zeros(batch_size, 1)), dim=1)
156
+ # processing concat
157
+ context_pt = left_tokenizer(concat_input, add_special_tokens=False, return_tensors='pt',
158
+ padding='max_length', truncation=True,
159
+ max_length=config.dataset.max_token_length)
160
+ input_pt = torch.cat((context_pt['input_ids'], _target_with_bos_pt['input_ids']),
161
+ dim=1)[:, -config.dataset.max_token_length:]
162
+ input_attn = torch.cat((context_pt['attention_mask'], _target_with_bos_pt['attention_mask']),
163
+ dim=1)[:, -config.dataset.max_token_length:]
164
+ lm_input = {'input_ids': input_pt, 'attention_mask': input_attn}
165
+ if find_batch:
166
+ lm_target = torch.cat((context_pt['input_ids'],
167
+ _target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
168
+ else:
169
+ lm_target = torch.cat((context_pt['input_ids'] * 0 - 1,
170
+ _target_pt['input_ids']), dim=1)[:, -config.dataset.max_token_length:]
171
+ if for_test:
172
+ return lm_input, lm_target, inference_tokenized
173
+ return lm_input, lm_target
utils/model_helpers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def print_trainable_parameters(model):
3
+ """
4
+ Prints the number of trainable parameters in the model.
5
+ """
6
+ trainable_params = 0
7
+ all_param = 0
8
+ all_param_names = []
9
+ trainable_param_names = []
10
+ prompt_weights = 0
11
+ prompt_normalizer = 0
12
+ prompt_normalizer_layer = []
13
+ soft_prompt_layers = []
14
+ for name, param in model.named_parameters():
15
+
16
+ all_param += param.numel()
17
+ all_param_names.append(name)
18
+ if param.requires_grad:
19
+ print(name)
20
+ if 'prompt_encoder.default.embedding' in name:
21
+ prompt_weights+= param.numel()
22
+ soft_prompt_layers.append(param)
23
+ if 'prompt_normalizer' in name:
24
+ prompt_normalizer += param.numel()
25
+ prompt_normalizer_layer.append(param)
26
+ trainable_params += param.numel()
27
+ trainable_param_names.append(name)
28
+ print(
29
+ f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
30
+ )
31
+ return {"trainable": trainable_params, "all": all_param, "trainable%": 100 * trainable_params / all_param}
utils/parser_helper.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def str2bool(v):
5
+ if v is None:
6
+ return None
7
+ exclusive = ['accurate', 'query', 'document']
8
+ if v in exclusive:
9
+ return v
10
+ if isinstance(v, bool):
11
+ return v
12
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
13
+ return True
14
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
15
+ return False
16
+ else:
17
+ raise argparse.ArgumentTypeError('Boolean value expected.')
utils/seed_everything.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ max_seed_value = np.iinfo(np.uint32).max
9
+ min_seed_value = np.iinfo(np.uint32).min
10
+
11
+
12
+ def seed_everything(seed: Optional[int], workers: bool = False) -> int:
13
+ """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition,
14
+ sets the following environment variables:
15
+
16
+ - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend).
17
+ - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``.
18
+
19
+ Args:
20
+ seed: the integer value seed for global random state in Lightning.
21
+ If `None`, will read seed from `PL_GLOBAL_SEED` env variable
22
+ or select it randomly.
23
+ workers: if set to ``True``, will properly configure all dataloaders passed to the
24
+ Trainer with a ``worker_init_fn``. If the user already provides such a function
25
+ for their dataloaders, setting this argument will have no influence. See also:
26
+ :func:`~lightning_fabric.utilities.seed.pl_worker_init_function`.
27
+ """
28
+
29
+ if not isinstance(seed, int):
30
+ seed = int(seed)
31
+
32
+ if not (min_seed_value <= seed <= max_seed_value):
33
+ raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
34
+
35
+ print(f"Global seed set to {seed}")
36
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
37
+ random.seed(seed)
38
+ np.random.seed(seed)
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed_all(seed)
41
+
42
+ os.environ["PL_SEED_WORKERS"] = f"{int(workers)}"
43
+
44
+ return seed