nouamanetazi HF staff commited on
Commit
6e2f86a
1 Parent(s): cc50582

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: nanotron
3
+ ---
4
+
5
+ # ⚙️ Nano-Mistral
6
+
7
+ Modeling code for Mistral to use with [Nanotron](https://github.com/huggingface/nanotron/)
8
+
9
+ ## 🚀 Quickstart
10
+
11
+ ```python
12
+ # Generate a config file
13
+ python config_tiny_mistral.py
14
+
15
+
16
+ # Run training
17
+ export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
18
+ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
19
+ ```
__pycache__/config_tiny_mistral.cpython-310.pyc ADDED
Binary file (3.99 kB). View file
 
__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (2.81 kB). View file
 
__pycache__/modeling_mistral.cpython-310.pyc ADDED
Binary file (24.7 kB). View file
 
config_tiny_mistral.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information.
2
+
3
+ Usage:
4
+ ```
5
+ python config_tiny_mistral.py
6
+ ```
7
+ """
8
+ import os
9
+
10
+ from nanotron.config import (
11
+ CheckpointsArgs,
12
+ Config,
13
+ DataArgs,
14
+ GeneralArgs,
15
+ LoggingArgs,
16
+ LRSchedulerArgs,
17
+ ModelArgs,
18
+ OptimizerArgs,
19
+ ParallelismArgs,
20
+ PretrainDatasetsArgs,
21
+ RandomInit,
22
+ TokenizerArgs,
23
+ TokensArgs,
24
+ )
25
+ from nanotron.logging import human_format
26
+ from dataclasses import dataclass
27
+ from typing import Optional
28
+
29
+
30
+ @dataclass
31
+ class MistralConfig:
32
+ """Configuration for a MISTRAL model
33
+
34
+ Be careful on having a coherent typing as we use it to reconstruct the model from yaml
35
+ """
36
+
37
+ bos_token_id: int = 1
38
+ eos_token_id: int = 2
39
+ hidden_act: str = "silu"
40
+ hidden_size: int = 4096
41
+ initializer_range: float = 0.02
42
+ intermediate_size: int = 11008
43
+ is_mistral_config: bool = True # We use this help differentiate models in yaml/python conversion
44
+ max_position_embeddings: int = 2048
45
+ num_attention_heads: int = 32
46
+ num_hidden_layers: int = 32
47
+ num_key_value_heads: Optional[int] = None
48
+ pad_token_id: Optional[int] = None
49
+ pretraining_tp: int = 1
50
+ rms_norm_eps: float = 1e-6
51
+ rope_scaling: Optional[dict] = None
52
+ tie_word_embeddings: bool = False
53
+ use_cache: bool = True
54
+ vocab_size: int = 32000
55
+
56
+ def __post_init__(self):
57
+ # for backward compatibility
58
+ if self.num_key_value_heads is None:
59
+ self.num_key_value_heads = self.num_attention_heads
60
+
61
+ model_config = MistralConfig(
62
+ # Config for a tiny model model with 1.62M parameters
63
+ bos_token_id=1,
64
+ eos_token_id=2,
65
+ hidden_act="silu",
66
+ hidden_size=16,
67
+ initializer_range=0.02,
68
+ intermediate_size=64,
69
+ max_position_embeddings=256,
70
+ num_attention_heads=4,
71
+ num_hidden_layers=2,
72
+ num_key_value_heads=4,
73
+ pretraining_tp=1,
74
+ rms_norm_eps=1e-05,
75
+ rope_scaling=None,
76
+ tie_word_embeddings=True,
77
+ use_cache=True,
78
+ vocab_size=256,
79
+ )
80
+
81
+ num_params = human_format(
82
+ model_config.vocab_size * model_config.hidden_size * 2
83
+ + model_config.num_hidden_layers
84
+ * (
85
+ 3 * model_config.hidden_size * model_config.intermediate_size
86
+ + 4 * model_config.hidden_size * model_config.hidden_size
87
+ )
88
+ ).replace(".", "p")
89
+
90
+ print(f"Model has {num_params} parameters")
91
+
92
+ seed = 42
93
+
94
+ learning_rate = LRSchedulerArgs(
95
+ learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5
96
+ )
97
+
98
+ optimizer = OptimizerArgs(
99
+ zero_stage=0,
100
+ weight_decay=0.01,
101
+ clip_grad=1.0,
102
+ accumulate_grad_in_fp32=True,
103
+ adam_eps=1e-08,
104
+ adam_beta1=0.9,
105
+ adam_beta2=0.95,
106
+ torch_adam_is_fused=True,
107
+ learning_rate_scheduler=learning_rate,
108
+ )
109
+
110
+ parallelism = ParallelismArgs(
111
+ dp=2,
112
+ pp=2,
113
+ tp=2,
114
+ pp_engine="1f1b",
115
+ tp_mode="REDUCE_SCATTER",
116
+ tp_linear_async_communication=True,
117
+ recompute_granularity="selective",
118
+ )
119
+
120
+ tokens = TokensArgs(sequence_length=32, train_steps=10, micro_batch_size=2, batch_accumulation_per_replica=1)
121
+
122
+ dataset = PretrainDatasetsArgs(
123
+ hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", text_column_name="completion"
124
+ )
125
+
126
+ checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
127
+ os.makedirs(checkpoints_path, exist_ok=True)
128
+
129
+ config = Config(
130
+ general=GeneralArgs(project="debug", run="tiny_mistral", seed=seed),
131
+ checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10),
132
+ parallelism=parallelism,
133
+ model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
134
+ tokenizer=TokenizerArgs("gpt2"),
135
+ optimizer=optimizer,
136
+ logging=LoggingArgs(),
137
+ tokens=tokens,
138
+ data=DataArgs(dataset=dataset, seed=seed),
139
+ profiler=None,
140
+ )
141
+
142
+ if __name__ == "__main__":
143
+ dir = os.path.dirname(__file__)
144
+
145
+ # Save config as YAML file
146
+ config.save_as_yaml(f"{dir}/config_tiny_mistral.yaml")
147
+
148
+ # You can now train a model with this config using `/run_train.py`
config_tiny_mistral.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoints:
2
+ checkpoint_interval: 10
3
+ checkpoints_path: /fsx/nouamane/projects/nanotron/checkpoints
4
+ checkpoints_path_is_shared_file_system: false
5
+ resume_checkpoint_path: null
6
+ save_initial_state: false
7
+ data:
8
+ dataset:
9
+ dataset_overwrite_cache: false
10
+ dataset_processing_num_proc_per_process: 1
11
+ hf_dataset_config_name: null
12
+ hf_dataset_or_datasets: HuggingFaceH4/testing_alpaca_small
13
+ hf_dataset_splits: train
14
+ text_column_name: completion
15
+ num_loading_workers: 1
16
+ seed: 42
17
+ general:
18
+ benchmark_csv_path: null
19
+ consumed_train_samples: null
20
+ ignore_sanity_checks: false
21
+ project: debug
22
+ run: tiny_mistral
23
+ seed: 42
24
+ step: null
25
+ logging:
26
+ iteration_step_info_interval: 1
27
+ log_level: info
28
+ log_level_replica: info
29
+ model:
30
+ ddp_bucket_cap_mb: 25
31
+ dtype: bfloat16
32
+ init_method:
33
+ std: 0.025
34
+ make_vocab_size_divisible_by: 1
35
+ model_config:
36
+ bos_token_id: 1
37
+ eos_token_id: 2
38
+ hidden_act: silu
39
+ hidden_size: 16
40
+ initializer_range: 0.02
41
+ intermediate_size: 64
42
+ is_mistral_config: true
43
+ max_position_embeddings: 256
44
+ num_attention_heads: 4
45
+ num_hidden_layers: 2
46
+ num_key_value_heads: 4
47
+ pad_token_id: null
48
+ pretraining_tp: 1
49
+ rms_norm_eps: 1.0e-05
50
+ rope_scaling: null
51
+ tie_word_embeddings: true
52
+ use_cache: true
53
+ vocab_size: 256
54
+ optimizer:
55
+ accumulate_grad_in_fp32: true
56
+ adam_beta1: 0.9
57
+ adam_beta2: 0.95
58
+ adam_eps: 1.0e-08
59
+ clip_grad: 1.0
60
+ learning_rate_scheduler:
61
+ learning_rate: 0.0003
62
+ lr_decay_steps: 8
63
+ lr_decay_style: cosine
64
+ lr_warmup_steps: 2
65
+ lr_warmup_style: linear
66
+ min_decay_lr: 1.0e-05
67
+ torch_adam_is_fused: true
68
+ weight_decay: 0.01
69
+ zero_stage: 0
70
+ parallelism:
71
+ dp: 2
72
+ pp: 2
73
+ pp_engine: 1f1b
74
+ recompute_granularity: SELECTIVE
75
+ tp: 2
76
+ tp_linear_async_communication: true
77
+ tp_mode: REDUCE_SCATTER
78
+ profiler: null
79
+ tokenizer:
80
+ tokenizer_max_length: null
81
+ tokenizer_name_or_path: gpt2
82
+ tokenizer_revision: null
83
+ tokens:
84
+ batch_accumulation_per_replica: 1
85
+ limit_test_batches: 0
86
+ limit_val_batches: 0
87
+ micro_batch_size: 2
88
+ sequence_length: 32
89
+ train_steps: 10
90
+ val_check_interval: -1
dataloader.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nanotron.config import (
2
+ PretrainDatasetsArgs,
3
+ )
4
+ from nanotron.dataloader import (
5
+ clm_process,
6
+ dummy_infinite_data_generator,
7
+ get_datasets,
8
+ get_train_dataloader,
9
+ )
10
+ from nanotron.logging import log_rank
11
+ from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks
12
+ from nanotron.trainer import DistributedTrainer
13
+ from nanotron.utils import (
14
+ main_rank_first,
15
+ )
16
+ from nanotron import logging
17
+
18
+ try:
19
+ from huggingface_hub import __version__ as hf_hub_version
20
+ from transformers import AutoTokenizer
21
+ from transformers import __version__ as tf_version
22
+ except ImportError:
23
+ hf_hub_version = None
24
+ tf_version = None
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ def get_dataloader(trainer: DistributedTrainer):
30
+ """Returns a dataloader for training."""
31
+
32
+ # First, we need to know which ranks to feed the dataloader to
33
+ input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model)
34
+
35
+ # Case 1: Dummy data generator
36
+ if trainer.config.data.dataset is None:
37
+ log_rank("Using dummy data generator", logger=logger, level=logging.INFO, rank=0)
38
+ dataloader = dummy_infinite_data_generator(
39
+ micro_batch_size=trainer.micro_batch_size,
40
+ sequence_length=trainer.sequence_length,
41
+ input_pp_rank=input_pp_rank,
42
+ output_pp_rank=output_pp_rank,
43
+ vocab_size=trainer.model_config.vocab_size,
44
+ seed=trainer.config.data.seed,
45
+ parallel_context=trainer.parallel_context,
46
+ )()
47
+
48
+ # Case 2: HuggingFace datasets
49
+ elif isinstance(trainer.config.data.dataset, PretrainDatasetsArgs):
50
+ log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0)
51
+ tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path
52
+ log_rank(
53
+ f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}",
54
+ logger=logger,
55
+ level=logging.INFO,
56
+ rank=0,
57
+ )
58
+
59
+ # We need to the 1st device to process dataset and cache it, then other devices load from cache
60
+ with main_rank_first(trainer.parallel_context.world_pg):
61
+ # TODO @nouamanetazi: this may timeout before 1st device finishes processing dataset. Can we have a ctxmanager to modify timeout?
62
+ # TODO: generalise to include for validation/test splits
63
+
64
+ # We load the raw dataset
65
+ raw_dataset = get_datasets(
66
+ hf_dataset_or_datasets=trainer.config.data.dataset.hf_dataset_or_datasets,
67
+ splits=trainer.config.data.dataset.hf_dataset_splits,
68
+ )["train"]
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+ tokenizer.padding_side = "left"
73
+
74
+ # We apply the Causal Language Modeling preprocessing
75
+ train_dataset = clm_process(
76
+ raw_dataset=raw_dataset,
77
+ tokenizer=tokenizer,
78
+ text_column_name=trainer.config.data.dataset.text_column_name,
79
+ dataset_processing_num_proc_per_process=trainer.config.data.dataset.dataset_processing_num_proc_per_process,
80
+ dataset_overwrite_cache=trainer.config.data.dataset.dataset_overwrite_cache,
81
+ sequence_length=trainer.sequence_length,
82
+ )
83
+
84
+ # We load the processed dataset on the ranks requiring it
85
+ dataloader = get_train_dataloader(
86
+ train_dataset=train_dataset,
87
+ sequence_length=trainer.sequence_length,
88
+ parallel_context=trainer.parallel_context,
89
+ input_pp_rank=input_pp_rank,
90
+ output_pp_rank=output_pp_rank,
91
+ micro_batch_size=trainer.micro_batch_size,
92
+ consumed_train_samples=trainer.consumed_train_samples,
93
+ dataloader_num_workers=trainer.config.data.num_loading_workers,
94
+ seed_worker=trainer.config.data.seed,
95
+ dataloader_drop_last=True,
96
+ )
97
+ # Check if we have enough samples for train_steps
98
+ assert (
99
+ trainer.config.tokens.train_steps - trainer.start_iteration_step
100
+ ) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size() < len(dataloader), (
101
+ f"Dataset is too small for steps ({len(dataloader)} < {(trainer.config.tokens.train_steps - trainer.start_iteration_step) * trainer.global_batch_size // trainer.parallel_context.dp_pg.size()}), "
102
+ f"Try train_steps<={len(dataloader) * trainer.parallel_context.dp_pg.size() // trainer.global_batch_size + trainer.start_iteration_step}"
103
+ )
104
+ else:
105
+ raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {trainer.config.data.dataset}")
106
+
107
+ return dataloader
modeling_mistral.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Mistral model.
16
+ """
17
+ from typing import Dict, Optional, Union
18
+
19
+ import torch
20
+ from flash_attn import bert_padding
21
+ from flash_attn.flash_attn_interface import (
22
+ flash_attn_varlen_func,
23
+ flash_attn_with_kvcache,
24
+ )
25
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
26
+ from torch import nn
27
+ from transformers import MistralConfig
28
+ from transformers.activations import ACT2FN
29
+
30
+ from nanotron import distributed as dist
31
+ from nanotron import logging
32
+ from nanotron.config import ParallelismArgs, RecomputeGranularity
33
+ from nanotron.nn.layer_norm import TritonRMSNorm
34
+ from nanotron.logging import log_rank
35
+ from nanotron.models import NanotronModel
36
+ from nanotron.parallel import ParallelContext
37
+ from nanotron.parallel.parameters import NanotronParameter
38
+ from nanotron.parallel.pipeline_parallel.block import (
39
+ PipelineBlock,
40
+ TensorPointer,
41
+ )
42
+ from nanotron.parallel.pipeline_parallel.p2p import P2P
43
+ from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
44
+ from nanotron.parallel.tensor_parallel.nn import (
45
+ TensorParallelColumnLinear,
46
+ TensorParallelEmbedding,
47
+ TensorParallelLinearMode,
48
+ TensorParallelRowLinear,
49
+ )
50
+ from nanotron.random import RandomStates
51
+ from nanotron.utils import checkpoint_method
52
+ from nanotron.generation.generate_store import AttachableStore
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+ class RotaryEmbedding(nn.Module):
58
+ def __init__(self, dim: int, end: int, theta: float = 10000.0):
59
+ super().__init__()
60
+ assert dim % 2 == 0
61
+ self.dim = dim
62
+ self.end = end
63
+ self.theta = theta
64
+ # TODO @nouamane: Figure out why we can't set `DTypeInvariantTensor` ...
65
+ # TODO @thomasw21: Complex buffers break DDP, instead we store float and view them as complex
66
+ self.freqs_cis: torch.Tensor
67
+ self._initialized_buffer = False
68
+
69
+ def init_rotary_embeddings(self):
70
+ if self._initialized_buffer is True:
71
+ # Buffer if already initialized
72
+ return
73
+ self.register_buffer(
74
+ "freqs_cis",
75
+ torch.empty(self.end, self.dim // 2, 2, dtype=torch.float, device="cuda"),
76
+ persistent=False,
77
+ )
78
+ assert self.freqs_cis.device.type == "cuda"
79
+ # TODO @nouamane: One we figure out how to do the DTypeInvariantTensor, this can be removed and changed to an assert
80
+ if self.freqs_cis.dtype != torch.float:
81
+ self.freqs_cis = self.freqs_cis.to(torch.float)
82
+ assert self.freqs_cis.dtype == torch.float
83
+ freqs = 1.0 / (
84
+ self.theta
85
+ ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim)
86
+ )
87
+ t = torch.arange(self.end, device="cuda")
88
+ freqs = torch.outer(t, freqs).float()
89
+ complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
90
+ freqs = torch.view_as_real(complex_freqs)
91
+ self.freqs_cis.copy_(freqs)
92
+ self._initialized_buffer = True
93
+
94
+ def forward(
95
+ self,
96
+ x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
97
+ position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
98
+ ):
99
+ batch_size, seq_length, num_heads, inner_dim = x.shape
100
+ while (
101
+ position_ids is not None and position_ids[-1, -1] >= self.end
102
+ ) or seq_length >= self.end: # TODO @nouamane: check if this causes cpu-gpu sync
103
+ self.end *= 2
104
+ self._initialized_buffer = False
105
+ if self._initialized_buffer is False:
106
+ print(f"Initializing rotary embeddings with end={self.end}")
107
+ self.init_rotary_embeddings()
108
+ dtype = x.dtype
109
+ assert inner_dim % 2 == 0
110
+ x = x.view(
111
+ batch_size, seq_length, num_heads, inner_dim // 2, 2
112
+ ) # [batch_size, q_length, num_heads, inner_dim]
113
+ if x.dtype == torch.bfloat16:
114
+ x = x.float()
115
+ complex_x = torch.view_as_complex(x) # [batch_size, q_length, num_heads, inner_dim // 2]
116
+ if position_ids is None:
117
+ freqs_cis = self.freqs_cis[None, :seq_length, None, :]
118
+ else:
119
+ # TODO(kunhao): Should None follow the num_heads dimension?
120
+ if position_ids[-1, -1] < 0 or position_ids[-1, -1] >= self.end: # Quick test hopefully
121
+ raise ValueError(f"Position ids must be in the range [0, {self.end}), but got {position_ids}")
122
+ freqs_cis = self.freqs_cis[position_ids][:, :, None, :]
123
+ complex_freqs = torch.view_as_complex(freqs_cis)
124
+ x_out = torch.view_as_real(complex_x * complex_freqs).view(batch_size, seq_length, num_heads, inner_dim)
125
+ return x_out.type(dtype)
126
+
127
+
128
+ class GLUActivation(nn.Module):
129
+ def __init__(self, act_fn_name: str):
130
+ super().__init__()
131
+ self.act = ACT2FN[act_fn_name]
132
+
133
+ def forward(self, merged_states: torch.Tensor):
134
+ gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1)
135
+ return self.act(gate_states) * up_states
136
+
137
+
138
+ class MLP(nn.Module):
139
+ def __init__(
140
+ self,
141
+ config: MistralConfig,
142
+ parallel_config: Optional[ParallelismArgs],
143
+ tp_pg: dist.ProcessGroup,
144
+ ):
145
+ super().__init__()
146
+
147
+ # TODO @thomasw21: refactor so that we store that default in a single place.
148
+ tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
149
+ tp_linear_async_communication = (
150
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
151
+ )
152
+
153
+ gate_up_contiguous_chunks = (
154
+ config.intermediate_size, # shape of gate_linear
155
+ config.intermediate_size, # shape of up_linear
156
+ )
157
+ self.gate_up_proj = TensorParallelColumnLinear(
158
+ config.hidden_size,
159
+ 2 * config.intermediate_size,
160
+ pg=tp_pg,
161
+ mode=tp_mode,
162
+ bias=False,
163
+ async_communication=tp_linear_async_communication,
164
+ contiguous_chunks=gate_up_contiguous_chunks,
165
+ )
166
+
167
+ self.down_proj = TensorParallelRowLinear(
168
+ config.intermediate_size,
169
+ config.hidden_size,
170
+ pg=tp_pg,
171
+ mode=tp_mode,
172
+ bias=False,
173
+ async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
174
+ )
175
+ # TODO @nouamane: why can't we torch.jit.script GLUActivation?
176
+ self.split_silu_mul = GLUActivation(config.hidden_act)
177
+
178
+ def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
179
+ merged_states = self.gate_up_proj(hidden_states)
180
+ hidden_states = self.down_proj(self.split_silu_mul(merged_states))
181
+ return {"hidden_states": hidden_states}
182
+
183
+
184
+ class CoreAttention(nn.Module):
185
+ def __init__(self, config: MistralConfig, parallel_config: Optional[ParallelismArgs], layer_idx: int):
186
+ super().__init__()
187
+ # TODO @thomasw21: GPT has a weird `d_kv` config which I'm guessing is essentically a `d_qkv`
188
+ assert (
189
+ config.hidden_size % config.num_attention_heads == 0
190
+ ), f"Hidden size {config.hidden_size} must be divisible by number of attention heads {config.num_attention_heads}."
191
+ self.d_qk = config.hidden_size // config.num_attention_heads
192
+ self.d_v = config.hidden_size // config.num_attention_heads
193
+
194
+ self.checkpoint_attention = False # Because flash_attn already does checkpointing
195
+
196
+ @checkpoint_method(attr_name="checkpoint_attention")
197
+ def forward(
198
+ self,
199
+ query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim]
200
+ key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
201
+ value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim]
202
+ q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size)
203
+ kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size)
204
+ ):
205
+ # TODO @thomasw21: Compute once, instead of computing for each layers.
206
+ cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
207
+ cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device)
208
+ torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:])
209
+ torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:])
210
+
211
+ # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
212
+ # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
213
+ causal = False if q_sequence_mask.shape[1] == 1 else True
214
+ attn_output = flash_attn_varlen_func(
215
+ q=query_states,
216
+ k=key_states,
217
+ v=value_states,
218
+ cu_seqlens_q=cu_seqlens_q,
219
+ cu_seqlens_k=cu_seqlens_k,
220
+ max_seqlen_q=q_sequence_mask.shape[1],
221
+ max_seqlen_k=kv_sequence_mask.shape[1],
222
+ dropout_p=0.0,
223
+ softmax_scale=None, # This already defaults to the scale I'm interested in
224
+ causal=causal,
225
+ return_attn_probs=False,
226
+ )
227
+
228
+ return attn_output
229
+
230
+
231
+ def pad_to_right(tensor, mask, new_tensor=None):
232
+ """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states)
233
+ Args:
234
+ tensor: (batch_size, seqlen, d1, d2)
235
+ mask: (batch_size, seqlen)
236
+ new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
237
+ Returns:
238
+ new_tensor: (batch_size, new_tensor_seqlen, d1, d2)
239
+ right_padded_mask: (batch_size, seqlen)
240
+ """
241
+ # First, we need to find the number of padding for each row
242
+ unpad_seqlens = mask.sum(1)
243
+ # Then, we need to find the maximum length of the tensor
244
+ max_seqlen = mask.shape[1]
245
+ # We can then create the indices to select the padded values
246
+ # The indices are the same for each row
247
+ indices = torch.arange(max_seqlen, device=mask.device)
248
+ # We can then create the mask for the padded values
249
+ right_padded_mask = indices < unpad_seqlens[:, None]
250
+ # We select the useful values
251
+ useful_values = tensor[mask]
252
+ # We create the new tensor (if not provided)
253
+ new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor
254
+ # We fill the new tensor with the useful values
255
+ new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values
256
+ return new_tensor, right_padded_mask
257
+
258
+
259
+ class CausalSelfAttention(nn.Module, AttachableStore):
260
+ def __init__(
261
+ self,
262
+ config: MistralConfig,
263
+ parallel_config: Optional[ParallelismArgs],
264
+ tp_pg: dist.ProcessGroup,
265
+ layer_idx: int,
266
+ ):
267
+ super().__init__()
268
+ # Tensor parallel considerations: We split tensors along head dimension
269
+ assert (
270
+ config.num_attention_heads % tp_pg.size() == 0
271
+ ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})."
272
+ try:
273
+ assert (
274
+ config.num_key_value_heads % tp_pg.size() == 0
275
+ ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})."
276
+ except AttributeError:
277
+ log_rank(
278
+ "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads",
279
+ logger=logger,
280
+ level=logging.WARNING,
281
+ rank=0,
282
+ )
283
+ # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads
284
+ config.num_key_value_heads = config.num_attention_heads
285
+ assert (
286
+ config.num_attention_heads % config.num_key_value_heads == 0
287
+ ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})."
288
+ self.n_local_q_heads = config.num_attention_heads // tp_pg.size()
289
+ self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size()
290
+ self.n_repeats = config.num_attention_heads // config.num_key_value_heads
291
+ self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not
292
+ self.d_qk = config.hidden_size // config.num_attention_heads
293
+ self.d_v = config.hidden_size // config.num_attention_heads
294
+ self.d_model = config.hidden_size
295
+
296
+ # TODO @thomasw21: refactor so that we store that default in a single place.
297
+ tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
298
+ tp_linear_async_communication = (
299
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
300
+ )
301
+
302
+ # build the slice config for self.qkv for save/load
303
+ # shard are done within the contiguous chunk
304
+ qkv_contiguous_chunks = (
305
+ config.num_attention_heads * self.d_qk, # shape of q
306
+ config.num_key_value_heads * self.d_qk, # shape of k
307
+ config.num_key_value_heads * self.d_qk, # shape of v
308
+ )
309
+ self.qkv_proj = TensorParallelColumnLinear(
310
+ self.d_model,
311
+ config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
312
+ pg=tp_pg,
313
+ mode=tp_mode,
314
+ bias=False,
315
+ async_communication=tp_linear_async_communication,
316
+ contiguous_chunks=qkv_contiguous_chunks,
317
+ )
318
+ # TODO(kunhao): We want to have only one version per device and not one version per layer.
319
+ self.rotary_embedding = RotaryEmbedding(
320
+ dim=self.d_qk,
321
+ end=config.max_position_embeddings,
322
+ )
323
+
324
+ # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
325
+ self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, interleaved=True)
326
+
327
+ self.o_proj = TensorParallelRowLinear(
328
+ config.num_attention_heads * self.d_qk,
329
+ self.d_model,
330
+ pg=tp_pg,
331
+ mode=tp_mode,
332
+ bias=False,
333
+ async_communication=tp_linear_async_communication,
334
+ )
335
+
336
+ self.attention = CoreAttention(
337
+ config,
338
+ parallel_config=parallel_config,
339
+ layer_idx=layer_idx,
340
+ )
341
+
342
+ self.prefill_kv_len = (
343
+ config.max_position_embeddings
344
+ ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings
345
+
346
+ def forward(
347
+ self,
348
+ hidden_states, # [seq_length, batch_size, hidden_size]
349
+ sequence_mask, # [batch_size, seq_length]
350
+ ):
351
+ qkv_states = self.qkv_proj(
352
+ hidden_states
353
+ ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk]
354
+ q_length, batch_size, _ = qkv_states.shape
355
+
356
+ if self.is_gqa:
357
+ query_states, key_states, value_states = torch.split(
358
+ qkv_states,
359
+ [
360
+ self.n_local_q_heads * self.d_qk,
361
+ self.n_local_kv_heads * self.d_qk,
362
+ self.n_local_kv_heads * self.d_qk,
363
+ ],
364
+ dim=-1,
365
+ )
366
+
367
+ query_states = (
368
+ query_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk)
369
+ )
370
+ key_states = (
371
+ key_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
372
+ )
373
+ value_states = (
374
+ value_states.transpose(0, 1).contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk)
375
+ )
376
+ else:
377
+ query_states, key_states, value_states = (
378
+ qkv_states.view(q_length, batch_size, 3, self.n_local_q_heads, self.d_qk)
379
+ .permute(2, 1, 0, 3, 4)
380
+ .contiguous()
381
+ ) # [3, batch_size, seq_length, n_local_q_heads, d_qk]
382
+
383
+ store = self.get_local_store()
384
+ if store is not None: # Inference case
385
+ # Double check that we use store only at inference time
386
+ assert key_states.requires_grad is False
387
+ assert value_states.requires_grad is False
388
+ print("Using store")
389
+ if "position_offsets" in store:
390
+ old_position_offsets = store["position_offsets"]
391
+ position_ids = old_position_offsets[:, None] + sequence_mask
392
+ else:
393
+ position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1
394
+ position_offsets = position_ids[:, -1]
395
+
396
+ # Compute rotary embeddings
397
+ # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
398
+ old_rotary_embed_end = self.rotary_embedding.end
399
+ query_states = self.rotary_embedding(query_states, position_ids=position_ids)
400
+ key_states = self.rotary_embedding(key_states, position_ids=position_ids)
401
+
402
+ if "key" not in store:
403
+ # First inference iteration (Prefill)
404
+ # TODO @nouamane: support custom masking
405
+ # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted
406
+ # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence)
407
+ assert ~(
408
+ sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False
409
+ ).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing"
410
+
411
+ # preallocate k_cache, v_cache to self.prefill_kv_len
412
+ k_cache = torch.zeros(
413
+ (
414
+ batch_size,
415
+ self.prefill_kv_len,
416
+ self.n_local_kv_heads,
417
+ self.d_qk,
418
+ ),
419
+ dtype=query_states.dtype,
420
+ device=query_states.device,
421
+ )
422
+ v_cache = torch.zeros(
423
+ (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v),
424
+ dtype=query_states.dtype,
425
+ device=query_states.device,
426
+ )
427
+ # Remove pad tokens from key_states and concatenate samples in key_unpad
428
+ # cu_seqlens_k is the cumulative sequence lengths of key_states
429
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
430
+ query_states,
431
+ sequence_mask,
432
+ )
433
+ (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
434
+ key_states, sequence_mask
435
+ )
436
+ (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask)
437
+
438
+ output_unpad = flash_attn_varlen_func(
439
+ q=query_unpad, # (total_q, n_local_q_heads, d_qk)
440
+ k=key_unpad, # (total_kv, n_local_kv_heads, d_qk)
441
+ v=value_unpad, # (total_kv, n_local_kv_heads, d_v)
442
+ cu_seqlens_q=cu_seqlens_q,
443
+ cu_seqlens_k=cu_seqlens_k,
444
+ max_seqlen_q=max_seqlen_q,
445
+ max_seqlen_k=max_seqlen_k,
446
+ dropout_p=0.0,
447
+ softmax_scale=None,
448
+ causal=True, # True in prefill phase, False in subsequent phases
449
+ return_attn_probs=False,
450
+ ) # (total_unpadded, n_local_q_heads, d_v)
451
+
452
+ attention_output = bert_padding.pad_input(
453
+ output_unpad, indices_q, batch_size, q_length
454
+ ) # (batch_size, q_length, n_local_q_heads, d_v)
455
+
456
+ pad_to_right(key_states, sequence_mask, new_tensor=k_cache)
457
+ pad_to_right(value_states, sequence_mask, new_tensor=v_cache)
458
+
459
+ else:
460
+ # Pull pre-computed key/value states
461
+ # Subsequent inference iterations (q_length=1)
462
+ k_cache = store["key"]
463
+ v_cache = store["value"]
464
+
465
+ # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values"
466
+ # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache
467
+ if self.rotary_embedding.end > old_rotary_embed_end:
468
+ k_cache = torch.cat(
469
+ [
470
+ k_cache,
471
+ torch.zeros(
472
+ (
473
+ batch_size,
474
+ self.rotary_embedding.end - old_rotary_embed_end,
475
+ self.n_local_kv_heads,
476
+ self.d_qk,
477
+ ),
478
+ dtype=query_states.dtype,
479
+ device=query_states.device,
480
+ ),
481
+ ],
482
+ dim=1,
483
+ )
484
+
485
+ v_cache = torch.cat(
486
+ [
487
+ v_cache,
488
+ torch.zeros(
489
+ (
490
+ batch_size,
491
+ self.rotary_embedding.end - old_rotary_embed_end,
492
+ self.n_local_kv_heads,
493
+ self.d_v,
494
+ ),
495
+ dtype=query_states.dtype,
496
+ device=query_states.device,
497
+ ),
498
+ ],
499
+ dim=1,
500
+ )
501
+
502
+ assert (
503
+ k_cache.shape[1] == self.rotary_embedding.end
504
+ ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
505
+ assert (
506
+ v_cache.shape[1] == self.rotary_embedding.end
507
+ ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}"
508
+
509
+ # [batch_size, seq_length, num_heads, d_qk]
510
+ query_states = query_states.view(
511
+ batch_size, q_length, self.n_local_q_heads, self.d_qk
512
+ ) # [batch_size, q_length, self.n_heads, d_qk]
513
+ kv_length = key_states.shape[1]
514
+ key_states = key_states.view(
515
+ batch_size, kv_length, self.n_local_kv_heads, self.d_qk
516
+ ) # [batch_size, kv_length, self.n_heads, d_qk]
517
+ value_states = value_states.view(
518
+ batch_size, kv_length, self.n_local_kv_heads, self.d_v
519
+ ) # [batch_size, kv_length, self.n_heads, d_v]
520
+
521
+ attention_output = flash_attn_with_kvcache(
522
+ query_states,
523
+ k_cache,
524
+ v_cache,
525
+ key_states,
526
+ value_states,
527
+ rotary_cos=None,
528
+ rotary_sin=None,
529
+ # TODO @nouamane: seems like this doesnt help to indicate padding in (for first iteration it's just 0)
530
+ cache_seqlens=position_offsets.contiguous(),
531
+ softmax_scale=None,
532
+ causal=True,
533
+ rotary_interleaved=False, # GPT-NeoX style
534
+ )
535
+
536
+ store.update(
537
+ {
538
+ "key": k_cache, # flash-attn has updated with new key_states using cache_seqlens
539
+ "value": v_cache,
540
+ "position_offsets": position_offsets,
541
+ }
542
+ )
543
+
544
+ else: # Training case
545
+ # Apply rotary embeddings to query/key states
546
+ # NOTE: The layout is different from models/mistral.py which is [batch_size, num_heads, seq_length, d_qk]
547
+ # Here it is, [batch_size, seq_length, num_heads, d_qk]
548
+ # [2, batch_size, seq_length, num_heads, d_qk]
549
+ key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
550
+ # [batch_size, seq_length, 2, num_heads, d_qk]
551
+ key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous()
552
+ query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states)
553
+ # [batch_size, seq_length, num_heads, d_qk]
554
+ key_states, value_states = torch.split(key_value_states, 1, dim=2)
555
+
556
+ q_sequence_mask = sequence_mask
557
+ kv_sequence_mask = sequence_mask
558
+
559
+ kv_length = key_states.shape[1]
560
+ # [batch_size, seq_length, num_heads, d_qk]
561
+ # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func`
562
+ query_states = query_states.view(
563
+ batch_size * q_length, self.n_local_q_heads, self.d_qk
564
+ ) # [batch_size * q_length, self.n_heads, d_qk]
565
+
566
+ key_states = key_states.view(
567
+ batch_size * kv_length, self.n_local_kv_heads, self.d_qk
568
+ ) # [batch_size * kv_length, self.n_heads, d_qk]
569
+ value_states = value_states.view(
570
+ batch_size * kv_length, self.n_local_kv_heads, self.d_v
571
+ ) # [batch_size * kv_length, self.n_heads, d_v]
572
+
573
+ attention_output = self.attention(
574
+ query_states=query_states,
575
+ key_states=key_states,
576
+ value_states=value_states,
577
+ q_sequence_mask=q_sequence_mask,
578
+ kv_sequence_mask=kv_sequence_mask,
579
+ )
580
+
581
+ attention_output = (
582
+ attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1)
583
+ )
584
+ output = self.o_proj(attention_output)
585
+
586
+ return {"hidden_states": output, "sequence_mask": sequence_mask}
587
+
588
+
589
+ class MistralDecoderLayer(nn.Module):
590
+ def __init__(
591
+ self,
592
+ config: MistralConfig,
593
+ parallel_config: Optional[ParallelismArgs],
594
+ tp_pg: dist.ProcessGroup,
595
+ layer_idx: int,
596
+ ):
597
+ super().__init__()
598
+ self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
599
+ self.attn = CausalSelfAttention(
600
+ config=config,
601
+ parallel_config=parallel_config,
602
+ tp_pg=tp_pg,
603
+ layer_idx=layer_idx,
604
+ )
605
+
606
+ self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
607
+ self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
608
+
609
+ def forward(
610
+ self,
611
+ hidden_states: Union[torch.Tensor, TensorPointer],
612
+ sequence_mask: Union[torch.Tensor, TensorPointer],
613
+ ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
614
+ residual = hidden_states
615
+ hidden_states = self.input_layernorm(hidden_states)
616
+
617
+ output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask)
618
+ hidden_states = output["hidden_states"]
619
+ hidden_states = hidden_states + residual
620
+
621
+ residual = hidden_states
622
+ hidden_states = self.post_attention_layernorm(hidden_states)
623
+ hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
624
+ hidden_states = hidden_states + residual
625
+
626
+ return {
627
+ "hidden_states": hidden_states,
628
+ "sequence_mask": output["sequence_mask"],
629
+ }
630
+
631
+
632
+ class Embedding(nn.Module, AttachableStore):
633
+ def __init__(self, tp_pg: dist.ProcessGroup, config: MistralConfig, parallel_config: Optional[ParallelismArgs]):
634
+ super().__init__()
635
+ self.token_embedding = TensorParallelEmbedding(
636
+ num_embeddings=config.vocab_size,
637
+ embedding_dim=config.hidden_size,
638
+ padding_idx=config.pad_token_id,
639
+ pg=tp_pg,
640
+ mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE,
641
+ )
642
+ self.pg = tp_pg
643
+
644
+ def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length]
645
+ store = self.get_local_store()
646
+ if store is not None:
647
+ if "past_length" in store:
648
+ past_length = store["past_length"]
649
+ else:
650
+ past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0])
651
+
652
+ cumsum_mask = input_mask.cumsum(-1, dtype=torch.long)
653
+ # Store new past_length in store
654
+ store["past_length"] = past_length + cumsum_mask[:, -1]
655
+
656
+ # Format input in `[seq_length, batch_size]` to support high TP with low batch_size
657
+ input_ids = input_ids.transpose(0, 1)
658
+ input_embeds = self.token_embedding(input_ids)
659
+ return {"input_embeds": input_embeds}
660
+
661
+
662
+ class MistralModel(nn.Module):
663
+ """Build pipeline graph"""
664
+
665
+ def __init__(
666
+ self,
667
+ config: MistralConfig,
668
+ parallel_context: ParallelContext,
669
+ parallel_config: Optional[ParallelismArgs],
670
+ ):
671
+ super().__init__()
672
+
673
+ # Declare all the nodes
674
+ self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda"))
675
+ self.config = config
676
+ self.parallel_config = parallel_config
677
+ self.parallel_context = parallel_context
678
+ self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE
679
+ tp_linear_async_communication = (
680
+ parallel_config.tp_linear_async_communication if parallel_config is not None else False
681
+ )
682
+
683
+ self.token_position_embeddings = PipelineBlock(
684
+ p2p=self.p2p,
685
+ module_builder=Embedding,
686
+ module_kwargs={
687
+ "tp_pg": parallel_context.tp_pg,
688
+ "config": config,
689
+ "parallel_config": parallel_config,
690
+ },
691
+ module_input_keys={"input_ids", "input_mask"},
692
+ module_output_keys={"input_embeds"},
693
+ )
694
+
695
+ self.decoder = nn.ModuleList(
696
+ [
697
+ PipelineBlock(
698
+ p2p=self.p2p,
699
+ module_builder=MistralDecoderLayer,
700
+ module_kwargs={
701
+ "config": config,
702
+ "parallel_config": parallel_config,
703
+ "tp_pg": parallel_context.tp_pg,
704
+ "layer_idx": layer_idx,
705
+ },
706
+ module_input_keys={"hidden_states", "sequence_mask"},
707
+ module_output_keys={"hidden_states", "sequence_mask"},
708
+ )
709
+ for layer_idx in range(config.num_hidden_layers)
710
+ ]
711
+ )
712
+
713
+ self.final_layer_norm = PipelineBlock(
714
+ p2p=self.p2p,
715
+ module_builder=TritonRMSNorm,
716
+ module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps},
717
+ module_input_keys={"input"},
718
+ module_output_keys={"hidden_states"},
719
+ ) # TODO
720
+
721
+ self.lm_head = PipelineBlock(
722
+ p2p=self.p2p,
723
+ # Understand that this means that we return sharded logits that are going to need to be gathered
724
+ module_builder=TensorParallelColumnLinear,
725
+ module_kwargs={
726
+ "in_features": config.hidden_size,
727
+ "out_features": config.vocab_size,
728
+ "pg": parallel_context.tp_pg,
729
+ "bias": False,
730
+ # TODO @thomasw21: refactor so that we store that default in a single place.
731
+ "mode": self.tp_mode,
732
+ "async_communication": tp_linear_async_communication,
733
+ },
734
+ module_input_keys={"x"},
735
+ module_output_keys={"logits"},
736
+ )
737
+
738
+ self.cast_to_fp32 = PipelineBlock(
739
+ p2p=self.p2p,
740
+ module_builder=lambda: lambda x: x.float(),
741
+ module_kwargs={},
742
+ module_input_keys={"x"},
743
+ module_output_keys={"output"},
744
+ )
745
+
746
+ def forward(
747
+ self,
748
+ input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
749
+ input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
750
+ ):
751
+ return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0]
752
+
753
+ def forward_with_hidden_states(
754
+ self,
755
+ input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
756
+ input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length]
757
+ ):
758
+ # all tensors are optional as most ranks don't need anything from the dataloader.
759
+
760
+ output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)
761
+
762
+ hidden_encoder_states = {
763
+ "hidden_states": output["input_embeds"],
764
+ "sequence_mask": input_mask,
765
+ }
766
+ for encoder_block in self.decoder:
767
+ hidden_encoder_states = encoder_block(**hidden_encoder_states)
768
+
769
+ hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
770
+
771
+ sharded_logits = self.lm_head(x=hidden_states)["logits"]
772
+
773
+ fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"]
774
+
775
+ return fp32_sharded_logits, hidden_states
776
+
777
+ def get_block_compute_costs(self):
778
+ """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
779
+ model_config = self.config
780
+ d_ff = model_config.intermediate_size
781
+ d_qkv = model_config.hidden_size // model_config.num_attention_heads
782
+ block_compute_costs = {
783
+ # CausalSelfAttention (qkv proj + attn out) + MLP
784
+ MistralDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
785
+ + 3 * d_ff * model_config.hidden_size,
786
+ # This is the last lm_head
787
+ TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
788
+ }
789
+ return block_compute_costs
790
+
791
+ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
792
+ """Get flops per second for a given model"""
793
+ world_size = self.parallel_context.world_pg.size()
794
+ try:
795
+ num_key_values_heads = self.config.num_key_value_heads
796
+ except AttributeError:
797
+ num_key_values_heads = self.config.num_attention_heads
798
+
799
+ model_flops, hardware_flops = get_flops(
800
+ num_layers=self.config.num_hidden_layers,
801
+ hidden_size=self.config.hidden_size,
802
+ num_heads=self.config.num_attention_heads,
803
+ num_key_value_heads=num_key_values_heads,
804
+ vocab_size=self.config.vocab_size,
805
+ ffn_hidden_size=self.config.intermediate_size,
806
+ seq_len=sequence_length,
807
+ batch_size=global_batch_size,
808
+ recompute_granularity=self.parallel_config.recompute_granularity,
809
+ )
810
+
811
+ model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12)
812
+ hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12)
813
+ return model_flops_per_s, hardware_flops_per_s
814
+
815
+
816
+ @torch.jit.script
817
+ def masked_mean(loss, label_mask, dtype):
818
+ # type: (Tensor, Tensor, torch.dtype) -> Tensor
819
+ return (loss * label_mask).sum(dtype=dtype) / label_mask.sum()
820
+
821
+
822
+ class Loss(nn.Module):
823
+ def __init__(self, tp_pg: dist.ProcessGroup):
824
+ super().__init__()
825
+ self.tp_pg = tp_pg
826
+
827
+ def forward(
828
+ self,
829
+ sharded_logits: torch.Tensor, # [seq_length, batch_size, logits]
830
+ label_ids: torch.Tensor, # [batch_size, seq_length]
831
+ label_mask: torch.Tensor, # [batch_size, seq_length]
832
+ ) -> Dict[str, torch.Tensor]:
833
+ # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
834
+ # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
835
+ loss = sharded_cross_entropy(
836
+ sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
837
+ ).transpose(0, 1)
838
+ # TODO @thomasw21: It's unclear what kind of normalization we want to do.
839
+ loss = masked_mean(loss, label_mask, dtype=torch.float)
840
+ # I think indexing causes a sync we don't actually want
841
+ # loss = loss[label_mask].sum()
842
+ return {"loss": loss}
843
+
844
+
845
+ class MistralForTraining(NanotronModel):
846
+ def __init__(
847
+ self,
848
+ config: MistralConfig,
849
+ parallel_context: ParallelContext,
850
+ parallel_config: Optional[ParallelismArgs],
851
+ random_states: Optional[RandomStates] = None,
852
+ ):
853
+ super().__init__()
854
+ import warnings
855
+ warnings.warn("This is just a Llama Model, not a Mistral one for demo purpose. Please fix implementation")
856
+ self.model = MistralModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config)
857
+ self.loss = PipelineBlock(
858
+ p2p=self.model.p2p,
859
+ module_builder=Loss,
860
+ module_kwargs={"tp_pg": parallel_context.tp_pg},
861
+ module_input_keys={
862
+ "sharded_logits",
863
+ "label_ids",
864
+ "label_mask",
865
+ },
866
+ module_output_keys={"loss"},
867
+ )
868
+ self.parallel_context = parallel_context
869
+ self.config = config
870
+ self.parallel_config = parallel_config
871
+
872
+ def forward(
873
+ self,
874
+ input_ids: Union[torch.Tensor, TensorPointer],
875
+ input_mask: Union[torch.Tensor, TensorPointer],
876
+ label_ids: Union[torch.Tensor, TensorPointer],
877
+ label_mask: Union[torch.Tensor, TensorPointer],
878
+ ) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
879
+ sharded_logits = self.model(
880
+ input_ids=input_ids,
881
+ input_mask=input_mask,
882
+ )
883
+ loss = self.loss(
884
+ sharded_logits=sharded_logits,
885
+ label_ids=label_ids,
886
+ label_mask=label_mask,
887
+ )["loss"]
888
+ return {"loss": loss}
889
+
890
+ @torch.no_grad()
891
+ def init_model_randomly(self, init_method, scaled_init_method):
892
+ """Initialize model parameters randomly.
893
+ Args:
894
+ init_method (callable): Used for embedding/position/qkv weight in attention/first layer weight of mlp/ /lm_head/
895
+ scaled_init_method (callable): Used for o weight in attention/second layer weight of mlp/
896
+
897
+ Note:
898
+ Layernorm weight all 0 or 1 depending on `apply_layernorm_1p`
899
+ """
900
+ model = self
901
+ initialized_parameters = set()
902
+ # Handle tensor parallelism
903
+ module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()}
904
+ # Fix the root_model
905
+ module_id_to_prefix[id(model)] = ""
906
+
907
+ for module_name, module in model.named_modules():
908
+ if isinstance(module, TensorParallelColumnLinear):
909
+ # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
910
+ # What it does:
911
+ # - instantiate a buffer of the `full size` in fp32
912
+ # - run init method on it
913
+ # - shard result to get only a specific shard
914
+ # Instead I'm lazy and just going to run init_method, since they are scalar independent
915
+ assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
916
+ name for name, _ in module.named_parameters()
917
+ }
918
+ for param_name, param in module.named_parameters():
919
+ assert isinstance(param, NanotronParameter)
920
+ if param.is_tied:
921
+ tied_info = param.get_tied_info()
922
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
923
+ module_id_to_prefix=module_id_to_prefix
924
+ )
925
+ else:
926
+ full_param_name = f"{module_name}.{param_name}"
927
+
928
+ if full_param_name in initialized_parameters:
929
+ # Already initialized
930
+ continue
931
+
932
+ if "weight" == param_name:
933
+ init_method(param)
934
+ elif "bias" == param_name:
935
+ param.zero_()
936
+ else:
937
+ raise ValueError(f"Who the fuck is {param_name}?")
938
+
939
+ assert full_param_name not in initialized_parameters
940
+ initialized_parameters.add(full_param_name)
941
+ elif isinstance(module, TensorParallelRowLinear):
942
+ # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
943
+ # What it does:
944
+ # - instantiate a buffer of the `full size` in fp32
945
+ # - run init method on it
946
+ # - shard result to get only a specific shard
947
+ # Instead I'm lazy and just going to run init_method, since they are scalar independent
948
+ assert {"weight"} == {name for name, _ in module.named_parameters()} or {"weight"} == {
949
+ name for name, _ in module.named_parameters()
950
+ }
951
+ for param_name, param in module.named_parameters():
952
+ assert isinstance(param, NanotronParameter)
953
+ if param.is_tied:
954
+ tied_info = param.get_tied_info()
955
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
956
+ module_id_to_prefix=module_id_to_prefix
957
+ )
958
+ else:
959
+ full_param_name = f"{module_name}.{param_name}"
960
+
961
+ if full_param_name in initialized_parameters:
962
+ # Already initialized
963
+ continue
964
+
965
+ if "weight" == param_name:
966
+ scaled_init_method(param)
967
+ elif "bias" == param_name:
968
+ param.zero_()
969
+ else:
970
+ raise ValueError(f"Who the fuck is {param_name}?")
971
+
972
+ assert full_param_name not in initialized_parameters
973
+ initialized_parameters.add(full_param_name)
974
+ elif isinstance(module, TritonRMSNorm):
975
+ assert {"weight"} == {name for name, _ in module.named_parameters()}
976
+ for param_name, param in module.named_parameters():
977
+ assert isinstance(param, NanotronParameter)
978
+ if param.is_tied:
979
+ tied_info = param.get_tied_info()
980
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
981
+ module_id_to_prefix=module_id_to_prefix
982
+ )
983
+ else:
984
+ full_param_name = f"{module_name}.{param_name}"
985
+
986
+ if full_param_name in initialized_parameters:
987
+ # Already initialized
988
+ continue
989
+
990
+ if "weight" == param_name:
991
+ # TODO @thomasw21: Sometimes we actually want 0
992
+ param.fill_(1)
993
+ elif "bias" == param_name:
994
+ param.zero_()
995
+ else:
996
+ raise ValueError(f"Who the fuck is {param_name}?")
997
+
998
+ assert full_param_name not in initialized_parameters
999
+ initialized_parameters.add(full_param_name)
1000
+ elif isinstance(module, TensorParallelEmbedding):
1001
+ # TODO @thomasw21: Handle tied embeddings
1002
+ # Somehow Megatron-LM does something super complicated, https://github.com/NVIDIA/Megatron-LM/blob/2360d732a399dd818d40cbe32828f65b260dee11/megatron/core/tensor_parallel/layers.py#L96
1003
+ # What it does:
1004
+ # - instantiate a buffer of the `full size` in fp32
1005
+ # - run init method on it
1006
+ # - shard result to get only a specific shard
1007
+ # Instead I'm lazy and just going to run init_method, since they are scalar independent
1008
+ assert {"weight"} == {name for name, _ in module.named_parameters()}
1009
+
1010
+ assert isinstance(module.weight, NanotronParameter)
1011
+ if module.weight.is_tied:
1012
+ tied_info = module.weight.get_tied_info()
1013
+ full_param_name = tied_info.get_full_name_from_module_id_to_prefix(
1014
+ module_id_to_prefix=module_id_to_prefix
1015
+ )
1016
+ else:
1017
+ full_param_name = f"{module_name}.weight"
1018
+
1019
+ if full_param_name in initialized_parameters:
1020
+ # Already initialized
1021
+ continue
1022
+
1023
+ init_method(module.weight)
1024
+ assert full_param_name not in initialized_parameters
1025
+ initialized_parameters.add(full_param_name)
1026
+
1027
+ assert initialized_parameters == {
1028
+ param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix)
1029
+ if param.is_tied
1030
+ else name
1031
+ for name, param in model.named_parameters()
1032
+ }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}"
1033
+
1034
+ def get_block_compute_costs(self):
1035
+ """Computes the compute cost of each block in the model so that we can do a better job of load balancing."""
1036
+ return self.model.get_block_compute_costs()
1037
+
1038
+ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size):
1039
+ """Get flops per second for a given model"""
1040
+ return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size)
1041
+
1042
+
1043
+ def get_flops(
1044
+ num_layers,
1045
+ hidden_size,
1046
+ num_heads,
1047
+ num_key_value_heads,
1048
+ vocab_size,
1049
+ seq_len,
1050
+ ffn_hidden_size,
1051
+ batch_size=1,
1052
+ recompute_granularity=None,
1053
+ ):
1054
+ """Counts flops in an decoder-only model
1055
+ Args:
1056
+ num_layers: number of decoder layers
1057
+ hidden_size: hidden size of the model
1058
+ num_heads: number of heads in the model
1059
+ num_key_value_heads: number of key/value heads in the model
1060
+ ffn_hidden_size: hidden size of the FFN
1061
+ vocab_size: size of the vocabulary
1062
+ seq_len: sequence length of the decoder
1063
+ batch_size: batch size
1064
+ recompute_granularity: Activation recomputation method. Either None, FULL or SELECTIVE. Check Megatron-LM docs for more info.
1065
+ Returns:
1066
+ model_flops: flops in the model (should be independent of the hardware and model implementation)
1067
+ hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf
1068
+ """
1069
+ if num_key_value_heads is None:
1070
+ num_key_value_heads = num_heads
1071
+ hidden_size_per_head = hidden_size // num_heads
1072
+ # In the following we mark the reduced dimension with parentheses
1073
+ # decoder
1074
+ # self attention
1075
+ ## qkv projection
1076
+ decoder_qkv_proj_flops_fwd = (
1077
+ 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head
1078
+ + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head
1079
+ )
1080
+ ## qk logits
1081
+ decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len
1082
+ ## v logits
1083
+ decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head
1084
+ ## attn out
1085
+ decoder_attn_out_flops_fwd = (
1086
+ 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size
1087
+ )
1088
+ # FF
1089
+ ## 1st layer
1090
+ decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size
1091
+ ## 2nd layer
1092
+ decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size
1093
+
1094
+ decoder_flops_fwd = (
1095
+ decoder_qkv_proj_flops_fwd
1096
+ + decoder_qk_logits_flops_fwd
1097
+ + decoder_v_logits_flops_fwd
1098
+ + decoder_attn_out_flops_fwd
1099
+ + decoder_ffn_1_flops_fwd
1100
+ + decoder_ffn_2_flops_fwd
1101
+ )
1102
+
1103
+ # lm head
1104
+ lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size
1105
+
1106
+ # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to
1107
+ # both input and weight tensors
1108
+ model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd
1109
+
1110
+ if recompute_granularity is None:
1111
+ hardware_flops = model_flops
1112
+ elif recompute_granularity is RecomputeGranularity.FULL:
1113
+ # Note: we don't recompute lm head activs
1114
+ hardware_flops = model_flops + decoder_flops_fwd # + activ recomputation
1115
+ elif recompute_granularity is RecomputeGranularity.SELECTIVE:
1116
+ # all terms with s^2 are flops that are recomputed
1117
+ # ref. appendix A: https://arxiv.org/pdf/2205.05198.pdf
1118
+ recomputed_decoder_flops = decoder_qk_logits_flops_fwd + decoder_v_logits_flops_fwd
1119
+ hardware_flops = model_flops + recomputed_decoder_flops
1120
+ else:
1121
+ raise ValueError("recompute_granularity must be one of 'full' or 'selective'")
1122
+
1123
+ return model_flops, hardware_flops
run_train.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nanotron training script.
3
+
4
+ Usage:
5
+ ```
6
+ export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations
7
+ torchrun --nproc_per_node=8 run_train.py --config-file config_tiny_mistral.yaml
8
+ ```
9
+ """
10
+ import argparse
11
+
12
+ from modeling_mistral import MistralForTraining
13
+ from dataloader import get_dataloader
14
+ from nanotron.trainer import DistributedTrainer
15
+ from config_tiny_mistral import MistralConfig
16
+
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file")
22
+ return parser.parse_args()
23
+
24
+
25
+ if __name__ == "__main__":
26
+ args = get_args()
27
+ config_file = args.config_file
28
+
29
+ # Load trainer and data
30
+ trainer = DistributedTrainer(config_file, model_class=MistralForTraining, model_config_class=MistralConfig)
31
+ dataloader = get_dataloader(trainer)
32
+
33
+ # Train
34
+ trainer.train(dataloader)