thennal commited on
Commit
d139aed
1 Parent(s): 36fcac1

Upload LLMEncoder

Browse files
config.json ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LLMEncoder"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "llmencoder.LLMEncoderConfig",
7
+ "AutoModel": "llmencoder.LLMEncoder"
8
+ },
9
+ "base_model": "microsoft/Phi-3-mini-4k-instruct",
10
+ "doc_max_length": 400,
11
+ "model_config": {
12
+ "_name_or_path": "microsoft/Phi-3-mini-4k-instruct",
13
+ "add_cross_attention": false,
14
+ "architectures": [
15
+ "Phi3ForCausalLM"
16
+ ],
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "auto_map": {
20
+ "AutoConfig": "microsoft/Phi-3-mini-4k-instruct--configuration_phi3.Phi3Config",
21
+ "AutoModelForCausalLM": "microsoft/Phi-3-mini-4k-instruct--modeling_phi3.Phi3ForCausalLM"
22
+ },
23
+ "bad_words_ids": null,
24
+ "begin_suppress_tokens": null,
25
+ "bos_token_id": 1,
26
+ "chunk_size_feed_forward": 0,
27
+ "cross_attention_hidden_size": null,
28
+ "decoder_start_token_id": null,
29
+ "diversity_penalty": 0.0,
30
+ "do_sample": false,
31
+ "early_stopping": false,
32
+ "embd_pdrop": 0.0,
33
+ "encoder_no_repeat_ngram_size": 0,
34
+ "eos_token_id": 32000,
35
+ "exponential_decay_length_penalty": null,
36
+ "finetuning_task": null,
37
+ "forced_bos_token_id": null,
38
+ "forced_eos_token_id": null,
39
+ "hidden_act": "silu",
40
+ "hidden_size": 3072,
41
+ "id2label": {
42
+ "0": "LABEL_0",
43
+ "1": "LABEL_1"
44
+ },
45
+ "initializer_range": 0.02,
46
+ "intermediate_size": 8192,
47
+ "is_decoder": false,
48
+ "is_encoder_decoder": false,
49
+ "label2id": {
50
+ "LABEL_0": 0,
51
+ "LABEL_1": 1
52
+ },
53
+ "length_penalty": 1.0,
54
+ "max_length": 20,
55
+ "max_position_embeddings": 4096,
56
+ "min_length": 0,
57
+ "model_type": "phi3",
58
+ "no_repeat_ngram_size": 0,
59
+ "num_attention_heads": 32,
60
+ "num_beam_groups": 1,
61
+ "num_beams": 1,
62
+ "num_hidden_layers": 25,
63
+ "num_key_value_heads": 32,
64
+ "num_return_sequences": 1,
65
+ "original_max_position_embeddings": 4096,
66
+ "output_attentions": false,
67
+ "output_hidden_states": false,
68
+ "output_scores": false,
69
+ "pad_token_id": 32000,
70
+ "prefix": null,
71
+ "problem_type": null,
72
+ "pruned_heads": {},
73
+ "remove_invalid_values": false,
74
+ "repetition_penalty": 1.0,
75
+ "resid_pdrop": 0.0,
76
+ "return_dict": true,
77
+ "return_dict_in_generate": false,
78
+ "rms_norm_eps": 1e-05,
79
+ "rope_scaling": null,
80
+ "rope_theta": 10000.0,
81
+ "sep_token_id": null,
82
+ "sliding_window": 2047,
83
+ "suppress_tokens": null,
84
+ "task_specific_params": null,
85
+ "temperature": 1.0,
86
+ "tf_legacy_loss": false,
87
+ "tie_encoder_decoder": false,
88
+ "tie_word_embeddings": false,
89
+ "tokenizer_class": null,
90
+ "top_k": 50,
91
+ "top_p": 1.0,
92
+ "torch_dtype": "bfloat16",
93
+ "torchscript": false,
94
+ "typical_p": 1.0,
95
+ "use_bfloat16": false,
96
+ "use_cache": true,
97
+ "vocab_size": 32064
98
+ },
99
+ "pooling_mode": "weighted_mean",
100
+ "skip_instruction": true,
101
+ "torch_dtype": "bfloat16",
102
+ "transformers_version": "4.44.2"
103
+ }
llmencoder.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from typing import Dict, List, Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.multiprocessing as mp
9
+ from peft import PeftModel
10
+ from torch import Tensor, device, nn
11
+ from tqdm.autonotebook import tqdm, trange
12
+ from transformers import (
13
+ AutoModel,
14
+ AutoConfig,
15
+ PretrainedConfig,
16
+ PreTrainedModel,
17
+ AutoTokenizer,
18
+ LlamaConfig,
19
+ MistralConfig,
20
+ GemmaConfig,
21
+ Qwen2Config,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def batch_to_device(batch, target_device: device):
28
+ """
29
+ send a pytorch batch to a device (CPU/GPU)
30
+ """
31
+ for key in batch:
32
+ if isinstance(batch[key], Tensor):
33
+ batch[key] = batch[key].to(target_device)
34
+ return batch
35
+
36
+
37
+ class LLMEncoderConfig(PretrainedConfig):
38
+ def __init__(
39
+ self,
40
+ pooling_mode: str = "weighted_mean",
41
+ max_length: int = 512,
42
+ doc_max_length: int = 400,
43
+ skip_instruction: bool = True,
44
+ **kwargs,
45
+ ):
46
+ if pooling_mode not in ["mean", "weighted_mean", "eos_token", "bos_token"]:
47
+ raise ValueError(
48
+ (f"Pooling mode {pooling_mode} is not supported.",
49
+ "Please choose one of 'mean', 'weighted_mean', 'eos_token', 'bos_token'.")
50
+ )
51
+ self.pooling_mode = pooling_mode
52
+ self.max_length = max_length
53
+ self.doc_max_length = doc_max_length
54
+ self.skip_instruction = skip_instruction
55
+ self.model_config = None
56
+ self.base_model = None
57
+
58
+ super().__init__(**kwargs)
59
+
60
+ class LLMEncoder(PreTrainedModel):
61
+ config_class = LLMEncoderConfig
62
+
63
+ def __init__(
64
+ self,
65
+ model: PreTrainedModel,
66
+ tokenizer: AutoTokenizer,
67
+ config: LLMEncoderConfig
68
+ ):
69
+ super().__init__(config)
70
+ self.model = model
71
+ self.tokenizer = tokenizer
72
+ self.pooling_mode = config.pooling_mode
73
+ self.max_length = config.max_length
74
+ self.doc_max_length = config.doc_max_length
75
+ self.skip_instruction = config.skip_instruction
76
+ self.model_config = None
77
+
78
+ @classmethod
79
+ def from_pretrained(
80
+ self,
81
+ base_model_name_or_path,
82
+ peft_model_name_or_path=None,
83
+ config=None,
84
+ **kwargs,
85
+ ):
86
+ """
87
+ Load a pretrained model from a model identifier or path.
88
+ Args:
89
+ base_model_name_or_path: Model identifier or path to pretrained model.
90
+ peft_model_name_or_path: Path to any PEFT models to apply.
91
+ Returns: L3Prune model.
92
+ """
93
+
94
+ if not config:
95
+ config = LLMEncoderConfig()
96
+
97
+ if not config.base_model:
98
+ config.base_model = base_model_name_or_path
99
+
100
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
101
+ tokenizer.pad_token = tokenizer.eos_token
102
+ tokenizer.padding_side = "left"
103
+
104
+ if config.model_config:
105
+ model_config = AutoConfig.from_pretrained(config.base_model)
106
+ model_config = model_config.from_dict(config.model_config)
107
+ else:
108
+ model_config = AutoConfig.from_pretrained(base_model_name_or_path)
109
+ config.model_config = model_config
110
+
111
+ model = AutoModel.from_pretrained(base_model_name_or_path, config=model_config, **kwargs)
112
+
113
+
114
+ if peft_model_name_or_path is not None:
115
+ model = PeftModel.from_pretrained(
116
+ model,
117
+ peft_model_name_or_path,
118
+ )
119
+ model = model.merge_and_unload()
120
+
121
+ return self(model=model, tokenizer=tokenizer, config=config)
122
+
123
+ def prune(self, percent_prune=0):
124
+ """
125
+ Prune a model to a percentage of layers of the base model. If percent_prune is equal to or greater than 1,
126
+ it is taken as the specific layer number to prune to. For example, if percent_prune=0.3, 30% of the layers will be pruned. If
127
+ percent_prune=3, the model will be pruned to 3 layers.
128
+ """
129
+ # take it as the specific layer number to prune to
130
+ if percent_prune >= 1:
131
+ new_num_layers = int(percent_prune)
132
+ else:
133
+ new_num_layers = int(self.model.config.num_hidden_layers * (1 - percent_prune))
134
+ print(f"Pruning to {new_num_layers} layer.")
135
+ self.model.layers = self.model.layers[:new_num_layers]
136
+ self.model.config.num_hidden_layers = new_num_layers
137
+ self.config.model_config.num_hidden_layers = new_num_layers
138
+
139
+ def prepare_for_tokenization(self, text):
140
+ if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
141
+ text = (
142
+ "<|start_header_id|>user<|end_header_id|>\n\n"
143
+ + text.strip()
144
+ + "<|eot_id|>"
145
+ )
146
+ return text
147
+ if self.model.config._name_or_path in [
148
+ "mistralai/Mistral-7B-Instruct-v0.2",
149
+ "meta-llama/Llama-2-7b-chat-hf",
150
+ ]:
151
+ text = "[INST] " + text.strip() + " [/INST]"
152
+ if self.model.config._name_or_path in [
153
+ "google/gemma-2-9b-it",
154
+ ]:
155
+ text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
156
+ if self.model.config._name_or_path in [
157
+ "Qwen/Qwen2-1.5B-Instruct",
158
+ "Qwen/Qwen2-7B-Instruct",
159
+ ]:
160
+ text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
161
+ if self.pooling_mode == "eos_token":
162
+ if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
163
+ text = text.strip() + "<|end_of_text|>"
164
+ elif isinstance(self.model.config, LlamaConfig) or isinstance(
165
+ self.model.config, MistralConfig
166
+ ):
167
+ text = text.strip() + " </s>"
168
+ elif isinstance(self.model.config, GemmaConfig):
169
+ text = text.strip() + "<eos>"
170
+ elif isinstance(self.model.config, Qwen2Config):
171
+ text = text.strip() + "<|endoftext|>"
172
+ return text
173
+
174
+ def tokenize(self, texts):
175
+ texts_2 = []
176
+ original_texts = []
177
+ for text in texts:
178
+ t = text.split("!@#$%^&*()")
179
+ texts_2.append(t[1] if len(t) > 1 else "")
180
+ original_texts.append("".join(t))
181
+
182
+ original = self.tokenizer(
183
+ original_texts,
184
+ return_tensors="pt",
185
+ padding=True,
186
+ truncation=True,
187
+ max_length=self.max_length,
188
+ )
189
+ embed_mask = None
190
+ for t_i, t in enumerate(texts_2):
191
+ ids = self.tokenizer(
192
+ [t],
193
+ return_tensors="pt",
194
+ padding=True,
195
+ truncation=True,
196
+ max_length=self.max_length,
197
+ add_special_tokens=False,
198
+ )
199
+ if embed_mask is None:
200
+ e_m = torch.zeros_like(original["attention_mask"][t_i])
201
+ if len(ids["input_ids"][0]) > 0:
202
+ e_m[-len(ids["input_ids"][0]) :] = torch.ones(
203
+ len(ids["input_ids"][0])
204
+ )
205
+ embed_mask = e_m.unsqueeze(0)
206
+ else:
207
+ e_m = torch.zeros_like(original["attention_mask"][t_i])
208
+ if len(ids["input_ids"][0]) > 0:
209
+ e_m[-len(ids["input_ids"][0]) :] = torch.ones(
210
+ len(ids["input_ids"][0])
211
+ )
212
+ embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
213
+
214
+ original["embed_mask"] = embed_mask
215
+ return original
216
+
217
+ def _skip_instruction(self, sentence_feature):
218
+ assert (
219
+ sentence_feature["attention_mask"].shape
220
+ == sentence_feature["embed_mask"].shape
221
+ )
222
+ sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
223
+
224
+ def forward(self, sentence_feature: Dict[str, Tensor]):
225
+ embed_mask = None
226
+ if "embed_mask" in sentence_feature:
227
+ embed_mask = sentence_feature.pop("embed_mask")
228
+ reps = self.model(**sentence_feature)
229
+ sentence_feature["embed_mask"] = embed_mask
230
+
231
+ return self.get_pooling(sentence_feature, reps.last_hidden_state)
232
+
233
+ def get_pooling(self, features, last_hidden_states): # All models padded from left
234
+ assert (
235
+ self.tokenizer.padding_side == "left"
236
+ ), "Pooling modes are implemented for padding from left."
237
+ if self.skip_instruction:
238
+ self._skip_instruction(features)
239
+ seq_lengths = features["attention_mask"].sum(dim=-1)
240
+ if self.pooling_mode == "mean":
241
+ return torch.stack(
242
+ [
243
+ last_hidden_states[i, -length:, :].mean(dim=0)
244
+ for i, length in enumerate(seq_lengths)
245
+ ],
246
+ dim=0,
247
+ )
248
+ elif self.pooling_mode == "weighted_mean":
249
+ bs, l, _ = last_hidden_states.shape
250
+ complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)
251
+ for i, seq_l in enumerate(seq_lengths):
252
+ if seq_l > 0:
253
+ complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1
254
+ complete_weights[i] /= torch.clamp(
255
+ complete_weights[i].sum(), min=1e-9
256
+ )
257
+ return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)
258
+ elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token":
259
+ return last_hidden_states[:, -1]
260
+ elif self.pooling_mode == "bos_token":
261
+ return last_hidden_states[
262
+ features["input_ids"] == self.tokenizer.bos_token_id
263
+ ]
264
+ else:
265
+ raise ValueError(f"{self.pooling_mode} is not implemented yet.")
266
+
267
+ def _convert_to_str(self, instruction, text):
268
+ tokenized_q = self.tokenizer(
269
+ text,
270
+ return_tensors="pt",
271
+ padding=True,
272
+ truncation=True,
273
+ max_length=self.max_length,
274
+ add_special_tokens=False,
275
+ )
276
+ tokenized_q_length = len(tokenized_q["input_ids"][0])
277
+
278
+ while tokenized_q_length > self.doc_max_length:
279
+ reduction_ratio = self.doc_max_length / tokenized_q_length
280
+ reduced_length = int(len(text.split()) * reduction_ratio)
281
+ text = " ".join(text.split()[:reduced_length])
282
+ tokenized_q = self.tokenizer(
283
+ text,
284
+ return_tensors="pt",
285
+ padding=True,
286
+ truncation=True,
287
+ max_length=self.max_length,
288
+ add_special_tokens=False,
289
+ )
290
+ tokenized_q_length = len(tokenized_q["input_ids"][0])
291
+
292
+ return (
293
+ f"{instruction.strip()} !@#$%^&*(){text}"
294
+ if instruction
295
+ else f"!@#$%^&*(){text}"
296
+ )
297
+
298
+ def encode(
299
+ self,
300
+ sentences: Union[str, List[str]],
301
+ batch_size: int = 32,
302
+ show_progress_bar: bool = True,
303
+ convert_to_numpy: bool = False,
304
+ convert_to_tensor: bool = False,
305
+ device: Optional[str] = None,
306
+ ):
307
+ """
308
+ Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string.
309
+ Args:
310
+ sentences: sentence or sentences to encode.
311
+ batch_size: batch size for turning sentence tokens into embeddings.
312
+ show_progress_bar: whether to show progress bars during encoding steps.
313
+ convert_to_numpy: If true, return numpy arrays instead of torch tensors.
314
+ convert_to_tensor: If true, return torch tensors (default).
315
+ device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified,
316
+ the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports
317
+ multiprocessing as currently implemented.
318
+
319
+ Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation).
320
+
321
+ """
322
+ if isinstance(sentences[0], str) and isinstance(sentences[-1], int):
323
+ sentences = [sentences]
324
+ # required for MEDI version of MTEB
325
+ if isinstance(sentences[0], str):
326
+ sentences = [[""] + [sentence] for sentence in sentences]
327
+
328
+ if device is None:
329
+ device = "cuda" if torch.cuda.is_available() else "cpu"
330
+
331
+ concatenated_input_texts = []
332
+ for sentence in sentences:
333
+ assert isinstance(sentence[0], str)
334
+ assert isinstance(sentence[1], str)
335
+ concatenated_input_texts.append(
336
+ self._convert_to_str(sentence[0], sentence[1])
337
+ )
338
+ sentences = concatenated_input_texts
339
+
340
+ self.eval()
341
+
342
+ if convert_to_tensor:
343
+ convert_to_numpy = False
344
+
345
+ length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
346
+ sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
347
+ all_embeddings = []
348
+
349
+ if torch.cuda.device_count() <= 1:
350
+ # This branch also support mps devices
351
+ self.to(device)
352
+ for start_index in trange(
353
+ 0,
354
+ len(sentences),
355
+ batch_size,
356
+ desc="Batches",
357
+ disable=not show_progress_bar,
358
+ ):
359
+ sentences_batch = sentences_sorted[
360
+ start_index : start_index + batch_size
361
+ ]
362
+ embeddings = self._encode(
363
+ sentences_batch, device=device, convert_to_numpy=convert_to_numpy
364
+ )
365
+ all_embeddings.append(embeddings)
366
+ else:
367
+
368
+ num_proc = torch.cuda.device_count()
369
+ cuda_compatible_multiprocess = mp.get_context("spawn")
370
+ with cuda_compatible_multiprocess.Pool(num_proc) as p:
371
+ sentences_batches = [
372
+ sentences_sorted[start_index : start_index + batch_size]
373
+ for start_index in range(0, len(sentences), batch_size)
374
+ ]
375
+
376
+ progress_bar = tqdm(
377
+ total=len(sentences_batches),
378
+ desc="Batches",
379
+ disable=not show_progress_bar,
380
+ )
381
+ results = []
382
+
383
+ def update(*args):
384
+ progress_bar.update()
385
+
386
+ for batch in sentences_batches:
387
+ results.append(
388
+ p.apply_async(
389
+ self._encode,
390
+ args=(batch, None, convert_to_numpy, True),
391
+ callback=update,
392
+ )
393
+ )
394
+
395
+ all_embeddings = [result.get() for result in results]
396
+ progress_bar.close()
397
+
398
+ all_embeddings = torch.cat(all_embeddings, dim=0)
399
+ all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
400
+ all_embeddings = all_embeddings.to(torch.float32)
401
+ if convert_to_numpy:
402
+ all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
403
+ return all_embeddings
404
+
405
+ def save(self, output_path, merge_before_save=False, save_config=True):
406
+ if merge_before_save and isinstance(self.model, PeftModel):
407
+ self.model = self.model.merge_and_unload()
408
+ if hasattr(self.model, "_hf_peft_config_loaded"):
409
+ self.model._hf_peft_config_loaded = False
410
+
411
+ self.model.save_pretrained(output_path)
412
+ self.tokenizer.save_pretrained(output_path)
413
+
414
+ l3prune_config = {
415
+ "pooling_mode": self.pooling_mode,
416
+ "max_length": self.max_length,
417
+ "doc_max_length": self.doc_max_length,
418
+ "skip_instruction": self.skip_instruction,
419
+ }
420
+
421
+ if save_config:
422
+ os.makedirs(output_path, exist_ok=True)
423
+ with open(f"{output_path}/l3prune_config.json", "w") as fOut:
424
+ json.dump(l3prune_config, fOut, indent=4)
425
+
426
+ def _encode(
427
+ self,
428
+ sentences_batch,
429
+ device: Optional[str] = None,
430
+ convert_to_numpy: bool = False,
431
+ multiprocessing=False,
432
+ ):
433
+ if multiprocessing:
434
+ # multiprocessing only supports CUDA devices at this time, so we ignore the value of device
435
+ # and use cuda:rank for the device
436
+ rank = mp.current_process()._identity[0]
437
+ if device is None and torch.cuda.is_available():
438
+ device = f"cuda:{rank % torch.cuda.device_count()}"
439
+
440
+ self.to(device)
441
+ features = self.tokenize(
442
+ [self.prepare_for_tokenization(sentence) for sentence in sentences_batch]
443
+ )
444
+ features = batch_to_device(features, device)
445
+
446
+ with torch.no_grad():
447
+ embeddings = self.forward(features)
448
+ embeddings = embeddings.detach()
449
+ embeddings = embeddings.cpu()
450
+
451
+ return embeddings
452
+
453
+ def _text_length(self, text: Union[List[int], List[List[int]]]):
454
+ """
455
+ Help function to get the length for the input text. Text can be either a string (which means a single text)
456
+ a list of ints (which means a single tokenized text), or a tuple of list of ints
457
+ (representing several text inputs to the model).
458
+ """
459
+ if (
460
+ isinstance(text, str)
461
+ or (isinstance(text, list) and isinstance(text[0], int))
462
+ or len(text) == 0
463
+ ): # Single text, list of ints, or empty
464
+ return len(text)
465
+ if isinstance(text, dict): # {key: value} case
466
+ return len(next(iter(text.values())))
467
+ elif not hasattr(text, "__len__"): # Object has no len() method
468
+ return 1
469
+ else:
470
+ return sum([len(t) for t in text])
471
+
472
+ def resize_token_embeddings(
473
+ self,
474
+ new_num_tokens: Optional[int] = None,
475
+ pad_to_multiple_of: Optional[int] = None,
476
+ ) -> nn.Embedding:
477
+ return self.model.resize_token_embeddings(
478
+ new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
479
+ )
480
+
481
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
482
+ self.model.gradient_checkpointing_enable(
483
+ gradient_checkpointing_kwargs=gradient_checkpointing_kwargs
484
+ )
485
+
486
+ def save_pretrained(self, save_directory, **kwargs):
487
+ self.tokenizer.save_pretrained(save_directory, **kwargs)
488
+ super().save_pretrained(save_directory, **kwargs)
489
+
490
+ def push_to_hub(self, repo_id, **kwargs):
491
+ self.tokenizer.push_to_hub(repo_id, **kwargs)
492
+ super().push_to_hub(repo_id, **kwargs)
model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dec49db72195150316916aab24b283d2e9d61de74a82407810bfa327ff7c402c
3
+ size 4972489328
model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd0c50c320b8680658ad5807d43f1a59676e72d0fc4c30c6d9413a0641e9a979
3
+ size 887153360
model.safetensors.index.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 5859624960
4
+ },
5
+ "weight_map": {
6
+ "model.embed_tokens.weight": "model-00001-of-00002.safetensors",
7
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00002.safetensors",
8
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
9
+ "model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
10
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
11
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
12
+ "model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
13
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00002.safetensors",
14
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
15
+ "model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
16
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
17
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
18
+ "model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
19
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00002.safetensors",
20
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
21
+ "model.layers.10.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
22
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
23
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
24
+ "model.layers.10.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
25
+ "model.layers.11.input_layernorm.weight": "model-00001-of-00002.safetensors",
26
+ "model.layers.11.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
27
+ "model.layers.11.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
28
+ "model.layers.11.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
29
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
30
+ "model.layers.11.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
31
+ "model.layers.12.input_layernorm.weight": "model-00001-of-00002.safetensors",
32
+ "model.layers.12.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
33
+ "model.layers.12.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
34
+ "model.layers.12.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
35
+ "model.layers.12.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
36
+ "model.layers.12.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
37
+ "model.layers.13.input_layernorm.weight": "model-00001-of-00002.safetensors",
38
+ "model.layers.13.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
39
+ "model.layers.13.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
40
+ "model.layers.13.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
41
+ "model.layers.13.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
42
+ "model.layers.13.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
43
+ "model.layers.14.input_layernorm.weight": "model-00001-of-00002.safetensors",
44
+ "model.layers.14.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
45
+ "model.layers.14.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
46
+ "model.layers.14.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
47
+ "model.layers.14.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
48
+ "model.layers.14.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
49
+ "model.layers.15.input_layernorm.weight": "model-00001-of-00002.safetensors",
50
+ "model.layers.15.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
51
+ "model.layers.15.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
52
+ "model.layers.15.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
53
+ "model.layers.15.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
54
+ "model.layers.15.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
55
+ "model.layers.16.input_layernorm.weight": "model-00001-of-00002.safetensors",
56
+ "model.layers.16.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
57
+ "model.layers.16.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
58
+ "model.layers.16.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
59
+ "model.layers.16.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
60
+ "model.layers.16.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
61
+ "model.layers.17.input_layernorm.weight": "model-00001-of-00002.safetensors",
62
+ "model.layers.17.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
63
+ "model.layers.17.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
64
+ "model.layers.17.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
65
+ "model.layers.17.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
66
+ "model.layers.17.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
67
+ "model.layers.18.input_layernorm.weight": "model-00001-of-00002.safetensors",
68
+ "model.layers.18.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
69
+ "model.layers.18.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
70
+ "model.layers.18.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
71
+ "model.layers.18.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
72
+ "model.layers.18.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
73
+ "model.layers.19.input_layernorm.weight": "model-00001-of-00002.safetensors",
74
+ "model.layers.19.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
75
+ "model.layers.19.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
76
+ "model.layers.19.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
77
+ "model.layers.19.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
78
+ "model.layers.19.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
79
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00002.safetensors",
80
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
81
+ "model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
82
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
83
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
84
+ "model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
85
+ "model.layers.20.input_layernorm.weight": "model-00001-of-00002.safetensors",
86
+ "model.layers.20.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
87
+ "model.layers.20.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
88
+ "model.layers.20.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
89
+ "model.layers.20.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
90
+ "model.layers.20.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
91
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00002.safetensors",
92
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
93
+ "model.layers.21.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
94
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
95
+ "model.layers.21.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
96
+ "model.layers.21.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
97
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00002.safetensors",
98
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
99
+ "model.layers.22.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
100
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
101
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
102
+ "model.layers.22.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
103
+ "model.layers.23.input_layernorm.weight": "model-00002-of-00002.safetensors",
104
+ "model.layers.23.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
105
+ "model.layers.23.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
106
+ "model.layers.23.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
107
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
108
+ "model.layers.23.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
109
+ "model.layers.24.input_layernorm.weight": "model-00002-of-00002.safetensors",
110
+ "model.layers.24.mlp.down_proj.weight": "model-00002-of-00002.safetensors",
111
+ "model.layers.24.mlp.gate_up_proj.weight": "model-00002-of-00002.safetensors",
112
+ "model.layers.24.post_attention_layernorm.weight": "model-00002-of-00002.safetensors",
113
+ "model.layers.24.self_attn.o_proj.weight": "model-00002-of-00002.safetensors",
114
+ "model.layers.24.self_attn.qkv_proj.weight": "model-00002-of-00002.safetensors",
115
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00002.safetensors",
116
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
117
+ "model.layers.3.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
118
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
119
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
120
+ "model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
121
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00002.safetensors",
122
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
123
+ "model.layers.4.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
124
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
125
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
126
+ "model.layers.4.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
127
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00002.safetensors",
128
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
129
+ "model.layers.5.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
130
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
131
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
132
+ "model.layers.5.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
133
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00002.safetensors",
134
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
135
+ "model.layers.6.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
136
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
137
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
138
+ "model.layers.6.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
139
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00002.safetensors",
140
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
141
+ "model.layers.7.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
142
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
143
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
144
+ "model.layers.7.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
145
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00002.safetensors",
146
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
147
+ "model.layers.8.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
148
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
149
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
150
+ "model.layers.8.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
151
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00002.safetensors",
152
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00002.safetensors",
153
+ "model.layers.9.mlp.gate_up_proj.weight": "model-00001-of-00002.safetensors",
154
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00002.safetensors",
155
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00002.safetensors",
156
+ "model.layers.9.self_attn.qkv_proj.weight": "model-00001-of-00002.safetensors",
157
+ "model.norm.weight": "model-00002-of-00002.safetensors"
158
+ }
159
+ }