sade-adrien commited on
Commit
15e8d2f
1 Parent(s): 6b9513a

Upload 2 files

Browse files
mapping_adapter_checkpoint_114000steps.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb6e95ae3b9cd81f6d5bdcad0387c65aa804ec172336db5f89ba7ad7ffc1f8d2
3
+ size 125866547
representation_mapping.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup
2
+ from torch.utils.data import DataLoader
3
+ import transformers
4
+ from sklearn.model_selection import train_test_split
5
+ from datasets import load_dataset, DatasetDict
6
+ import torch.nn as nn
7
+ import torch
8
+ import wandb
9
+ from tqdm import tqdm
10
+
11
+ args_max_epoch = 1
12
+ args_batch_size = 64
13
+ args_learning_rate = 3e-5
14
+ args_num_warmup_steps = 100
15
+ args_gradient_accumulation_steps_default = 2
16
+ adapter_hidden_dim = 4096
17
+
18
+ device = 'cuda'
19
+
20
+
21
+ def main():
22
+ wandb.init(project="MappingAdapater_training_v6", name="training_run")
23
+
24
+ model = MappingStructure(checkpointE = "sentence-transformers/stsb-roberta-large",
25
+ checkpointD = "mistralai/Mistral-7B-Instruct-v0.1",
26
+ hidden_dim = adapter_hidden_dim,
27
+ torch_dtype = torch.float16,
28
+ flash_attn = True,
29
+ ).to(device)
30
+
31
+ for n,p in model.named_parameters():
32
+ if 'mapping' not in n:
33
+ p.requires_grad = False
34
+ else:
35
+ p.requires_grad = True
36
+
37
+ dataset = load_dataset("sade-adrien/redpajama_v2_sample_10M")['train']
38
+ train_dataset, val_dataset = split_dataset(dataset, train_size=.989333)
39
+ datasets = DatasetDict({
40
+ 'train': train_dataset,
41
+ 'val': val_dataset
42
+ })
43
+
44
+ train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True)
45
+ val_dataloader = DataLoader(datasets['val'], batch_size=args_batch_size, shuffle=False)
46
+
47
+ optimizer = AdamW(model.parameters(), lr=args_learning_rate)
48
+ scheduler = get_linear_schedule_with_warmup(optimizer, args_num_warmup_steps, args_max_epoch*len(train_dataloader))
49
+
50
+ global_step = 0
51
+ for epoch in range(args_max_epoch):
52
+ train_dataloader = DataLoader(datasets['train'], batch_size=args_batch_size, shuffle=True, worker_init_fn=lambda _: torch.manual_seed(epoch))
53
+
54
+ for batch in tqdm(train_dataloader):
55
+ input_prompt = batch['raw_content']
56
+ outputs = model(input_prompt=input_prompt, compute_loss=True)
57
+ loss = outputs['loss']
58
+
59
+ # Gradient accumulation
60
+ loss = loss / args_gradient_accumulation_steps_default
61
+ loss.backward()
62
+
63
+ if (global_step + 1) % args_gradient_accumulation_steps_default == 0:
64
+ optimizer.step()
65
+ optimizer.zero_grad()
66
+ scheduler.step()
67
+
68
+
69
+ if (global_step + 1) % 2000 == 0:
70
+ torch.save({
71
+ 'epoch': epoch,
72
+ 'mapping_state_dict': model.mapping.state_dict(),
73
+ 'optimizer_state_dict': optimizer.state_dict(),
74
+ 'scheduler_state_dict': scheduler.state_dict(),
75
+ 'global_step': global_step,
76
+ }, f'models/mapping_adapter_checkpoint_{global_step + 1}steps.pth')
77
+
78
+ global_step += 1
79
+ val_loss = None
80
+ if (global_step + 1) % 8000 == 0:
81
+ model.eval()
82
+ val_loss = 0.0
83
+ with torch.no_grad():
84
+ for val_batch in tqdm(val_dataloader):
85
+ val_inputs = val_batch['raw_content']
86
+ val_outputs = model(input_prompt=val_inputs, compute_loss=True)
87
+ val_loss += val_outputs['loss']
88
+ val_loss /= len(val_dataloader)
89
+
90
+ model.train()
91
+
92
+ wandb.log({
93
+ 'step': global_step + 1,
94
+ 'learning_rate': scheduler.get_last_lr()[0],
95
+ 'train_loss': loss.item() * args_gradient_accumulation_steps_default,
96
+ 'val_loss': val_loss.item() if val_loss else None
97
+ })
98
+
99
+
100
+
101
+
102
+ def split_dataset(dataset, train_size=.9):
103
+ index = int(len(dataset) * train_size)
104
+ return dataset.select(range(index)), dataset.select(range(index, len(dataset)))
105
+
106
+ class MappingAdapter(nn.Module):
107
+ def __init__(self, input_dim, output_dim, hidden_dim):
108
+ super(MappingAdapter, self).__init__()
109
+ self.layer1 = nn.Linear(input_dim, hidden_dim)
110
+ self.layer2 = nn.Linear(hidden_dim, output_dim)
111
+ self.activation = nn.LeakyReLU(.01)
112
+
113
+ def forward(self, x):
114
+ x = self.layer1(x)
115
+ x = self.activation(x)
116
+ x = self.layer2(x)
117
+ return x
118
+
119
+ class MappingStructure(nn.Module):
120
+ def __init__(self, checkpointE, checkpointD, hidden_dim=2048, torch_dtype=torch.float32, flash_attn=False):
121
+ super(MappingStructure, self).__init__()
122
+
123
+ self.configE = AutoConfig.from_pretrained(checkpointE)
124
+ self.Encoder = AutoModel.from_pretrained(checkpointE,
125
+ low_cpu_mem_usage = True,
126
+ torch_dtype = torch_dtype,
127
+ config = self.configE
128
+ )
129
+
130
+ self.configD = AutoConfig.from_pretrained(checkpointD)
131
+ if flash_attn:
132
+ self.configD.update({'_flash_attn_2_enabled' : True})
133
+ self.Decoder = AutoModel.from_pretrained(checkpointD,
134
+ low_cpu_mem_usage = True,
135
+ torch_dtype = torch_dtype,
136
+ config = self.configD
137
+ )
138
+
139
+ self.mapping = MappingAdapter(self.configD.hidden_size, self.configE.hidden_size, hidden_dim=hidden_dim).to(torch_dtype)
140
+
141
+ self._init_tokenizers(checkpointE, checkpointD)
142
+
143
+ def _init_tokenizers(self, checkpointE, checkpointD):
144
+ self.tokenizerE = AutoTokenizer.from_pretrained(checkpointE, use_fast = False, revision = 'main', config = self.configE, padding_side='left')
145
+ self.tokenizerD = AutoTokenizer.from_pretrained(checkpointD, use_fast = False, revision = 'main', config = self.configD, padding_side='left')
146
+ self.tokenizerD.pad_token_id = self.tokenizerD.unk_token_id
147
+
148
+ def cosine_sim(self, u, v):
149
+ assert u.shape == v.shape, "u and v must have the same shape"
150
+ u_normalized = u / torch.norm(u, dim=1, keepdim=True)
151
+ v_normalized = v / torch.norm(v, dim=1, keepdim=True)
152
+
153
+ # Compute cosine similarity using dot product
154
+ return torch.sum(u_normalized * v_normalized, dim=1)
155
+
156
+
157
+ def mean_pooling(self, hidden_state, attention_mask):
158
+ token_embeddings = hidden_state
159
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
160
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
161
+
162
+
163
+ def build_batch(self, input_prompt):
164
+ size = torch.randint(1, self.configE.max_position_embeddings-2, (1,)).item()
165
+ targets = []
166
+
167
+ for prompt in input_prompt:
168
+ tokenized_input = self.tokenizerE(prompt)
169
+ tokenized_input = {'input_ids': tokenized_input['input_ids'][:size],
170
+ 'attention_mask': tokenized_input['attention_mask'][:size],
171
+
172
+ }
173
+ targets.append(tokenized_input)
174
+ targets = self.tokenizerE.pad(targets, padding=True, return_tensors='pt')
175
+
176
+ return targets
177
+
178
+
179
+ def forward(self, input_prompt, compute_loss=False):
180
+ loss = None
181
+
182
+ # Slice prompt of needed to fit encoder max position embeddings (hard constraint)
183
+ if not compute_loss:
184
+ inputs = self.tokenizerD(input_prompt, return_tensors='pt', padding=True).to(device)
185
+
186
+ hidden_state_D = self.Decoder(**inputs).last_hidden_state
187
+ hidden_state_D_mapped = self.mapping(hidden_state_D)
188
+
189
+ else:
190
+ targets = self.build_batch(input_prompt).to(device)
191
+
192
+ input_prompt_sliced = self.tokenizerE.batch_decode(targets['input_ids'], skip_special_tokens=True)
193
+ inputs = self.tokenizerD(input_prompt_sliced, return_tensors='pt', padding=True).to(device)
194
+
195
+ hidden_state_D = self.Decoder(**inputs).last_hidden_state
196
+ hidden_state_D_mapped = self.mapping(hidden_state_D)
197
+
198
+ hidden_state_E = self.Encoder(**targets).last_hidden_state
199
+
200
+ proj_E = self.mean_pooling(hidden_state_E, targets['attention_mask'])
201
+ proj_D = self.mean_pooling(hidden_state_D_mapped, inputs['attention_mask'])
202
+
203
+ loss = 1 - torch.mean(self.cosine_sim(proj_E, proj_D))
204
+
205
+ del inputs
206
+ del targets
207
+ del input_prompt_sliced
208
+ del hidden_state_E
209
+ del proj_E
210
+ del proj_D
211
+ torch.cuda.empty_cache()
212
+
213
+ return {'loss': loss,
214
+ 'last_hidden_state': hidden_state_D,
215
+ 'last_hidden_state_mapped': hidden_state_D_mapped,
216
+ }
217
+
218
+
219
+ if __name__ == '__main__':
220
+ main()