mwritescode commited on
Commit
6edaa8b
·
1 Parent(s): e12246b

Upload model

Browse files
Files changed (4) hide show
  1. config.json +52 -0
  2. generation_config.json +7 -0
  3. gpt2.py +209 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2PrefixTuningWithLMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "gpt2.GPT2PrefixTuningConfig",
9
+ "AutoModelForCausalLM": "gpt2.GPT2PrefixTuningWithLMHeadModel"
10
+ },
11
+ "bos_token_id": 50256,
12
+ "embd_pdrop": 0.1,
13
+ "eos_token_id": 50256,
14
+ "initializer_range": 0.02,
15
+ "is_flat": false,
16
+ "layer_norm_epsilon": 1e-05,
17
+ "model_type": "gpt2",
18
+ "n_ctx": 1024,
19
+ "n_embd": 1024,
20
+ "n_head": 16,
21
+ "n_inner": null,
22
+ "n_layer": 24,
23
+ "n_positions": 1024,
24
+ "n_special": 0,
25
+ "objective_type": "sentence",
26
+ "pad_token_id": 50257,
27
+ "plm_name_or_path": "gpt2-medium",
28
+ "predict_special_tokens": true,
29
+ "prefix_dropout_prob": 0.0,
30
+ "prefix_hidden_size": 512,
31
+ "prefix_len": 5,
32
+ "reorder_and_upcast_attn": false,
33
+ "resid_pdrop": 0.1,
34
+ "scale_attn_by_inverse_layer_idx": false,
35
+ "scale_attn_weights": true,
36
+ "summary_activation": null,
37
+ "summary_first_dropout": 0.1,
38
+ "summary_proj_to_labels": true,
39
+ "summary_type": "cls_index",
40
+ "summary_use_proj": true,
41
+ "task_specific_params": {
42
+ "text-generation": {
43
+ "do_sample": true,
44
+ "max_length": 50
45
+ }
46
+ },
47
+ "torch_dtype": "float32",
48
+ "transformers_version": "4.26.0",
49
+ "use_cache": true,
50
+ "use_layer_dep": false,
51
+ "vocab_size": 50258
52
+ }
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50256,
4
+ "eos_token_id": 50256,
5
+ "pad_token_id": 50257,
6
+ "transformers_version": "4.26.0"
7
+ }
gpt2.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PretrainedConfig, AutoConfig
4
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
5
+ from transformers.models.gpt2.modeling_gpt2 import GPT2PreTrainedModel, GPT2LMHeadModel
6
+
7
+ from src.utils.prefix import PrefixEncoder
8
+
9
+ class GPT2PrefixTuningConfig(PretrainedConfig):
10
+ attribute_map = {
11
+ "hidden_size": "n_embd",
12
+ "max_position_embeddings": "n_positions",
13
+ "num_attention_heads": "n_head",
14
+ "num_hidden_layers": "n_layer",
15
+ }
16
+ model_type = "gpt2"
17
+ keys_to_ignore_at_inference = ["past_key_values"]
18
+
19
+ def __init__(self,
20
+ plm_name_or_path='gpt2-medium',
21
+ prefix_len=5,
22
+ prefix_dropout_prob=0.0,
23
+ prefix_hidden_size=512,
24
+ is_flat=False,
25
+ pad_token_id=50257,
26
+ objective_type='sentence',
27
+ use_layer_dep=False,
28
+ **kwargs):
29
+ super().__init__(**kwargs)
30
+ self.plm_name_or_path = plm_name_or_path
31
+ self.prefix_len = prefix_len
32
+ self.prefix_dropout_prob = prefix_dropout_prob
33
+ self.prefix_hidden_size = prefix_hidden_size
34
+ self.is_flat = is_flat
35
+ plm_config = AutoConfig.from_pretrained(plm_name_or_path).to_dict()
36
+ del plm_config['_name_or_path']
37
+ self.update(plm_config)
38
+ self.pad_token_id = pad_token_id
39
+ self.vocab_size = self.pad_token_id + 1
40
+ self.objective_type = objective_type # or 'sentence' or 'token' which is the classical objective
41
+ self.use_layer_dep = use_layer_dep
42
+
43
+ class GPT2PrefixTuningWithLMHeadModel(GPT2PreTrainedModel):
44
+ def __init__(self, config, pretrained_model=None):
45
+ super().__init__(config)
46
+ print(config)
47
+ if pretrained_model is None:
48
+ self.pretrained_model = GPT2LMHeadModel.from_pretrained(config.plm_name_or_path, pad_token_id=config.pad_token_id)
49
+ self.pretrained_model.resize_token_embeddings(config.vocab_size)
50
+ else:
51
+ self.pretrained_model = pretrained_model
52
+
53
+ for param in self.pretrained_model.parameters():
54
+ param.requires_grad = False
55
+
56
+ self.prefix_len = config.prefix_len
57
+ self.prefix_encoder = PrefixEncoder(config)
58
+
59
+ def train(self, mode=True):
60
+ super().train(mode)
61
+ self.pretrained_model.eval()
62
+
63
+ def get_input_embeddings(self) -> nn.Module:
64
+ return self.pretrained_model.get_input_embeddings()
65
+
66
+ def get_output_embeddings(self):
67
+ return self.pretrained_model.lm_head
68
+
69
+ def set_output_embeddings(self, new_embeddings):
70
+ self.pretrained_model.set_output_embeddings(new_embeddings=new_embeddings)
71
+
72
+ def get_input_embeddings(self):
73
+ return self.pretrained_model.get_input_embeddings()
74
+
75
+ def set_input_embeddings(self, new_embeddings):
76
+ self.pretrained_model.set_input_embeddings(new_embeddings=new_embeddings)
77
+
78
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
79
+ token_type_ids = kwargs.get("token_type_ids", None)
80
+
81
+ # only last token for inputs_ids if past is defined in kwargs
82
+ if past_key_values:
83
+ input_ids = input_ids[:, -1].unsqueeze(-1)
84
+ if token_type_ids is not None:
85
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
86
+
87
+ batch_size = input_ids.shape[0]
88
+ attention_mask = kwargs.get("attention_mask", None)
89
+ position_ids = kwargs.get("position_ids", None)
90
+
91
+ if attention_mask is not None:
92
+ prefix_attention_mask = torch.ones(batch_size, self.prefix_len).to(input_ids.device)
93
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
94
+
95
+ if attention_mask is not None and position_ids is None:
96
+ # create position_ids on the fly for batch generation
97
+ position_ids = attention_mask.long().cumsum(-1) - 1
98
+ position_ids.masked_fill_(attention_mask == 0, 1)
99
+ if past_key_values:
100
+ position_ids = position_ids[:, -1].unsqueeze(-1)
101
+ else:
102
+ position_ids = None
103
+
104
+ if past_key_values is None:
105
+ past_key_values = self.prefix_encoder(batch_size=batch_size)
106
+ if position_ids is not None:
107
+ position_ids = position_ids[:, self.prefix_len:]
108
+
109
+ return {
110
+ "input_ids": input_ids,
111
+ "past_key_values": past_key_values,
112
+ "use_cache": kwargs.get("use_cache"),
113
+ "position_ids": position_ids,
114
+ "attention_mask": attention_mask,
115
+ "token_type_ids": token_type_ids,
116
+ }
117
+
118
+ def forward(
119
+ self,
120
+ input_ids,
121
+ past_key_values=None,
122
+ attention_mask=None,
123
+ token_type_ids=None,
124
+ position_ids=None,
125
+ head_mask=None,
126
+ inputs_embeds=None,
127
+ encoder_hidden_states=None,
128
+ encoder_attention_mask=None,
129
+ labels=None,
130
+ use_cache=None,
131
+ output_attentions=None,
132
+ output_hidden_states=None,
133
+ return_dict=None,
134
+ ):
135
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
136
+
137
+ if past_key_values is not None and self.training:
138
+ raise ValueError("past_key_value is dedicated to prefix tokens in this implementation. Please don't use it for anything else.")
139
+
140
+ if past_key_values is None:
141
+ batch_size = input_ids.shape[0]
142
+ past_key_values = self.prefix_encoder(batch_size=batch_size)
143
+ if attention_mask is not None:
144
+ prefix_attention_mask = torch.ones(batch_size, self.prefix_len).to(input_ids.device)
145
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
146
+
147
+ labels_for_plm = None if self.config.objective_type == 'sentence' else labels
148
+
149
+ position_ids = None if not self.training and input_ids.shape[1] == 1 else position_ids
150
+ if position_ids is not None:
151
+ position_ids = position_ids.contiguous()
152
+
153
+ transformer_outputs = self.pretrained_model(
154
+ input_ids,
155
+ past_key_values=past_key_values,
156
+ attention_mask=attention_mask,
157
+ token_type_ids=token_type_ids,
158
+ position_ids=position_ids,
159
+ head_mask=head_mask,
160
+ inputs_embeds=inputs_embeds,
161
+ encoder_hidden_states=encoder_hidden_states,
162
+ encoder_attention_mask=encoder_attention_mask,
163
+ labels=labels_for_plm,
164
+ use_cache=use_cache,
165
+ output_attentions=output_attentions,
166
+ output_hidden_states=output_hidden_states,
167
+ return_dict=return_dict,
168
+ )
169
+
170
+ if labels_for_plm is None:
171
+ lm_logits = transformer_outputs.logits if return_dict else transformer_outputs[0]
172
+
173
+ loss = None
174
+ if labels is not None:
175
+ # Shift so that tokens < n predict n
176
+ shift_logits = lm_logits[..., :-1, :].contiguous()
177
+ shift_labels = labels[..., 1:].contiguous()
178
+ loss_fct = nn.CrossEntropyLoss(reduction='none')
179
+ batch_size, seqlen, _ = shift_logits.shape
180
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
181
+ loss = loss.view(batch_size, seqlen).sum(dim=-1)
182
+ loss = loss.mean()
183
+
184
+ if not return_dict:
185
+ output = (lm_logits,) + transformer_outputs[1:]
186
+ return ((loss,) + output) if loss is not None else output
187
+
188
+ return CausalLMOutputWithCrossAttentions(
189
+ loss=loss,
190
+ logits=lm_logits,
191
+ past_key_values=transformer_outputs.past_key_values,
192
+ hidden_states=transformer_outputs.hidden_states,
193
+ attentions=transformer_outputs.attentions,
194
+ cross_attentions=transformer_outputs.cross_attentions,
195
+ )
196
+ else:
197
+ return transformer_outputs
198
+
199
+ @staticmethod
200
+ def _reorder_cache(past, beam_idx):
201
+ """
202
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
203
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
204
+ beam_idx at every generation step.
205
+ """
206
+ return tuple(
207
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
208
+ for layer_past in past
209
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecf2cdfb5b0a3726e5eebf5461a65ba9adc54ae6ca427a918f7dda0624d6c3ec
3
+ size 1547553167