eyad-silx commited on
Commit
d278d9d
·
1 Parent(s): 5253ac8

Update repository

Browse files
Files changed (49) hide show
  1. LICENSE +21 -0
  2. README.md +227 -0
  3. assets/gpt2_124M_loss.png +0 -0
  4. assets/nanogpt.jpg +0 -0
  5. bench.py +117 -0
  6. config/char_config.py +43 -0
  7. config/dtat_config.py +48 -0
  8. config/enwik8_config.py +46 -0
  9. config/eval_gpt2.py +8 -0
  10. config/eval_gpt2_large.py +8 -0
  11. config/eval_gpt2_medium.py +8 -0
  12. config/eval_gpt2_xl.py +8 -0
  13. config/finetune_shakespeare.py +25 -0
  14. config/train_gpt2.py +25 -0
  15. config/train_shakespeare_char.py +37 -0
  16. configurator.py +47 -0
  17. data/openwebtext/prepare.py +81 -0
  18. data/openwebtext/readme.md +15 -0
  19. data/shakespeare/prepare.py +33 -0
  20. data/shakespeare/readme.md +9 -0
  21. data/shakespeare_char/prepare.py +68 -0
  22. data/shakespeare_char/readme.md +9 -0
  23. model.py +330 -0
  24. model_dtat.py +257 -0
  25. model_modified.py +190 -0
  26. prepare_data.py +37 -0
  27. sample.py +89 -0
  28. scaling_laws.ipynb +0 -0
  29. train.py +336 -0
  30. train_baseline.py +228 -0
  31. train_dtat.py +256 -0
  32. train_enwik8.py +114 -0
  33. transformer_sizing.ipynb +402 -0
  34. wandb/run-20241230_125819-geso4xvw/files/config.yaml +47 -0
  35. wandb/run-20241230_125819-geso4xvw/files/output.log +21 -0
  36. wandb/run-20241230_125819-geso4xvw/files/wandb-metadata.json +43 -0
  37. wandb/run-20241230_125819-geso4xvw/files/wandb-summary.json +1 -0
  38. wandb/run-20241230_125819-geso4xvw/logs/debug-core.log +14 -0
  39. wandb/run-20241230_125819-geso4xvw/logs/debug-internal.log +16 -0
  40. wandb/run-20241230_125819-geso4xvw/logs/debug.log +26 -0
  41. wandb/run-20241230_125819-geso4xvw/run-geso4xvw.wandb +0 -0
  42. wandb/run-20241230_125924-h4hgg9ir/files/config.yaml +47 -0
  43. wandb/run-20241230_125924-h4hgg9ir/files/output.log +29 -0
  44. wandb/run-20241230_125924-h4hgg9ir/files/wandb-metadata.json +43 -0
  45. wandb/run-20241230_125924-h4hgg9ir/files/wandb-summary.json +1 -0
  46. wandb/run-20241230_125924-h4hgg9ir/logs/debug-core.log +14 -0
  47. wandb/run-20241230_125924-h4hgg9ir/logs/debug-internal.log +17 -0
  48. wandb/run-20241230_125924-h4hgg9ir/logs/debug.log +26 -0
  49. wandb/run-20241230_125924-h4hgg9ir/run-h4hgg9ir.wandb +0 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Andrej Karpathy
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # nanoGPT
3
+
4
+ ![nanoGPT](assets/nanogpt.jpg)
5
+
6
+ The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it.
7
+
8
+ ![repro124m](assets/gpt2_124M_loss.png)
9
+
10
+ Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI).
11
+
12
+ ## install
13
+
14
+ ```
15
+ pip install torch numpy transformers datasets tiktoken wandb tqdm
16
+ ```
17
+
18
+ Dependencies:
19
+
20
+ - [pytorch](https://pytorch.org) <3
21
+ - [numpy](https://numpy.org/install/) <3
22
+ - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints)
23
+ - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText)
24
+ - `tiktoken` for OpenAI's fast BPE code <3
25
+ - `wandb` for optional logging <3
26
+ - `tqdm` for progress bars <3
27
+
28
+ ## quick start
29
+
30
+ If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers:
31
+
32
+ ```sh
33
+ python data/shakespeare_char/prepare.py
34
+ ```
35
+
36
+ This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system:
37
+
38
+ **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file:
39
+
40
+ ```sh
41
+ python train.py config/train_shakespeare_char.py
42
+ ```
43
+
44
+ If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory:
45
+
46
+ ```sh
47
+ python sample.py --out_dir=out-shakespeare-char
48
+ ```
49
+
50
+ This generates a few samples, for example:
51
+
52
+ ```
53
+ ANGELO:
54
+ And cowards it be strawn to my bed,
55
+ And thrust the gates of my threats,
56
+ Because he that ale away, and hang'd
57
+ An one with him.
58
+
59
+ DUKE VINCENTIO:
60
+ I thank your eyes against it.
61
+
62
+ DUKE VINCENTIO:
63
+ Then will answer him to save the malm:
64
+ And what have you tyrannous shall do this?
65
+
66
+ DUKE VINCENTIO:
67
+ If you have done evils of all disposition
68
+ To end his power, the day of thrust for a common men
69
+ That I leave, to fight with over-liking
70
+ Hasting in a roseman.
71
+ ```
72
+
73
+ lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later).
74
+
75
+ **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows:
76
+
77
+ ```sh
78
+ python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0
79
+ ```
80
+
81
+ Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun:
82
+
83
+ ```sh
84
+ python sample.py --out_dir=out-shakespeare-char --device=cpu
85
+ ```
86
+ Generates samples like this:
87
+
88
+ ```
89
+ GLEORKEN VINGHARD III:
90
+ Whell's the couse, the came light gacks,
91
+ And the for mought you in Aut fries the not high shee
92
+ bot thou the sought bechive in that to doth groan you,
93
+ No relving thee post mose the wear
94
+ ```
95
+
96
+ Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc.
97
+
98
+ Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more.
99
+
100
+ ## reproducing GPT-2
101
+
102
+ A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText:
103
+
104
+ ```sh
105
+ python data/openwebtext/prepare.py
106
+ ```
107
+
108
+ This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run:
109
+
110
+ ```sh
111
+ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
112
+ ```
113
+
114
+ This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match.
115
+
116
+ If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like:
117
+
118
+ ```sh
119
+ # Run on the first (master) node with example IP 123.456.123.456:
120
+ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
121
+ # Run on the worker node:
122
+ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
123
+ ```
124
+
125
+ It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `python sample.py`.
126
+
127
+ Finally, to train on a single GPU simply run the `python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs.
128
+
129
+ ## baselines
130
+
131
+ OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows:
132
+
133
+ ```sh
134
+ $ python train.py config/eval_gpt2.py
135
+ $ python train.py config/eval_gpt2_medium.py
136
+ $ python train.py config/eval_gpt2_large.py
137
+ $ python train.py config/eval_gpt2_xl.py
138
+ ```
139
+
140
+ and observe the following losses on train and val:
141
+
142
+ | model | params | train loss | val loss |
143
+ | ------| ------ | ---------- | -------- |
144
+ | gpt2 | 124M | 3.11 | 3.12 |
145
+ | gpt2-medium | 350M | 2.85 | 2.84 |
146
+ | gpt2-large | 774M | 2.66 | 2.67 |
147
+ | gpt2-xl | 1558M | 2.56 | 2.54 |
148
+
149
+ However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction.
150
+
151
+ ## finetuning
152
+
153
+ Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like:
154
+
155
+ ```sh
156
+ python train.py config/finetune_shakespeare.py
157
+ ```
158
+
159
+ This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`:
160
+
161
+ ```
162
+ THEODORE:
163
+ Thou shalt sell me to the highest bidder: if I die,
164
+ I sell thee to the first; if I go mad,
165
+ I sell thee to the second; if I
166
+ lie, I sell thee to the third; if I slay,
167
+ I sell thee to the fourth: so buy or sell,
168
+ I tell thee again, thou shalt not sell my
169
+ possession.
170
+
171
+ JULIET:
172
+ And if thou steal, thou shalt not sell thyself.
173
+
174
+ THEODORE:
175
+ I do not steal; I sell the stolen goods.
176
+
177
+ THEODORE:
178
+ Thou know'st not what thou sell'st; thou, a woman,
179
+ Thou art ever a victim, a thing of no worth:
180
+ Thou hast no right, no right, but to be sold.
181
+ ```
182
+
183
+ Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try!
184
+
185
+ ## sampling / inference
186
+
187
+ Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model:
188
+
189
+ ```sh
190
+ python sample.py \
191
+ --init_from=gpt2-xl \
192
+ --start="What is the answer to life, the universe, and everything?" \
193
+ --num_samples=5 --max_new_tokens=100
194
+ ```
195
+
196
+ If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. ```python sample.py --start=FILE:prompt.txt```.
197
+
198
+ ## efficiency notes
199
+
200
+ For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities.
201
+
202
+ Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team!
203
+
204
+ ## todos
205
+
206
+ - Investigate and add FSDP instead of DDP
207
+ - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.)
208
+ - Finetune the finetuning script, I think the hyperparams are not great
209
+ - Schedule for linear batch size increase during training
210
+ - Incorporate other embeddings (rotary, alibi)
211
+ - Separate out the optim buffers from model params in checkpoints I think
212
+ - Additional logging around network health (e.g. gradient clip events, magnitudes)
213
+ - Few more investigations around better init etc.
214
+
215
+ ## troubleshooting
216
+
217
+ Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run.
218
+
219
+ For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context.
220
+
221
+ For more questions/discussions feel free to stop by **#nanoGPT** on Discord:
222
+
223
+ [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp)
224
+
225
+ ## acknowledgements
226
+
227
+ All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT!
assets/gpt2_124M_loss.png ADDED
assets/nanogpt.jpg ADDED
bench.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A much shorter version of train.py for benchmarking
3
+ """
4
+ import os
5
+ from contextlib import nullcontext
6
+ import numpy as np
7
+ import time
8
+ import torch
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ batch_size = 12
13
+ block_size = 1024
14
+ bias = False
15
+ real_data = True
16
+ seed = 1337
17
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
18
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
19
+ compile = True # use PyTorch 2.0 to compile the model to be faster
20
+ profile = False # use pytorch profiler, or just simple benchmarking?
21
+ exec(open('configurator.py').read()) # overrides from command line or config file
22
+ # -----------------------------------------------------------------------------
23
+
24
+ torch.manual_seed(seed)
25
+ torch.cuda.manual_seed(seed)
26
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
27
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
28
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
29
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
30
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
31
+
32
+ # data loading init
33
+ if real_data:
34
+ dataset = 'openwebtext'
35
+ data_dir = os.path.join('data', dataset)
36
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
37
+ def get_batch(split):
38
+ data = train_data # note ignore split in benchmarking script
39
+ ix = torch.randint(len(data) - block_size, (batch_size,))
40
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
41
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
42
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
43
+ return x, y
44
+ else:
45
+ # alternatively, if fixed data is desired to not care about data loading
46
+ x = torch.randint(50304, (batch_size, block_size), device=device)
47
+ y = torch.randint(50304, (batch_size, block_size), device=device)
48
+ get_batch = lambda split: (x, y)
49
+
50
+ # model init
51
+ gptconf = GPTConfig(
52
+ block_size = block_size, # how far back does the model look? i.e. context size
53
+ n_layer = 12, n_head = 12, n_embd = 768, # size of the model
54
+ dropout = 0, # for determinism
55
+ bias = bias,
56
+ )
57
+ model = GPT(gptconf)
58
+ model.to(device)
59
+
60
+ optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)
61
+
62
+ if compile:
63
+ print("Compiling model...")
64
+ model = torch.compile(model) # pytorch 2.0
65
+
66
+ if profile:
67
+ # useful docs on pytorch profiler:
68
+ # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html
69
+ # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile
70
+ wait, warmup, active = 5, 5, 5
71
+ num_steps = wait + warmup + active
72
+ with torch.profiler.profile(
73
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
74
+ schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),
75
+ on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),
76
+ record_shapes=False,
77
+ profile_memory=False,
78
+ with_stack=False, # incurs an additional overhead, disable if not needed
79
+ with_flops=True,
80
+ with_modules=False, # only for torchscript models atm
81
+ ) as prof:
82
+
83
+ X, Y = get_batch('train')
84
+ for k in range(num_steps):
85
+ with ctx:
86
+ logits, loss = model(X, Y)
87
+ X, Y = get_batch('train')
88
+ optimizer.zero_grad(set_to_none=True)
89
+ loss.backward()
90
+ optimizer.step()
91
+ lossf = loss.item()
92
+ print(f"{k}/{num_steps} loss: {lossf:.4f}")
93
+
94
+ prof.step() # notify the profiler at end of each step
95
+
96
+ else:
97
+
98
+ # simple benchmarking
99
+ torch.cuda.synchronize()
100
+ for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark
101
+ t0 = time.time()
102
+ X, Y = get_batch('train')
103
+ for k in range(num_steps):
104
+ with ctx:
105
+ logits, loss = model(X, Y)
106
+ X, Y = get_batch('train')
107
+ optimizer.zero_grad(set_to_none=True)
108
+ loss.backward()
109
+ optimizer.step()
110
+ lossf = loss.item()
111
+ print(f"{k}/{num_steps} loss: {lossf:.4f}")
112
+ torch.cuda.synchronize()
113
+ t1 = time.time()
114
+ dt = t1-t0
115
+ mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt)
116
+ if stage == 1:
117
+ print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%")
config/char_config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for character-level language model on enwik8
3
+ Targeting ~44M parameters for comparison with baseline models
4
+ """
5
+
6
+ # Model configuration
7
+ config = {
8
+ # Dataset params
9
+ 'dataset': 'enwik8',
10
+ 'vocab_size': 256, # Character-level, so 256 possible byte values
11
+ 'block_size': 1024, # Context length
12
+
13
+ # Model params (tuned for ~44M parameters)
14
+ 'n_layer': 12,
15
+ 'n_head': 8,
16
+ 'n_embd': 512,
17
+ 'dropout': 0.1,
18
+ 'bias': False, # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
19
+
20
+ # Training params
21
+ 'learning_rate': 6e-4,
22
+ 'max_iters': 100000,
23
+ 'weight_decay': 1e-1,
24
+ 'beta1': 0.9,
25
+ 'beta2': 0.95,
26
+ 'grad_clip': 1.0,
27
+
28
+ # Learning rate decay settings
29
+ 'decay_lr': True,
30
+ 'warmup_iters': 2000,
31
+ 'lr_decay_iters': 100000,
32
+ 'min_lr': 6e-5,
33
+
34
+ # Evaluation and logging
35
+ 'eval_interval': 500,
36
+ 'log_interval': 100,
37
+ 'eval_iters': 200,
38
+
39
+ # System
40
+ 'device': 'cuda', # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
41
+ 'dtype': 'bfloat16', # 'float32', 'bfloat16', or 'float16'
42
+ 'compile': True, # use PyTorch 2.0 to compile the model to be faster
43
+ }
config/dtat_config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for Dynamic Token-Aware Transformer (DTAT) on enwik8
3
+ """
4
+
5
+ class DTATConfig:
6
+ def __init__(self):
7
+ # Model architecture
8
+ self.block_size = 1024
9
+ self.vocab_size = 256 # byte-level vocabulary
10
+ self.n_layer = 12
11
+ self.n_head = 8
12
+ self.n_embd = 512
13
+ self.dropout = 0.1
14
+ self.bias = False
15
+
16
+ # DTAT specific parameters
17
+ self.sparse_topk = 32 # Number of tokens to attend to for less important tokens
18
+
19
+ # Training parameters
20
+ self.batch_size = 32 # Added batch_size
21
+ self.learning_rate = 6e-4
22
+ self.weight_decay = 1e-1
23
+ self.beta1 = 0.9
24
+ self.beta2 = 0.95
25
+ self.grad_clip = 1.0
26
+ self.warmup_iters = 2000
27
+
28
+ # Learning rate schedule
29
+ self.decay_lr = True
30
+ self.lr_decay_iters = 100000
31
+ self.min_lr = 6e-5
32
+
33
+ # Training loop
34
+ self.max_iters = 100000
35
+ self.eval_interval = 500
36
+ self.log_interval = 100
37
+ self.eval_iters = 200
38
+
39
+ # System
40
+ self.device = 'cuda'
41
+ self.dtype = 'bfloat16'
42
+ self.compile = True
43
+
44
+ def get_config(self):
45
+ return self
46
+
47
+ def get_config():
48
+ return DTATConfig()
config/enwik8_config.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration for enwik8 dataset using NanoGPT architecture
3
+ Targeting ~44M parameters for comparison with baseline models
4
+ """
5
+
6
+ import ml_collections
7
+
8
+ def get_config():
9
+ config = ml_collections.ConfigDict()
10
+
11
+ # model
12
+ config.block_size = 1024
13
+ config.vocab_size = 256 # 256 possible byte values
14
+ config.n_layer = 12
15
+ config.n_head = 8
16
+ config.n_embd = 512
17
+ config.dropout = 0.1
18
+ config.bias = False
19
+
20
+ # adamw optimizer
21
+ config.learning_rate = 6e-4
22
+ config.max_iters = 100000
23
+ config.weight_decay = 1e-1
24
+ config.beta1 = 0.9
25
+ config.beta2 = 0.95
26
+ config.grad_clip = 1.0
27
+
28
+ # learning rate decay settings
29
+ config.decay_lr = True
30
+ config.warmup_iters = 2000
31
+ config.lr_decay_iters = 100000
32
+ config.min_lr = 6e-5
33
+
34
+ # system
35
+ config.device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc.
36
+ config.dtype = 'bfloat16' # 'float32', 'bfloat16', or 'float16'
37
+ config.compile = True # use PyTorch 2.0 to compile the model to be faster
38
+
39
+ # data
40
+ config.dataset = 'enwik8'
41
+ config.batch_size = 32
42
+ config.eval_interval = 500
43
+ config.log_interval = 100
44
+ config.eval_iters = 200
45
+
46
+ return config
config/eval_gpt2.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=12, n_head=12, n_embd=768
3
+ # 124M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2'
config/eval_gpt2_large.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=36, n_head=20, n_embd=1280
3
+ # 774M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2-large'
config/eval_gpt2_medium.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=24, n_head=16, n_embd=1024
3
+ # 350M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2-medium'
config/eval_gpt2_xl.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # evaluate the base gpt2
2
+ # n_layer=48, n_head=25, n_embd=1600
3
+ # 1558M parameters
4
+ batch_size = 8
5
+ eval_iters = 500 # use more iterations to get good estimate
6
+ eval_only = True
7
+ wandb_log = False
8
+ init_from = 'gpt2-xl'
config/finetune_shakespeare.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ out_dir = 'out-shakespeare'
4
+ eval_interval = 5
5
+ eval_iters = 40
6
+ wandb_log = False # feel free to turn on
7
+ wandb_project = 'shakespeare'
8
+ wandb_run_name = 'ft-' + str(time.time())
9
+
10
+ dataset = 'shakespeare'
11
+ init_from = 'gpt2-xl' # this is the largest GPT-2 model
12
+
13
+ # only save checkpoints if the validation loss improves
14
+ always_save_checkpoint = False
15
+
16
+ # the number of examples per iter:
17
+ # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter
18
+ # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters
19
+ batch_size = 1
20
+ gradient_accumulation_steps = 32
21
+ max_iters = 20
22
+
23
+ # finetune at constant LR
24
+ learning_rate = 3e-5
25
+ decay_lr = False
config/train_gpt2.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB
2
+ # launch as the following (e.g. in a screen session) and wait ~5 days:
3
+ # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py
4
+
5
+ wandb_log = True
6
+ wandb_project = 'owt'
7
+ wandb_run_name='gpt2-124M'
8
+
9
+ # these make the total batch size be ~0.5M
10
+ # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520
11
+ batch_size = 12
12
+ block_size = 1024
13
+ gradient_accumulation_steps = 5 * 8
14
+
15
+ # this makes total number of tokens be 300B
16
+ max_iters = 600000
17
+ lr_decay_iters = 600000
18
+
19
+ # eval stuff
20
+ eval_interval = 1000
21
+ eval_iters = 200
22
+ log_interval = 10
23
+
24
+ # weight decay
25
+ weight_decay = 1e-1
config/train_shakespeare_char.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train a miniature character-level shakespeare model
2
+ # good for debugging and playing on macbooks and such
3
+
4
+ out_dir = 'out-shakespeare-char'
5
+ eval_interval = 250 # keep frequent because we'll overfit
6
+ eval_iters = 200
7
+ log_interval = 10 # don't print too too often
8
+
9
+ # we expect to overfit on this small dataset, so only save when val improves
10
+ always_save_checkpoint = False
11
+
12
+ wandb_log = False # override via command line if you like
13
+ wandb_project = 'shakespeare-char'
14
+ wandb_run_name = 'mini-gpt'
15
+
16
+ dataset = 'shakespeare_char'
17
+ gradient_accumulation_steps = 1
18
+ batch_size = 64
19
+ block_size = 256 # context of up to 256 previous characters
20
+
21
+ # baby GPT model :)
22
+ n_layer = 6
23
+ n_head = 6
24
+ n_embd = 384
25
+ dropout = 0.2
26
+
27
+ learning_rate = 1e-3 # with baby networks can afford to go a bit higher
28
+ max_iters = 5000
29
+ lr_decay_iters = 5000 # make equal to max_iters usually
30
+ min_lr = 1e-4 # learning_rate / 10 usually
31
+ beta2 = 0.99 # make a bit bigger because number of tokens per iter is small
32
+
33
+ warmup_iters = 100 # not super necessary potentially
34
+
35
+ # on macbook also add
36
+ # device = 'cpu' # run on cpu only
37
+ # compile = False # do not torch compile the model
configurator.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Poor Man's Configurator. Probably a terrible idea. Example usage:
3
+ $ python train.py config/override_file.py --batch_size=32
4
+ this will first run config/override_file.py, then override batch_size to 32
5
+
6
+ The code in this file will be run as follows from e.g. train.py:
7
+ >>> exec(open('configurator.py').read())
8
+
9
+ So it's not a Python module, it's just shuttling this code away from train.py
10
+ The code in this script then overrides the globals()
11
+
12
+ I know people are not going to love this, I just really dislike configuration
13
+ complexity and having to prepend config. to every single variable. If someone
14
+ comes up with a better simple Python solution I am all ears.
15
+ """
16
+
17
+ import sys
18
+ from ast import literal_eval
19
+
20
+ for arg in sys.argv[1:]:
21
+ if '=' not in arg:
22
+ # assume it's the name of a config file
23
+ assert not arg.startswith('--')
24
+ config_file = arg
25
+ print(f"Overriding config with {config_file}:")
26
+ with open(config_file) as f:
27
+ print(f.read())
28
+ exec(open(config_file).read())
29
+ else:
30
+ # assume it's a --key=value argument
31
+ assert arg.startswith('--')
32
+ key, val = arg.split('=')
33
+ key = key[2:]
34
+ if key in globals():
35
+ try:
36
+ # attempt to eval it it (e.g. if bool, number, or etc)
37
+ attempt = literal_eval(val)
38
+ except (SyntaxError, ValueError):
39
+ # if that goes wrong, just use the string
40
+ attempt = val
41
+ # ensure the types match ok
42
+ assert type(attempt) == type(globals()[key])
43
+ # cross fingers
44
+ print(f"Overriding: {key} = {attempt}")
45
+ globals()[key] = attempt
46
+ else:
47
+ raise ValueError(f"Unknown config key: {key}")
data/openwebtext/prepare.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # saves the openwebtext dataset to a binary file for training. following was helpful:
2
+ # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import tiktoken
8
+ from datasets import load_dataset # huggingface datasets
9
+
10
+ # number of workers in .map() call
11
+ # good number to use is ~order number of cpu cores // 2
12
+ num_proc = 8
13
+
14
+ # number of workers in load_dataset() call
15
+ # best number might be different from num_proc above as it also depends on NW speed.
16
+ # it is better than 1 usually though
17
+ num_proc_load_dataset = num_proc
18
+
19
+ enc = tiktoken.get_encoding("gpt2")
20
+
21
+ if __name__ == '__main__':
22
+ # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
23
+ dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset)
24
+
25
+ # owt by default only contains the 'train' split, so create a test split
26
+ split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True)
27
+ split_dataset['val'] = split_dataset.pop('test') # rename the test split to val
28
+
29
+ # this results in:
30
+ # >>> split_dataset
31
+ # DatasetDict({
32
+ # train: Dataset({
33
+ # features: ['text'],
34
+ # num_rows: 8009762
35
+ # })
36
+ # val: Dataset({
37
+ # features: ['text'],
38
+ # num_rows: 4007
39
+ # })
40
+ # })
41
+
42
+ # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
43
+ def process(example):
44
+ ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens
45
+ ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
46
+ # note: I think eot should be prepended not appended... hmm. it's called "eot" though...
47
+ out = {'ids': ids, 'len': len(ids)}
48
+ return out
49
+
50
+ # tokenize the dataset
51
+ tokenized = split_dataset.map(
52
+ process,
53
+ remove_columns=['text'],
54
+ desc="tokenizing the splits",
55
+ num_proc=num_proc,
56
+ )
57
+
58
+ # concatenate all the ids in each dataset into one large file we can use for training
59
+ for split, dset in tokenized.items():
60
+ arr_len = np.sum(dset['len'], dtype=np.uint64)
61
+ filename = os.path.join(os.path.dirname(__file__), f'{split}.bin')
62
+ dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
63
+ arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
64
+ total_batches = 1024
65
+
66
+ idx = 0
67
+ for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
68
+ # Batch together samples for faster write
69
+ batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
70
+ arr_batch = np.concatenate(batch['ids'])
71
+ # Write into mmap
72
+ arr[idx : idx + len(arr_batch)] = arr_batch
73
+ idx += len(arr_batch)
74
+ arr.flush()
75
+
76
+ # train.bin is ~17GB, val.bin ~8.5MB
77
+ # train has ~9B tokens (9,035,582,198)
78
+ # val has ~4M tokens (4,434,897)
79
+
80
+ # to read the bin files later, e.g. with numpy:
81
+ # m = np.memmap('train.bin', dtype=np.uint16, mode='r')
data/openwebtext/readme.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## openwebtext dataset
3
+
4
+ after running `prepare.py` (preprocess) we get:
5
+
6
+ - train.bin is ~17GB, val.bin ~8.5MB
7
+ - train has ~9B tokens (9,035,582,198)
8
+ - val has ~4M tokens (4,434,897)
9
+
10
+ this came from 8,013,769 documents in total.
11
+
12
+ references:
13
+
14
+ - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
15
+ - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset
data/shakespeare/prepare.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import tiktoken
4
+ import numpy as np
5
+
6
+ # download the tiny shakespeare dataset
7
+ input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
8
+ if not os.path.exists(input_file_path):
9
+ data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
10
+ with open(input_file_path, 'w', encoding='utf-8') as f:
11
+ f.write(requests.get(data_url).text)
12
+
13
+ with open(input_file_path, 'r', encoding='utf-8') as f:
14
+ data = f.read()
15
+ n = len(data)
16
+ train_data = data[:int(n*0.9)]
17
+ val_data = data[int(n*0.9):]
18
+
19
+ # encode with tiktoken gpt2 bpe
20
+ enc = tiktoken.get_encoding("gpt2")
21
+ train_ids = enc.encode_ordinary(train_data)
22
+ val_ids = enc.encode_ordinary(val_data)
23
+ print(f"train has {len(train_ids):,} tokens")
24
+ print(f"val has {len(val_ids):,} tokens")
25
+
26
+ # export to bin files
27
+ train_ids = np.array(train_ids, dtype=np.uint16)
28
+ val_ids = np.array(val_ids, dtype=np.uint16)
29
+ train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
30
+ val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
31
+
32
+ # train.bin has 301,966 tokens
33
+ # val.bin has 36,059 tokens
data/shakespeare/readme.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # tiny shakespeare
3
+
4
+ Tiny shakespeare, of the good old char-rnn fame :)
5
+
6
+ After running `prepare.py`:
7
+
8
+ - train.bin has 301,966 tokens
9
+ - val.bin has 36,059 tokens
data/shakespeare_char/prepare.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prepare the Shakespeare dataset for character-level language modeling.
3
+ So instead of encoding with GPT-2 BPE tokens, we just map characters to ints.
4
+ Will save train.bin, val.bin containing the ids, and meta.pkl containing the
5
+ encoder and decoder and some other related info.
6
+ """
7
+ import os
8
+ import pickle
9
+ import requests
10
+ import numpy as np
11
+
12
+ # download the tiny shakespeare dataset
13
+ input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt')
14
+ if not os.path.exists(input_file_path):
15
+ data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
16
+ with open(input_file_path, 'w') as f:
17
+ f.write(requests.get(data_url).text)
18
+
19
+ with open(input_file_path, 'r') as f:
20
+ data = f.read()
21
+ print(f"length of dataset in characters: {len(data):,}")
22
+
23
+ # get all the unique characters that occur in this text
24
+ chars = sorted(list(set(data)))
25
+ vocab_size = len(chars)
26
+ print("all the unique characters:", ''.join(chars))
27
+ print(f"vocab size: {vocab_size:,}")
28
+
29
+ # create a mapping from characters to integers
30
+ stoi = { ch:i for i,ch in enumerate(chars) }
31
+ itos = { i:ch for i,ch in enumerate(chars) }
32
+ def encode(s):
33
+ return [stoi[c] for c in s] # encoder: take a string, output a list of integers
34
+ def decode(l):
35
+ return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
36
+
37
+ # create the train and test splits
38
+ n = len(data)
39
+ train_data = data[:int(n*0.9)]
40
+ val_data = data[int(n*0.9):]
41
+
42
+ # encode both to integers
43
+ train_ids = encode(train_data)
44
+ val_ids = encode(val_data)
45
+ print(f"train has {len(train_ids):,} tokens")
46
+ print(f"val has {len(val_ids):,} tokens")
47
+
48
+ # export to bin files
49
+ train_ids = np.array(train_ids, dtype=np.uint16)
50
+ val_ids = np.array(val_ids, dtype=np.uint16)
51
+ train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin'))
52
+ val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin'))
53
+
54
+ # save the meta information as well, to help us encode/decode later
55
+ meta = {
56
+ 'vocab_size': vocab_size,
57
+ 'itos': itos,
58
+ 'stoi': stoi,
59
+ }
60
+ with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f:
61
+ pickle.dump(meta, f)
62
+
63
+ # length of dataset in characters: 1115394
64
+ # all the unique characters:
65
+ # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
66
+ # vocab size: 65
67
+ # train has 1003854 tokens
68
+ # val has 111540 tokens
data/shakespeare_char/readme.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # tiny shakespeare, character-level
3
+
4
+ Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level.
5
+
6
+ After running `prepare.py`:
7
+
8
+ - train.bin has 1,003,854 tokens
9
+ - val.bin has 111,540 tokens
model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head = config.n_head
42
+ self.n_embd = config.n_embd
43
+ self.dropout = config.dropout
44
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46
+ if not self.flash:
47
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48
+ # causal mask to ensure that attention is only applied to the left in the input sequence
49
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50
+ .view(1, 1, config.block_size, config.block_size))
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ if self.flash:
63
+ # efficient attention using Flash Attention CUDA kernels
64
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65
+ else:
66
+ # manual implementation of attention
67
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69
+ att = F.softmax(att, dim=-1)
70
+ att = self.attn_dropout(att)
71
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73
+
74
+ # output projection
75
+ y = self.resid_dropout(self.c_proj(y))
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83
+ self.gelu = nn.GELU()
84
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85
+ self.dropout = nn.Dropout(config.dropout)
86
+
87
+ def forward(self, x):
88
+ x = self.c_fc(x)
89
+ x = self.gelu(x)
90
+ x = self.c_proj(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlp(self.ln_2(x))
106
+ return x
107
+
108
+ @dataclass
109
+ class GPTConfig:
110
+ block_size: int = 1024
111
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112
+ n_layer: int = 12
113
+ n_head: int = 12
114
+ n_embd: int = 768
115
+ dropout: float = 0.0
116
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+
118
+ class GPT(nn.Module):
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ assert config.vocab_size is not None
123
+ assert config.block_size is not None
124
+ self.config = config
125
+
126
+ self.transformer = nn.ModuleDict(dict(
127
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
128
+ wpe = nn.Embedding(config.block_size, config.n_embd),
129
+ drop = nn.Dropout(config.dropout),
130
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
132
+ ))
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ # with weight tying when using torch.compile() some warnings get generated:
135
+ # "UserWarning: functional_call was passed multiple values for tied weights.
136
+ # This behavior is deprecated and will be an error in future versions"
137
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
138
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139
+
140
+ # init all weights
141
+ self.apply(self._init_weights)
142
+ # apply special scaled init to the residual projections, per GPT-2 paper
143
+ for pn, p in self.named_parameters():
144
+ if pn.endswith('c_proj.weight'):
145
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
+
147
+ # report number of parameters
148
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
+
150
+ def get_num_params(self, non_embedding=True):
151
+ """
152
+ Return the number of parameters in the model.
153
+ For non-embedding count (default), the position embeddings get subtracted.
154
+ The token embeddings would too, except due to the parameter sharing these
155
+ params are actually used as weights in the final layer, so we include them.
156
+ """
157
+ n_params = sum(p.numel() for p in self.parameters())
158
+ if non_embedding:
159
+ n_params -= self.transformer.wpe.weight.numel()
160
+ return n_params
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Linear):
164
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165
+ if module.bias is not None:
166
+ torch.nn.init.zeros_(module.bias)
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169
+
170
+ def forward(self, idx, targets=None):
171
+ device = idx.device
172
+ b, t = idx.size()
173
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175
+
176
+ # forward the GPT model itself
177
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
+ x = self.transformer.drop(tok_emb + pos_emb)
180
+ for block in self.transformer.h:
181
+ x = block(x)
182
+ x = self.transformer.ln_f(x)
183
+
184
+ if targets is not None:
185
+ # if we are given some desired targets also calculate the loss
186
+ logits = self.lm_head(x)
187
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188
+ else:
189
+ # inference-time mini-optimization: only forward the lm_head on the very last position
190
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191
+ loss = None
192
+
193
+ return logits, loss
194
+
195
+ def crop_block_size(self, block_size):
196
+ # model surgery to decrease the block size if necessary
197
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198
+ # but want to use a smaller block size for some smaller, simpler model
199
+ assert block_size <= self.config.block_size
200
+ self.config.block_size = block_size
201
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202
+ for block in self.transformer.h:
203
+ if hasattr(block.attn, 'bias'):
204
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205
+
206
+ @classmethod
207
+ def from_pretrained(cls, model_type, override_args=None):
208
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209
+ override_args = override_args or {} # default to empty dict
210
+ # only dropout can be overridden see more notes below
211
+ assert all(k == 'dropout' for k in override_args)
212
+ from transformers import GPT2LMHeadModel
213
+ print("loading weights from pretrained gpt: %s" % model_type)
214
+
215
+ # n_layer, n_head and n_embd are determined from model_type
216
+ config_args = {
217
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221
+ }[model_type]
222
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
223
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225
+ config_args['bias'] = True # always True for GPT model checkpoints
226
+ # we can override the dropout rate, if desired
227
+ if 'dropout' in override_args:
228
+ print(f"overriding dropout rate to {override_args['dropout']}")
229
+ config_args['dropout'] = override_args['dropout']
230
+ # create a from-scratch initialized minGPT model
231
+ config = GPTConfig(**config_args)
232
+ model = GPT(config)
233
+ sd = model.state_dict()
234
+ sd_keys = sd.keys()
235
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236
+
237
+ # init a huggingface/transformers model
238
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239
+ sd_hf = model_hf.state_dict()
240
+
241
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
242
+ sd_keys_hf = sd_hf.keys()
243
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247
+ # this means that we have to transpose these weights when we import them
248
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249
+ for k in sd_keys_hf:
250
+ if any(k.endswith(w) for w in transposed):
251
+ # special treatment for the Conv1D weights we need to transpose
252
+ assert sd_hf[k].shape[::-1] == sd[k].shape
253
+ with torch.no_grad():
254
+ sd[k].copy_(sd_hf[k].t())
255
+ else:
256
+ # vanilla copy over the other parameters
257
+ assert sd_hf[k].shape == sd[k].shape
258
+ with torch.no_grad():
259
+ sd[k].copy_(sd_hf[k])
260
+
261
+ return model
262
+
263
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264
+ # start with all of the candidate parameters
265
+ param_dict = {pn: p for pn, p in self.named_parameters()}
266
+ # filter out those that do not require grad
267
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272
+ optim_groups = [
273
+ {'params': decay_params, 'weight_decay': weight_decay},
274
+ {'params': nodecay_params, 'weight_decay': 0.0}
275
+ ]
276
+ num_decay_params = sum(p.numel() for p in decay_params)
277
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
278
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
+ # Create AdamW optimizer and use the fused version if it is available
281
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cuda'
283
+ extra_args = dict(fused=True) if use_fused else dict()
284
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
+ print(f"using fused AdamW: {use_fused}")
286
+
287
+ return optimizer
288
+
289
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
290
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291
+ # first estimate the number of flops we do per iteration.
292
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293
+ N = self.get_num_params()
294
+ cfg = self.config
295
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296
+ flops_per_token = 6*N + 12*L*H*Q*T
297
+ flops_per_fwdbwd = flops_per_token * T
298
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
300
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
301
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302
+ mfu = flops_achieved / flops_promised
303
+ return mfu
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
310
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311
+ """
312
+ for _ in range(max_new_tokens):
313
+ # if the sequence context is growing too long we must crop it at block_size
314
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315
+ # forward the model to get the logits for the index in the sequence
316
+ logits, _ = self(idx_cond)
317
+ # pluck the logits at the final step and scale by desired temperature
318
+ logits = logits[:, -1, :] / temperature
319
+ # optionally crop the logits to only the top k options
320
+ if top_k is not None:
321
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322
+ logits[logits < v[:, [-1]]] = -float('Inf')
323
+ # apply softmax to convert logits to (normalized) probabilities
324
+ probs = F.softmax(logits, dim=-1)
325
+ # sample from the distribution
326
+ idx_next = torch.multinomial(probs, num_samples=1)
327
+ # append sampled index to the running sequence and continue
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx
model_dtat.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class TokenImportanceNetwork(nn.Module):
7
+ """
8
+ Computes importance scores for each token based on:
9
+ 1. Local context patterns
10
+ 2. Token frequency
11
+ 3. Position information
12
+ """
13
+ def __init__(self, config):
14
+ super().__init__()
15
+ self.n_embd = config.n_embd
16
+
17
+ # Local context processing
18
+ self.context_net = nn.Sequential(
19
+ nn.Conv1d(config.n_embd, config.n_embd // 2, kernel_size=3, padding=1),
20
+ nn.ReLU(),
21
+ nn.Conv1d(config.n_embd // 2, 1, kernel_size=1)
22
+ )
23
+
24
+ # Frequency awareness
25
+ self.freq_embedding = nn.Embedding(256, config.n_embd // 4) # 256 possible byte values
26
+
27
+ # Position awareness
28
+ self.pos_embedding = nn.Embedding(config.block_size, config.n_embd // 4)
29
+
30
+ # Final importance score computation
31
+ self.importance_proj = nn.Sequential(
32
+ nn.Linear(config.n_embd + config.n_embd//2, config.n_embd//4),
33
+ nn.ReLU(),
34
+ nn.Linear(config.n_embd//4, 1),
35
+ nn.Sigmoid()
36
+ )
37
+
38
+ def forward(self, x, freq_table, positions):
39
+ B, T, C = x.shape
40
+
41
+ # Process local context
42
+ x_conv = self.context_net(x.transpose(1, 2)).transpose(1, 2) # [B, T, 1]
43
+
44
+ # Get frequency embeddings
45
+ freq_emb = self.freq_embedding(freq_table) # [B, T, C//4]
46
+
47
+ # Get position embeddings
48
+ pos_emb = self.pos_embedding(positions) # [B, T, C//4]
49
+
50
+ # Combine all features
51
+ combined = torch.cat([x, freq_emb, pos_emb], dim=-1)
52
+
53
+ # Compute importance scores
54
+ importance = self.importance_proj(combined) # [B, T, 1]
55
+ return importance
56
+
57
+ class SparseDenseAttention(nn.Module):
58
+ """
59
+ Hybrid attention mechanism that uses:
60
+ - Full attention for important tokens
61
+ - Sparse attention for less important tokens
62
+ """
63
+ def __init__(self, config):
64
+ super().__init__()
65
+ assert config.n_embd % config.n_head == 0
66
+
67
+ self.n_head = config.n_head
68
+ self.n_embd = config.n_embd
69
+ self.dropout = config.dropout
70
+
71
+ # Key, Query, Value projections
72
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
73
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
74
+
75
+ # Dropouts
76
+ self.attn_dropout = nn.Dropout(config.dropout)
77
+ self.resid_dropout = nn.Dropout(config.dropout)
78
+
79
+ # Sparse attention parameters
80
+ self.sparse_topk = getattr(config, 'sparse_topk', 32) # Number of tokens to attend to for less important tokens
81
+
82
+ def forward(self, x, importance_scores, mask=None):
83
+ B, T, C = x.shape
84
+
85
+ # Calculate query, key, values for all heads in batch
86
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
87
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
88
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
89
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
90
+
91
+ # Compute attention scores
92
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
93
+
94
+ # Apply importance scores
95
+ importance_scores = importance_scores.squeeze(-1).unsqueeze(1) # [B, 1, T]
96
+ att = att * importance_scores.unsqueeze(-1) # Scale attention by importance
97
+
98
+ # For less important tokens (importance < threshold), use sparse attention
99
+ sparse_mask = importance_scores < 0.5
100
+ if sparse_mask.any():
101
+ # Keep only top-k values for less important tokens
102
+ topk_values, _ = torch.topk(att.masked_fill(~sparse_mask, -float('inf')),
103
+ k=self.sparse_topk, dim=-1)
104
+ sparse_threshold = topk_values[..., -1, None]
105
+ att = att.masked_fill(
106
+ (att < sparse_threshold) & sparse_mask.unsqueeze(-1),
107
+ -float('inf')
108
+ )
109
+
110
+ # Apply softmax and dropout
111
+ att = F.softmax(att, dim=-1)
112
+ att = self.attn_dropout(att)
113
+
114
+ # Compute output
115
+ y = att @ v # [B, nh, T, hs]
116
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
117
+
118
+ # Output projection
119
+ y = self.resid_dropout(self.c_proj(y))
120
+ return y
121
+
122
+ class Block(nn.Module):
123
+ """
124
+ Transformer block with importance-aware processing
125
+ """
126
+ def __init__(self, config):
127
+ super().__init__()
128
+ self.ln_1 = nn.LayerNorm(config.n_embd)
129
+ self.attn = SparseDenseAttention(config)
130
+ self.ln_2 = nn.LayerNorm(config.n_embd)
131
+
132
+ self.mlp = nn.Sequential(
133
+ nn.Linear(config.n_embd, 4 * config.n_embd),
134
+ nn.GELU(),
135
+ nn.Linear(4 * config.n_embd, config.n_embd),
136
+ nn.Dropout(config.dropout),
137
+ )
138
+
139
+ # Feature amplification
140
+ self.feature_gate = nn.Sequential(
141
+ nn.Linear(config.n_embd, config.n_embd),
142
+ nn.Sigmoid()
143
+ )
144
+
145
+ def forward(self, x, importance_scores):
146
+ # Self-attention with importance awareness
147
+ attn_output = self.attn(self.ln_1(x), importance_scores)
148
+ x = x + attn_output
149
+
150
+ # Feature amplification based on importance
151
+ gate = self.feature_gate(x)
152
+ x = x * (1 + importance_scores * gate)
153
+
154
+ # MLP block
155
+ x = x + self.mlp(self.ln_2(x))
156
+ return x
157
+
158
+ class DTATTransformer(nn.Module):
159
+ """
160
+ Dynamic Token-Aware Transformer (DTAT) for character-level language modeling
161
+ """
162
+ def __init__(self, config):
163
+ super().__init__()
164
+ assert config.vocab_size is not None
165
+ assert config.block_size is not None
166
+ self.config = config
167
+
168
+ self.transformer = nn.ModuleDict(dict(
169
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
170
+ wpe = nn.Embedding(config.block_size, config.n_embd),
171
+ drop = nn.Dropout(config.dropout),
172
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
173
+ ln_f = nn.LayerNorm(config.n_embd),
174
+ ))
175
+
176
+ # Token importance network
177
+ self.importance_net = TokenImportanceNetwork(config)
178
+
179
+ # Output head
180
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
181
+
182
+ # Initialize weights
183
+ self.apply(self._init_weights)
184
+
185
+ # Report number of parameters
186
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
187
+
188
+ def get_num_params(self):
189
+ return sum(p.numel() for p in self.parameters())
190
+
191
+ def _init_weights(self, module):
192
+ if isinstance(module, nn.Linear):
193
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
194
+ if module.bias is not None:
195
+ torch.nn.init.zeros_(module.bias)
196
+ elif isinstance(module, nn.Embedding):
197
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
198
+
199
+ def forward(self, idx, targets=None, freq_table=None):
200
+ device = idx.device
201
+ b, t = idx.size()
202
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
203
+
204
+ # Get token frequencies if not provided
205
+ if freq_table is None:
206
+ freq_table = torch.bincount(idx.view(-1), minlength=self.config.vocab_size)
207
+ freq_table = freq_table.view(1, -1).expand(b, -1)
208
+
209
+ # Generate position indices
210
+ pos = torch.arange(0, t, dtype=torch.long, device=device)
211
+
212
+ # Token embeddings
213
+ tok_emb = self.transformer.wte(idx)
214
+ pos_emb = self.transformer.wpe(pos)
215
+ x = self.transformer.drop(tok_emb + pos_emb)
216
+
217
+ # Compute token importance scores
218
+ importance_scores = self.importance_net(x, freq_table, pos)
219
+
220
+ # Apply transformer blocks with importance awareness
221
+ for block in self.transformer.h:
222
+ x = block(x, importance_scores)
223
+
224
+ x = self.transformer.ln_f(x)
225
+
226
+ # Language modeling head
227
+ logits = self.lm_head(x)
228
+
229
+ # Loss computation
230
+ loss = None
231
+ if targets is not None:
232
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
233
+ # Convert loss to bits per character (bpc)
234
+ loss = loss / math.log(2)
235
+
236
+ return logits, loss, importance_scores
237
+
238
+ @torch.no_grad()
239
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
240
+ for _ in range(max_new_tokens):
241
+ # Crop context if needed
242
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
243
+ # Forward pass
244
+ logits, _, _ = self(idx_cond)
245
+ logits = logits[:, -1, :] / temperature
246
+
247
+ # Optional top-k sampling
248
+ if top_k is not None:
249
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
250
+ logits[logits < v[:, [-1]]] = -float('inf')
251
+
252
+ # Sample from distribution
253
+ probs = F.softmax(logits, dim=-1)
254
+ idx_next = torch.multinomial(probs, num_samples=1)
255
+ idx = torch.cat((idx, idx_next), dim=1)
256
+
257
+ return idx
model_modified.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ class HierarchicalPositionEncoding(nn.Module):
7
+ """
8
+ Hierarchical Position Encoding that captures position information at multiple scales:
9
+ - Fine-grained local position (token level)
10
+ - Medium-scale position (segment level)
11
+ - Coarse-grained position (document level)
12
+ """
13
+ def __init__(self, d_model, max_len=1024, base=10000):
14
+ super().__init__()
15
+ self.d_model = d_model
16
+ self.max_len = max_len
17
+ self.base = base
18
+
19
+ # Split embedding dimensions for different scales
20
+ self.local_dim = d_model // 2
21
+ self.segment_dim = d_model // 4
22
+ self.doc_dim = d_model - self.local_dim - self.segment_dim
23
+
24
+ # Create position encodings for different scales
25
+ self.register_buffer('local_pe', self._create_pe(max_len, self.local_dim))
26
+ self.register_buffer('segment_pe', self._create_pe(max_len//8, self.segment_dim))
27
+ self.register_buffer('doc_pe', self._create_pe(max_len//32, self.doc_dim))
28
+
29
+ def _create_pe(self, max_len, d_model):
30
+ pe = torch.zeros(max_len, d_model)
31
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
32
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(self.base) / d_model))
33
+ pe[:, 0::2] = torch.sin(position * div_term)
34
+ pe[:, 1::2] = torch.cos(position * div_term)
35
+ return pe.unsqueeze(0)
36
+
37
+ def forward(self, x):
38
+ B, T, C = x.shape
39
+
40
+ # Get positional encodings at different scales
41
+ local_pos = self.local_pe[:, :T, :]
42
+ segment_pos = self.segment_pe[:, :(T//8), :].repeat_interleave(8, dim=1)[:, :T, :]
43
+ doc_pos = self.doc_pe[:, :(T//32), :].repeat_interleave(32, dim=1)[:, :T, :]
44
+
45
+ # Combine all scales
46
+ pos_encoding = torch.cat([local_pos, segment_pos, doc_pos], dim=-1)
47
+ return pos_encoding
48
+
49
+ class MultiScaleAttention(nn.Module):
50
+ """
51
+ Multi-scale attention mechanism that processes information at different temporal scales
52
+ """
53
+ def __init__(self, config):
54
+ super().__init__()
55
+ assert config.n_embd % config.n_head == 0
56
+
57
+ # key, query, value projections for all heads, but in a batch
58
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
59
+ # output projection
60
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
61
+ # regularization
62
+ self.attn_dropout = nn.Dropout(config.dropout)
63
+ self.resid_dropout = nn.Dropout(config.dropout)
64
+ self.n_head = config.n_head
65
+ self.n_embd = config.n_embd
66
+ self.dropout = config.dropout
67
+
68
+ def forward(self, x):
69
+ B, T, C = x.shape # batch size, sequence length, embedding dimensionality
70
+
71
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
72
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
73
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
74
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
75
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
76
+
77
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
78
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
79
+ att = F.softmax(att, dim=-1)
80
+ att = self.attn_dropout(att)
81
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
82
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
83
+
84
+ # output projection
85
+ y = self.resid_dropout(self.c_proj(y))
86
+ return y
87
+
88
+ class Block(nn.Module):
89
+ def __init__(self, config):
90
+ super().__init__()
91
+ self.ln_1 = nn.LayerNorm(config.n_embd)
92
+ self.attn = MultiScaleAttention(config)
93
+ self.ln_2 = nn.LayerNorm(config.n_embd)
94
+ self.mlp = nn.ModuleDict(dict(
95
+ c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias),
96
+ c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias),
97
+ act = nn.GELU(),
98
+ dropout = nn.Dropout(config.dropout),
99
+ ))
100
+ m = self.mlp
101
+ self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlpf(self.ln_2(x))
106
+ return x
107
+
108
+ class GPTModified(nn.Module):
109
+ def __init__(self, config):
110
+ super().__init__()
111
+ assert config.vocab_size is not None
112
+ assert config.block_size is not None
113
+ self.config = config
114
+
115
+ self.transformer = nn.ModuleDict(dict(
116
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
117
+ hpe = HierarchicalPositionEncoding(config.n_embd, config.block_size),
118
+ drop = nn.Dropout(config.dropout),
119
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
120
+ ln_f = nn.LayerNorm(config.n_embd),
121
+ ))
122
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
123
+
124
+ # Initialize weights
125
+ self.apply(self._init_weights)
126
+ # Apply special scaled init to the residual projections, per GPT-2 paper
127
+ for pn, p in self.named_parameters():
128
+ if pn.endswith('c_proj.weight'):
129
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
130
+
131
+ # Report number of parameters
132
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
133
+
134
+ def get_num_params(self, non_embedding=True):
135
+ n_params = sum(p.numel() for p in self.parameters())
136
+ if non_embedding:
137
+ n_params -= self.transformer.wte.weight.numel()
138
+ return n_params
139
+
140
+ def _init_weights(self, module):
141
+ if isinstance(module, nn.Linear):
142
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
143
+ if module.bias is not None:
144
+ torch.nn.init.zeros_(module.bias)
145
+ elif isinstance(module, nn.Embedding):
146
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
147
+
148
+ def forward(self, idx, targets=None):
149
+ device = idx.device
150
+ b, t = idx.size()
151
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
152
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
153
+
154
+ # Forward pass
155
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
156
+ pos_emb = self.transformer.hpe(tok_emb) # position embeddings of shape (b, t, n_embd)
157
+ x = self.transformer.drop(tok_emb + pos_emb)
158
+ for block in self.transformer.h:
159
+ x = block(x)
160
+ x = self.transformer.ln_f(x)
161
+ logits = self.lm_head(x)
162
+
163
+ # If we are given some desired targets also calculate the loss
164
+ loss = None
165
+ if targets is not None:
166
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
167
+
168
+ return logits, loss
169
+
170
+ @torch.no_grad()
171
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
172
+ for _ in range(max_new_tokens):
173
+ # If the sequence context is growing too long we must crop it at block_size
174
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
175
+ # Forward the model to get the logits for the index in the sequence
176
+ logits, _ = self(idx_cond)
177
+ # Pluck the logits at the final step and scale by desired temperature
178
+ logits = logits[:, -1, :] / temperature
179
+ # Optionally crop the logits to only the top k options
180
+ if top_k is not None:
181
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
182
+ logits[logits < v[:, [-1]]] = -float('Inf')
183
+ # Apply softmax to convert logits to (normalized) probabilities
184
+ probs = F.softmax(logits, dim=-1)
185
+ # Sample from the distribution
186
+ idx_next = torch.multinomial(probs, num_samples=1)
187
+ # Append sampled index to the running sequence and continue
188
+ idx = torch.cat((idx, idx_next), dim=1)
189
+
190
+ return idx
prepare_data.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def prepare_enwik8(input_file, output_dir):
4
+ """
5
+ Prepare enwik8 dataset from enwik9:
6
+ - Extract first 100M bytes for enwik8
7
+ - Split into train (90M), val (5M), and test (5M)
8
+ """
9
+ # Create output directory if it doesn't exist
10
+ os.makedirs(output_dir, exist_ok=True)
11
+
12
+ # Read first 100M bytes from enwik9
13
+ with open(input_file, 'rb') as f:
14
+ data = f.read(100_000_000) # Read exactly 100M bytes
15
+
16
+ # Split the data
17
+ train_data = data[:90_000_000] # First 90M bytes
18
+ val_data = data[90_000_000:95_000_000] # Next 5M bytes
19
+ test_data = data[95_000_000:] # Last 5M bytes
20
+
21
+ # Save splits
22
+ splits = {
23
+ 'train.bin': train_data,
24
+ 'val.bin': val_data,
25
+ 'test.bin': test_data
26
+ }
27
+
28
+ for name, split_data in splits.items():
29
+ with open(os.path.join(output_dir, name), 'wb') as f:
30
+ f.write(split_data)
31
+ print(f"Saved {name} ({len(split_data):,} bytes)")
32
+
33
+ if __name__ == "__main__":
34
+ input_file = "enwik9/enwik9"
35
+ output_dir = "data"
36
+ prepare_enwik8(input_file, output_dir)
37
+ print("Dataset preparation completed!")
sample.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample from a trained model
3
+ """
4
+ import os
5
+ import pickle
6
+ from contextlib import nullcontext
7
+ import torch
8
+ import tiktoken
9
+ from model import GPTConfig, GPT
10
+
11
+ # -----------------------------------------------------------------------------
12
+ init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
13
+ out_dir = 'out' # ignored if init_from is not 'resume'
14
+ start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
15
+ num_samples = 10 # number of samples to draw
16
+ max_new_tokens = 500 # number of tokens generated in each sample
17
+ temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
18
+ top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability
19
+ seed = 1337
20
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
21
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
22
+ compile = False # use PyTorch 2.0 to compile the model to be faster
23
+ exec(open('configurator.py').read()) # overrides from command line or config file
24
+ # -----------------------------------------------------------------------------
25
+
26
+ torch.manual_seed(seed)
27
+ torch.cuda.manual_seed(seed)
28
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
29
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
30
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
31
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
32
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
33
+
34
+ # model
35
+ if init_from == 'resume':
36
+ # init from a model saved in a specific directory
37
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
38
+ checkpoint = torch.load(ckpt_path, map_location=device)
39
+ gptconf = GPTConfig(**checkpoint['model_args'])
40
+ model = GPT(gptconf)
41
+ state_dict = checkpoint['model']
42
+ unwanted_prefix = '_orig_mod.'
43
+ for k,v in list(state_dict.items()):
44
+ if k.startswith(unwanted_prefix):
45
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
46
+ model.load_state_dict(state_dict)
47
+ elif init_from.startswith('gpt2'):
48
+ # init from a given GPT-2 model
49
+ model = GPT.from_pretrained(init_from, dict(dropout=0.0))
50
+
51
+ model.eval()
52
+ model.to(device)
53
+ if compile:
54
+ model = torch.compile(model) # requires PyTorch 2.0 (optional)
55
+
56
+ # look for the meta pickle in case it is available in the dataset folder
57
+ load_meta = False
58
+ if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
59
+ meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
60
+ load_meta = os.path.exists(meta_path)
61
+ if load_meta:
62
+ print(f"Loading meta from {meta_path}...")
63
+ with open(meta_path, 'rb') as f:
64
+ meta = pickle.load(f)
65
+ # TODO want to make this more general to arbitrary encoder/decoder schemes
66
+ stoi, itos = meta['stoi'], meta['itos']
67
+ encode = lambda s: [stoi[c] for c in s]
68
+ decode = lambda l: ''.join([itos[i] for i in l])
69
+ else:
70
+ # ok let's assume gpt-2 encodings by default
71
+ print("No meta.pkl found, assuming GPT-2 encodings...")
72
+ enc = tiktoken.get_encoding("gpt2")
73
+ encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
74
+ decode = lambda l: enc.decode(l)
75
+
76
+ # encode the beginning of the prompt
77
+ if start.startswith('FILE:'):
78
+ with open(start[5:], 'r', encoding='utf-8') as f:
79
+ start = f.read()
80
+ start_ids = encode(start)
81
+ x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
82
+
83
+ # run generation
84
+ with torch.no_grad():
85
+ with ctx:
86
+ for k in range(num_samples):
87
+ y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
88
+ print(decode(y[0].tolist()))
89
+ print('---------------')
scaling_laws.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This training script can be run both on a single gpu in debug mode,
3
+ and also in a larger training run with distributed data parallel (ddp).
4
+
5
+ To run on a single GPU, example:
6
+ $ python train.py --batch_size=32 --compile=False
7
+
8
+ To run with DDP on 4 gpus on 1 node, example:
9
+ $ torchrun --standalone --nproc_per_node=4 train.py
10
+
11
+ To run with DDP on 4 gpus across 2 nodes, example:
12
+ - Run on the first (master) node with example IP 123.456.123.456:
13
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
14
+ - Run on the worker node:
15
+ $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
16
+ (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
17
+ """
18
+
19
+ import os
20
+ import time
21
+ import math
22
+ import pickle
23
+ from contextlib import nullcontext
24
+
25
+ import numpy as np
26
+ import torch
27
+ from torch.nn.parallel import DistributedDataParallel as DDP
28
+ from torch.distributed import init_process_group, destroy_process_group
29
+
30
+ from model import GPTConfig, GPT
31
+
32
+ # -----------------------------------------------------------------------------
33
+ # default config values designed to train a gpt2 (124M) on OpenWebText
34
+ # I/O
35
+ out_dir = 'out'
36
+ eval_interval = 2000
37
+ log_interval = 1
38
+ eval_iters = 200
39
+ eval_only = False # if True, script exits right after the first eval
40
+ always_save_checkpoint = True # if True, always save a checkpoint after each eval
41
+ init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
42
+ # wandb logging
43
+ wandb_log = False # disabled by default
44
+ wandb_project = 'owt'
45
+ wandb_run_name = 'gpt2' # 'run' + str(time.time())
46
+ # data
47
+ dataset = 'openwebtext'
48
+ gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
49
+ batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
50
+ block_size = 1024
51
+ # model
52
+ n_layer = 12
53
+ n_head = 12
54
+ n_embd = 768
55
+ dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
56
+ bias = False # do we use bias inside LayerNorm and Linear layers?
57
+ # adamw optimizer
58
+ learning_rate = 6e-4 # max learning rate
59
+ max_iters = 600000 # total number of training iterations
60
+ weight_decay = 1e-1
61
+ beta1 = 0.9
62
+ beta2 = 0.95
63
+ grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
64
+ # learning rate decay settings
65
+ decay_lr = True # whether to decay the learning rate
66
+ warmup_iters = 2000 # how many steps to warm up for
67
+ lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla
68
+ min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
69
+ # DDP settings
70
+ backend = 'nccl' # 'nccl', 'gloo', etc.
71
+ # system
72
+ device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
73
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
74
+ compile = True # use PyTorch 2.0 to compile the model to be faster
75
+ # -----------------------------------------------------------------------------
76
+ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
77
+ exec(open('configurator.py').read()) # overrides from command line or config file
78
+ config = {k: globals()[k] for k in config_keys} # will be useful for logging
79
+ # -----------------------------------------------------------------------------
80
+
81
+ # various inits, derived attributes, I/O setup
82
+ ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
83
+ if ddp:
84
+ init_process_group(backend=backend)
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ ddp_world_size = int(os.environ['WORLD_SIZE'])
88
+ device = f'cuda:{ddp_local_rank}'
89
+ torch.cuda.set_device(device)
90
+ master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
91
+ seed_offset = ddp_rank # each process gets a different seed
92
+ # world_size number of processes will be training simultaneously, so we can scale
93
+ # down the desired gradient accumulation iterations per process proportionally
94
+ assert gradient_accumulation_steps % ddp_world_size == 0
95
+ gradient_accumulation_steps //= ddp_world_size
96
+ else:
97
+ # if not ddp, we are running on a single gpu, and one process
98
+ master_process = True
99
+ seed_offset = 0
100
+ ddp_world_size = 1
101
+ tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
102
+ print(f"tokens per iteration will be: {tokens_per_iter:,}")
103
+
104
+ if master_process:
105
+ os.makedirs(out_dir, exist_ok=True)
106
+ torch.manual_seed(1337 + seed_offset)
107
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
108
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
109
+ device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
110
+ # note: float16 data type will automatically use a GradScaler
111
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
112
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
113
+
114
+ # poor man's data loader
115
+ data_dir = os.path.join('data', dataset)
116
+ def get_batch(split):
117
+ # We recreate np.memmap every batch to avoid a memory leak, as per
118
+ # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
119
+ if split == 'train':
120
+ data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
121
+ else:
122
+ data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
123
+ ix = torch.randint(len(data) - block_size, (batch_size,))
124
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
125
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
126
+ if device_type == 'cuda':
127
+ # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
128
+ x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
129
+ else:
130
+ x, y = x.to(device), y.to(device)
131
+ return x, y
132
+
133
+ # init these up here, can override if init_from='resume' (i.e. from a checkpoint)
134
+ iter_num = 0
135
+ best_val_loss = 1e9
136
+
137
+ # attempt to derive vocab_size from the dataset
138
+ meta_path = os.path.join(data_dir, 'meta.pkl')
139
+ meta_vocab_size = None
140
+ if os.path.exists(meta_path):
141
+ with open(meta_path, 'rb') as f:
142
+ meta = pickle.load(f)
143
+ meta_vocab_size = meta['vocab_size']
144
+ print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
145
+
146
+ # model init
147
+ model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,
148
+ bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line
149
+ if init_from == 'scratch':
150
+ # init a new model from scratch
151
+ print("Initializing a new model from scratch")
152
+ # determine the vocab size we'll use for from-scratch training
153
+ if meta_vocab_size is None:
154
+ print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)")
155
+ model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304
156
+ gptconf = GPTConfig(**model_args)
157
+ model = GPT(gptconf)
158
+ elif init_from == 'resume':
159
+ print(f"Resuming training from {out_dir}")
160
+ # resume training from a checkpoint.
161
+ ckpt_path = os.path.join(out_dir, 'ckpt.pt')
162
+ checkpoint = torch.load(ckpt_path, map_location=device)
163
+ checkpoint_model_args = checkpoint['model_args']
164
+ # force these config attributes to be equal otherwise we can't even resume training
165
+ # the rest of the attributes (e.g. dropout) can stay as desired from command line
166
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
167
+ model_args[k] = checkpoint_model_args[k]
168
+ # create the model
169
+ gptconf = GPTConfig(**model_args)
170
+ model = GPT(gptconf)
171
+ state_dict = checkpoint['model']
172
+ # fix the keys of the state dictionary :(
173
+ # honestly no idea how checkpoints sometimes get this prefix, have to debug more
174
+ unwanted_prefix = '_orig_mod.'
175
+ for k,v in list(state_dict.items()):
176
+ if k.startswith(unwanted_prefix):
177
+ state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
178
+ model.load_state_dict(state_dict)
179
+ iter_num = checkpoint['iter_num']
180
+ best_val_loss = checkpoint['best_val_loss']
181
+ elif init_from.startswith('gpt2'):
182
+ print(f"Initializing from OpenAI GPT-2 weights: {init_from}")
183
+ # initialize from OpenAI GPT-2 weights
184
+ override_args = dict(dropout=dropout)
185
+ model = GPT.from_pretrained(init_from, override_args)
186
+ # read off the created config params, so we can store them into checkpoint correctly
187
+ for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']:
188
+ model_args[k] = getattr(model.config, k)
189
+ # crop down the model block size if desired, using model surgery
190
+ if block_size < model.config.block_size:
191
+ model.crop_block_size(block_size)
192
+ model_args['block_size'] = block_size # so that the checkpoint will have the right value
193
+ model.to(device)
194
+
195
+ # initialize a GradScaler. If enabled=False scaler is a no-op
196
+ scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
197
+
198
+ # optimizer
199
+ optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
200
+ if init_from == 'resume':
201
+ optimizer.load_state_dict(checkpoint['optimizer'])
202
+ checkpoint = None # free up memory
203
+
204
+ # compile the model
205
+ if compile:
206
+ print("compiling the model... (takes a ~minute)")
207
+ unoptimized_model = model
208
+ model = torch.compile(model) # requires PyTorch 2.0
209
+
210
+ # wrap model into DDP container
211
+ if ddp:
212
+ model = DDP(model, device_ids=[ddp_local_rank])
213
+
214
+ # helps estimate an arbitrarily accurate loss over either split using many batches
215
+ @torch.no_grad()
216
+ def estimate_loss():
217
+ out = {}
218
+ model.eval()
219
+ for split in ['train', 'val']:
220
+ losses = torch.zeros(eval_iters)
221
+ for k in range(eval_iters):
222
+ X, Y = get_batch(split)
223
+ with ctx:
224
+ logits, loss = model(X, Y)
225
+ losses[k] = loss.item()
226
+ out[split] = losses.mean()
227
+ model.train()
228
+ return out
229
+
230
+ # learning rate decay scheduler (cosine with warmup)
231
+ def get_lr(it):
232
+ # 1) linear warmup for warmup_iters steps
233
+ if it < warmup_iters:
234
+ return learning_rate * (it + 1) / (warmup_iters + 1)
235
+ # 2) if it > lr_decay_iters, return min learning rate
236
+ if it > lr_decay_iters:
237
+ return min_lr
238
+ # 3) in between, use cosine decay down to min learning rate
239
+ decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
240
+ assert 0 <= decay_ratio <= 1
241
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
242
+ return min_lr + coeff * (learning_rate - min_lr)
243
+
244
+ # logging
245
+ if wandb_log and master_process:
246
+ import wandb
247
+ wandb.init(project=wandb_project, name=wandb_run_name, config=config)
248
+
249
+ # training loop
250
+ X, Y = get_batch('train') # fetch the very first batch
251
+ t0 = time.time()
252
+ local_iter_num = 0 # number of iterations in the lifetime of this process
253
+ raw_model = model.module if ddp else model # unwrap DDP container if needed
254
+ running_mfu = -1.0
255
+ while True:
256
+
257
+ # determine and set the learning rate for this iteration
258
+ lr = get_lr(iter_num) if decay_lr else learning_rate
259
+ for param_group in optimizer.param_groups:
260
+ param_group['lr'] = lr
261
+
262
+ # evaluate the loss on train/val sets and write checkpoints
263
+ if iter_num % eval_interval == 0 and master_process:
264
+ losses = estimate_loss()
265
+ print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
266
+ if wandb_log:
267
+ wandb.log({
268
+ "iter": iter_num,
269
+ "train/loss": losses['train'],
270
+ "val/loss": losses['val'],
271
+ "lr": lr,
272
+ "mfu": running_mfu*100, # convert to percentage
273
+ })
274
+ if losses['val'] < best_val_loss or always_save_checkpoint:
275
+ best_val_loss = losses['val']
276
+ if iter_num > 0:
277
+ checkpoint = {
278
+ 'model': raw_model.state_dict(),
279
+ 'optimizer': optimizer.state_dict(),
280
+ 'model_args': model_args,
281
+ 'iter_num': iter_num,
282
+ 'best_val_loss': best_val_loss,
283
+ 'config': config,
284
+ }
285
+ print(f"saving checkpoint to {out_dir}")
286
+ torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
287
+ if iter_num == 0 and eval_only:
288
+ break
289
+
290
+ # forward backward update, with optional gradient accumulation to simulate larger batch size
291
+ # and using the GradScaler if data type is float16
292
+ for micro_step in range(gradient_accumulation_steps):
293
+ if ddp:
294
+ # in DDP training we only need to sync gradients at the last micro step.
295
+ # the official way to do this is with model.no_sync() context manager, but
296
+ # I really dislike that this bloats the code and forces us to repeat code
297
+ # looking at the source of that context manager, it just toggles this variable
298
+ model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
299
+ with ctx:
300
+ logits, loss = model(X, Y)
301
+ loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
302
+ # immediately async prefetch next batch while model is doing the forward pass on the GPU
303
+ X, Y = get_batch('train')
304
+ # backward pass, with gradient scaling if training in fp16
305
+ scaler.scale(loss).backward()
306
+ # clip the gradient
307
+ if grad_clip != 0.0:
308
+ scaler.unscale_(optimizer)
309
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
310
+ # step the optimizer and scaler if training in fp16
311
+ scaler.step(optimizer)
312
+ scaler.update()
313
+ # flush the gradients as soon as we can, no need for this memory anymore
314
+ optimizer.zero_grad(set_to_none=True)
315
+
316
+ # timing and logging
317
+ t1 = time.time()
318
+ dt = t1 - t0
319
+ t0 = t1
320
+ if iter_num % log_interval == 0 and master_process:
321
+ # get loss as float. note: this is a CPU-GPU sync point
322
+ # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
323
+ lossf = loss.item() * gradient_accumulation_steps
324
+ if local_iter_num >= 5: # let the training loop settle a bit
325
+ mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
326
+ running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
327
+ print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
328
+ iter_num += 1
329
+ local_iter_num += 1
330
+
331
+ # termination conditions
332
+ if iter_num > max_iters:
333
+ break
334
+
335
+ if ddp:
336
+ destroy_process_group()
train_baseline.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for baseline NanoGPT model on enwik8 dataset.
3
+ Ensures proper bpc calculation and comparable evaluation with DTAT.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import math
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.nn.parallel import DistributedDataParallel as DDP
13
+ from torch.distributed import init_process_group, destroy_process_group
14
+ from contextlib import nullcontext
15
+ import wandb
16
+
17
+ from model import GPT, GPTConfig
18
+
19
+ def get_batch(data, block_size, batch_size, device):
20
+ """Generate a small batch of data of inputs x and targets y."""
21
+ ix = torch.randint(len(data) - block_size, (batch_size,))
22
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
23
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
24
+ x, y = x.to(device), y.to(device)
25
+ return x, y
26
+
27
+ def estimate_loss(model, data, eval_iters, block_size, batch_size, device):
28
+ """Estimate loss on data split, ensuring proper bpc calculation."""
29
+ model.eval()
30
+ losses = torch.zeros(eval_iters)
31
+ for k in range(eval_iters):
32
+ X, Y = get_batch(data, block_size, batch_size, device)
33
+ with torch.no_grad():
34
+ logits, loss = model(X, Y)
35
+ # Convert from nats to bpc
36
+ loss = loss / math.log(2)
37
+ losses[k] = loss.item()
38
+ out = losses.mean()
39
+ model.train()
40
+ return out
41
+
42
+ def get_lr(it, config):
43
+ """Get learning rate based on iteration."""
44
+ if it < config.warmup_iters:
45
+ return config.learning_rate * it / config.warmup_iters
46
+ if it > config.lr_decay_iters:
47
+ return config.min_lr
48
+ decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
49
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
50
+ return config.min_lr + coeff * (config.learning_rate - config.min_lr)
51
+
52
+ def main():
53
+ # Initialize distributed training if needed
54
+ ddp = int(os.environ.get('RANK', -1)) != -1
55
+ if ddp:
56
+ init_process_group(backend='nccl')
57
+ ddp_rank = int(os.environ['RANK'])
58
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
59
+ device = f'cuda:{ddp_local_rank}'
60
+ master_process = ddp_rank == 0
61
+ seed_offset = ddp_rank
62
+ else:
63
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
64
+ master_process = True
65
+ seed_offset = 0
66
+
67
+ torch.manual_seed(1337 + seed_offset)
68
+ torch.backends.cuda.matmul.allow_tf32 = True
69
+ torch.backends.cudnn.allow_tf32 = True
70
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
71
+
72
+ # Model configuration (matching paper's 44M parameter target)
73
+ config = GPTConfig(
74
+ block_size=1024,
75
+ vocab_size=256, # byte-level vocab
76
+ n_layer=12,
77
+ n_head=8,
78
+ n_embd=512,
79
+ dropout=0.1,
80
+ bias=False,
81
+ # Training specific
82
+ learning_rate=6e-4,
83
+ min_lr=6e-5,
84
+ warmup_iters=2000,
85
+ lr_decay_iters=100000,
86
+ max_iters=100000,
87
+ eval_interval=500,
88
+ eval_iters=200,
89
+ batch_size=32,
90
+ )
91
+
92
+ # Initialize wandb for baseline model
93
+ if master_process:
94
+ wandb.init(
95
+ project="enwik8-baseline",
96
+ config={
97
+ "architecture": "NanoGPT-Baseline",
98
+ "dataset": "enwik8",
99
+ "batch_size": config.batch_size,
100
+ "learning_rate": config.learning_rate,
101
+ "warmup_iters": config.warmup_iters,
102
+ "block_size": config.block_size,
103
+ "n_layer": config.n_layer,
104
+ "n_head": config.n_head,
105
+ "n_embd": config.n_embd,
106
+ "dropout": config.dropout,
107
+ }
108
+ )
109
+
110
+ # Data loading
111
+ print("Loading data...")
112
+ data_dir = os.path.join('data')
113
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r')
114
+ val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r')
115
+
116
+ # Model initialization
117
+ print("Initializing model...")
118
+ model = GPT(config)
119
+ model.to(device)
120
+
121
+ # Optimizer
122
+ optimizer = torch.optim.AdamW(
123
+ model.parameters(),
124
+ lr=config.learning_rate,
125
+ betas=(0.9, 0.95),
126
+ weight_decay=0.1,
127
+ )
128
+
129
+ if ddp:
130
+ model = DDP(model, device_ids=[ddp_local_rank])
131
+
132
+ # Training loop
133
+ print("Starting training...")
134
+ best_val_loss = float('inf')
135
+ iter_num = 0
136
+
137
+ while True:
138
+ lr = get_lr(iter_num, config)
139
+ for param_group in optimizer.param_groups:
140
+ param_group['lr'] = lr
141
+
142
+ # Get batch and timing
143
+ t0 = time.time()
144
+ X, Y = get_batch(train_data, config.block_size, config.batch_size, device)
145
+
146
+ # Forward pass
147
+ logits, loss = model(X, Y)
148
+ # Convert loss to bpc
149
+ loss = loss / math.log(2)
150
+
151
+ # Backward pass
152
+ optimizer.zero_grad(set_to_none=True)
153
+ loss.backward()
154
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
155
+ optimizer.step()
156
+
157
+ # Timing and logging
158
+ t1 = time.time()
159
+ dt = t1 - t0
160
+
161
+ if iter_num % 100 == 0 and master_process:
162
+ # Log metrics to wandb
163
+ metrics = {
164
+ "train/loss": loss.item(),
165
+ "train/bpc": loss.item(),
166
+ "train/grad_norm": grad_norm.item(),
167
+ "train/learning_rate": lr,
168
+ "train/tokens_per_sec": config.batch_size * config.block_size / dt,
169
+ "train/iteration": iter_num,
170
+ }
171
+ wandb.log(metrics)
172
+
173
+ print(f"iter {iter_num}: loss {loss.item():.4f}, bpc {loss.item():.4f}, "
174
+ f"grad_norm {grad_norm:.2f}, lr {lr:.2e}")
175
+
176
+ # Evaluation
177
+ if iter_num % config.eval_interval == 0:
178
+ val_loss = estimate_loss(
179
+ model, val_data, config.eval_iters,
180
+ config.block_size, config.batch_size, device
181
+ )
182
+
183
+ if master_process:
184
+ # Log validation metrics
185
+ val_metrics = {
186
+ "val/loss": val_loss,
187
+ "val/bpc": val_loss,
188
+ "val/iteration": iter_num,
189
+ }
190
+ wandb.log(val_metrics)
191
+
192
+ print(f"step {iter_num}: val loss {val_loss:.4f}, val bpc {val_loss:.4f}")
193
+
194
+ # Save best model
195
+ if val_loss < best_val_loss:
196
+ best_val_loss = val_loss
197
+ if master_process:
198
+ print(f"Saving best model with val_loss: {best_val_loss:.4f}")
199
+ checkpoint = {
200
+ 'model_state_dict': model.state_dict(),
201
+ 'optimizer_state_dict': optimizer.state_dict(),
202
+ 'config': config,
203
+ 'iter_num': iter_num,
204
+ 'best_val_loss': best_val_loss,
205
+ }
206
+ torch.save(checkpoint, 'best_model_baseline.pt')
207
+
208
+ # Log best model to wandb
209
+ wandb.save('best_model_baseline.pt')
210
+ wandb.run.summary["best_val_loss"] = best_val_loss
211
+ wandb.run.summary["best_val_bpc"] = best_val_loss
212
+ wandb.run.summary["best_model_iter"] = iter_num
213
+
214
+ iter_num += 1
215
+
216
+ # End training if we reach max_iters
217
+ if iter_num > config.max_iters:
218
+ break
219
+
220
+ # Clean up
221
+ if ddp:
222
+ destroy_process_group()
223
+
224
+ if master_process:
225
+ wandb.finish()
226
+
227
+ if __name__ == '__main__':
228
+ main()
train_dtat.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for Dynamic Token-Aware Transformer (DTAT) on enwik8 dataset.
3
+ Based on NanoGPT's training structure with modifications for token importance awareness.
4
+ """
5
+
6
+ import os
7
+ import time
8
+ import math
9
+ import pickle
10
+ from contextlib import nullcontext
11
+ import numpy as np
12
+ import torch
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ from torch.distributed import init_process_group, destroy_process_group
15
+ import matplotlib.pyplot as plt
16
+ import wandb
17
+
18
+ from model_dtat import DTATTransformer
19
+ from config.dtat_config import get_config
20
+
21
+ # -----------------------------------------------------------------------------
22
+ # I/O
23
+ def get_batch(data, block_size, batch_size, device):
24
+ """Generate a small batch of data of inputs x and targets y."""
25
+ ix = torch.randint(len(data) - block_size, (batch_size,))
26
+ x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
27
+ y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
28
+ x, y = x.to(device), y.to(device)
29
+ return x, y
30
+
31
+ def compute_freq_table(data, vocab_size=256):
32
+ """Compute frequency table for the dataset."""
33
+ freq = np.bincount(data, minlength=vocab_size)
34
+ return freq / len(data)
35
+
36
+ def visualize_importance(importance_scores, tokens, save_path):
37
+ """Visualize token importance scores and log to wandb."""
38
+ plt.figure(figsize=(15, 5))
39
+ plt.bar(range(len(tokens)), importance_scores.squeeze().cpu())
40
+ plt.title('Token Importance Scores')
41
+ plt.xlabel('Token Position')
42
+ plt.ylabel('Importance Score')
43
+ plt.savefig(save_path)
44
+
45
+ # Log to wandb
46
+ if wandb.run is not None:
47
+ wandb.log({"token_importance": wandb.Image(save_path)})
48
+
49
+ plt.close()
50
+
51
+ # -----------------------------------------------------------------------------
52
+ # Training
53
+
54
+ def estimate_loss(model, data, config):
55
+ out = {}
56
+ model.eval()
57
+ losses = torch.zeros(config.eval_iters)
58
+ for k in range(config.eval_iters):
59
+ X, Y = get_batch(data, config.block_size, config.batch_size, config.device)
60
+ with torch.no_grad():
61
+ logits, loss, _ = model(X, Y)
62
+ losses[k] = loss.item()
63
+ out = losses.mean()
64
+ model.train()
65
+ return out
66
+
67
+ def get_lr(it, config):
68
+ # 1) Linear warmup for warmup_iters steps
69
+ if it < config.warmup_iters:
70
+ return config.learning_rate * it / config.warmup_iters
71
+ # 2) If it > lr_decay_iters, return min learning rate
72
+ if it > config.lr_decay_iters:
73
+ return config.min_lr
74
+ # 3) In between, use cosine decay down to min learning rate
75
+ decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
76
+ assert 0 <= decay_ratio <= 1
77
+ coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
78
+ return config.min_lr + coeff * (config.learning_rate - config.min_lr)
79
+
80
+ def main():
81
+ # Initialize distributed training if needed
82
+ ddp = int(os.environ.get('RANK', -1)) != -1
83
+ if ddp:
84
+ init_process_group(backend='nccl')
85
+ ddp_rank = int(os.environ['RANK'])
86
+ ddp_local_rank = int(os.environ['LOCAL_RANK'])
87
+ device = f'cuda:{ddp_local_rank}'
88
+ master_process = ddp_rank == 0
89
+ seed_offset = ddp_rank
90
+ assert config.batch_size % torch.cuda.device_count() == 0
91
+ config.batch_size = config.batch_size // torch.cuda.device_count()
92
+ else:
93
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
94
+ master_process = True
95
+ seed_offset = 0
96
+
97
+ # Set seed for reproducibility
98
+ torch.manual_seed(1337 + seed_offset)
99
+ torch.backends.cuda.matmul.allow_tf32 = True
100
+ torch.backends.cudnn.allow_tf32 = True
101
+ device_type = 'cuda' if 'cuda' in device else 'cpu'
102
+
103
+ # Get config
104
+ config = get_config()
105
+ config.device = device
106
+
107
+ # Initialize wandb
108
+ if master_process:
109
+ wandb.init(
110
+ project="enwik8-dtat",
111
+ config={
112
+ "architecture": "DTAT",
113
+ "dataset": "enwik8",
114
+ "batch_size": config.batch_size,
115
+ "learning_rate": config.learning_rate,
116
+ "warmup_iters": config.warmup_iters,
117
+ "block_size": config.block_size,
118
+ "n_layer": config.n_layer,
119
+ "n_head": config.n_head,
120
+ "n_embd": config.n_embd,
121
+ "dropout": config.dropout,
122
+ "sparse_topk": config.sparse_topk,
123
+ }
124
+ )
125
+
126
+ # Data loading
127
+ print("Loading data...")
128
+ data_dir = os.path.join('data')
129
+ train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint8, mode='r')
130
+ val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint8, mode='r')
131
+
132
+ # Compute frequency table for the training data
133
+ freq_table = compute_freq_table(train_data)
134
+
135
+ # Model init
136
+ print("Initializing model...")
137
+ model = DTATTransformer(config)
138
+ model.to(device)
139
+
140
+ # Optimizer
141
+ optimizer = torch.optim.AdamW(
142
+ model.parameters(),
143
+ lr=config.learning_rate,
144
+ betas=(config.beta1, config.beta2),
145
+ weight_decay=config.weight_decay
146
+ )
147
+
148
+ if ddp:
149
+ model = DDP(model, device_ids=[ddp_local_rank])
150
+
151
+ # Training loop
152
+ print("Starting training...")
153
+ best_val_loss = float('inf')
154
+ iter_num = 0
155
+
156
+ while True:
157
+ lr = get_lr(iter_num, config) if config.decay_lr else config.learning_rate
158
+ for param_group in optimizer.param_groups:
159
+ param_group['lr'] = lr
160
+
161
+ # Get batch
162
+ t0 = time.time()
163
+ X, Y = get_batch(train_data, config.block_size, config.batch_size, device)
164
+
165
+ # Forward pass
166
+ logits, loss, importance_scores = model(X, Y)
167
+
168
+ # Calculate additional metrics
169
+ importance_mean = importance_scores.mean().item()
170
+ importance_std = importance_scores.std().item()
171
+
172
+ # Backward pass
173
+ optimizer.zero_grad(set_to_none=True)
174
+ loss.backward()
175
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
176
+ optimizer.step()
177
+
178
+ # Timing and logging
179
+ t1 = time.time()
180
+ dt = t1 - t0
181
+
182
+ if iter_num % config.log_interval == 0 and master_process:
183
+ # Log metrics to wandb
184
+ metrics = {
185
+ "train/loss": loss.item(),
186
+ "train/bpc": loss.item(),
187
+ "train/importance_mean": importance_mean,
188
+ "train/importance_std": importance_std,
189
+ "train/grad_norm": grad_norm.item(),
190
+ "train/learning_rate": lr,
191
+ "train/tokens_per_sec": config.batch_size * config.block_size / dt,
192
+ "train/iteration": iter_num,
193
+ }
194
+ wandb.log(metrics)
195
+
196
+ print(f"iter {iter_num}: loss {loss.item():.4f}, bpc {loss.item():.4f}, "
197
+ f"importance_mean {importance_mean:.3f}, grad_norm {grad_norm:.2f}")
198
+
199
+ # Visualize importance scores periodically
200
+ if iter_num % (config.log_interval * 10) == 0:
201
+ visualize_importance(
202
+ importance_scores[0],
203
+ X[0].cpu().numpy(),
204
+ f'importance_scores_iter_{iter_num}.png'
205
+ )
206
+
207
+ # Evaluation
208
+ if iter_num % config.eval_interval == 0:
209
+ val_loss = estimate_loss(model, val_data, config)
210
+
211
+ # Log validation metrics
212
+ if master_process:
213
+ val_metrics = {
214
+ "val/loss": val_loss,
215
+ "val/bpc": val_loss,
216
+ "val/iteration": iter_num,
217
+ }
218
+ wandb.log(val_metrics)
219
+
220
+ print(f"step {iter_num}: val loss {val_loss:.4f}, val bpc {val_loss:.4f}")
221
+
222
+ # Save best model
223
+ if val_loss < best_val_loss:
224
+ best_val_loss = val_loss
225
+ if master_process:
226
+ print(f"Saving best model with val_loss: {best_val_loss:.4f}")
227
+ checkpoint = {
228
+ 'model_state_dict': model.state_dict(),
229
+ 'optimizer_state_dict': optimizer.state_dict(),
230
+ 'config': config,
231
+ 'iter_num': iter_num,
232
+ 'best_val_loss': best_val_loss,
233
+ }
234
+ torch.save(checkpoint, 'best_model_dtat.pt')
235
+
236
+ # Log best model to wandb
237
+ wandb.save('best_model_dtat.pt')
238
+ wandb.run.summary["best_val_loss"] = best_val_loss
239
+ wandb.run.summary["best_val_bpc"] = best_val_loss
240
+ wandb.run.summary["best_model_iter"] = iter_num
241
+
242
+ iter_num += 1
243
+
244
+ # End training if we reach max_iters
245
+ if iter_num > config.max_iters:
246
+ break
247
+
248
+ # Clean up
249
+ if ddp:
250
+ destroy_process_group()
251
+
252
+ if master_process:
253
+ wandb.finish()
254
+
255
+ if __name__ == '__main__':
256
+ main()
train_enwik8.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import math
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from model_modified import GPTModified
7
+ import numpy as np
8
+ from contextlib import nullcontext
9
+
10
+ # Import configurations
11
+ from config.char_config import config
12
+
13
+ def load_data(split):
14
+ """Load binary data from split file."""
15
+ filename = os.path.join('data', f'{split}.bin')
16
+ with open(filename, 'rb') as f:
17
+ data = np.fromfile(f, dtype=np.uint8)
18
+ return data
19
+
20
+ def get_batch(data, block_size, batch_size, device):
21
+ """Generate a small batch of data of inputs x and targets y."""
22
+ ix = torch.randint(len(data) - block_size, (batch_size,))
23
+ x = torch.stack([torch.from_numpy(data[i:i+block_size].astype(np.int64)) for i in ix])
24
+ y = torch.stack([torch.from_numpy(data[i+1:i+1+block_size].astype(np.int64)) for i in ix])
25
+ x, y = x.to(device), y.to(device)
26
+ return x, y
27
+
28
+ def estimate_loss(model, data, eval_iters, block_size, batch_size, device):
29
+ """Estimate loss on data split."""
30
+ out = {}
31
+ model.eval()
32
+ losses = torch.zeros(eval_iters)
33
+ for k in range(eval_iters):
34
+ X, Y = get_batch(data, block_size, batch_size, device)
35
+ with torch.no_grad():
36
+ logits, loss = model(X, Y)
37
+ losses[k] = loss.item()
38
+ out = losses.mean()
39
+ model.train()
40
+ return out
41
+
42
+ def convert_to_bpc(loss):
43
+ """Convert from natural log (nats) to bits per character (bpc)."""
44
+ return loss / math.log(2)
45
+
46
+ def main():
47
+ # System setup
48
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
49
+ print(f"Using device: {device}")
50
+
51
+ dtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[config['dtype']]
52
+ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[config['dtype']]
53
+ ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
54
+
55
+ # Data loading
56
+ print("Loading data...")
57
+ train_data = load_data('train')
58
+ val_data = load_data('val')
59
+
60
+ # Model init
61
+ print("Initializing model...")
62
+ model = GPTModified(config)
63
+ model.to(device)
64
+
65
+ # Optimizer
66
+ optimizer = torch.optim.AdamW(
67
+ model.parameters(),
68
+ lr=config['learning_rate'],
69
+ betas=(config['beta1'], config['beta2']),
70
+ weight_decay=config['weight_decay']
71
+ )
72
+
73
+ # Training loop
74
+ best_val_loss = float('inf')
75
+ batch_size = 32
76
+
77
+ for iter in range(config['max_iters']):
78
+ # Sample a batch of data
79
+ xb, yb = get_batch(train_data, config['block_size'], batch_size, device)
80
+
81
+ # Forward pass
82
+ with ctx:
83
+ logits, loss = model(xb, yb)
84
+ loss_bpc = convert_to_bpc(loss.item())
85
+
86
+ # Backward pass
87
+ optimizer.zero_grad(set_to_none=True)
88
+ loss.backward()
89
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config['grad_clip'])
90
+ optimizer.step()
91
+
92
+ # Logging
93
+ if iter % config['log_interval'] == 0:
94
+ print(f"iter {iter}: train loss {loss_bpc:.4f} bpc")
95
+
96
+ # Evaluation
97
+ if iter % config['eval_interval'] == 0:
98
+ val_loss = estimate_loss(model, val_data, config['eval_iters'],
99
+ config['block_size'], batch_size, device)
100
+ val_bpc = convert_to_bpc(val_loss)
101
+ print(f"iter {iter}: val loss {val_bpc:.4f} bpc")
102
+
103
+ # Save best model
104
+ if val_bpc < best_val_loss:
105
+ best_val_loss = val_bpc
106
+ torch.save({
107
+ 'model_state_dict': model.state_dict(),
108
+ 'optimizer_state_dict': optimizer.state_dict(),
109
+ 'iter': iter,
110
+ 'best_val_loss': best_val_loss,
111
+ }, 'best_model.pt')
112
+
113
+ if __name__ == '__main__':
114
+ main()
transformer_sizing.ipynb ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "### Transformer Theoretical Model\n",
9
+ "\n",
10
+ "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 1,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from collections import OrderedDict"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 2,
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "# config_args = {\n",
29
+ "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n",
30
+ "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n",
31
+ "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n",
32
+ "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n",
33
+ "# }[model_type]\n",
34
+ "\n",
35
+ "block_size = 1024\n",
36
+ "vocab_size = 50257\n",
37
+ "n_layer = 12\n",
38
+ "n_head = 12\n",
39
+ "n_embd = 768\n",
40
+ "bias = False\n",
41
+ "assert not bias, \"this notebook assumes bias=False just for simplicity\""
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": 3,
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "name": "stdout",
51
+ "output_type": "stream",
52
+ "text": [
53
+ "we see: 124337664, expected: 124337664, match: True\n",
54
+ "name params ratio (%) \n",
55
+ "emebedding/position 786432 0.6325\n",
56
+ "embedding/token 38597376 31.0424\n",
57
+ "embedding 39383808 31.6749\n",
58
+ "attention/ln 768 0.0006\n",
59
+ "attention/kqv 1769472 1.4231\n",
60
+ "attention/proj 589824 0.4744\n",
61
+ "attention 2360064 1.8981\n",
62
+ "mlp/ln 768 0.0006\n",
63
+ "mlp/ffw 2359296 1.8975\n",
64
+ "mlp/proj 2359296 1.8975\n",
65
+ "mlp 4719360 3.7956\n",
66
+ "block 7079424 5.6937\n",
67
+ "transformer 84953088 68.3245\n",
68
+ "ln_f 768 0.0006\n",
69
+ "dense 0 0.0000\n",
70
+ "total 124337664 100.0000\n"
71
+ ]
72
+ }
73
+ ],
74
+ "source": [
75
+ "def params():\n",
76
+ " \"\"\" estimates the number of parameters in the model\"\"\"\n",
77
+ " out = OrderedDict()\n",
78
+ "\n",
79
+ " # token and position embeddings\n",
80
+ " out['emebedding/position'] = n_embd * block_size\n",
81
+ " out['embedding/token'] = n_embd * vocab_size\n",
82
+ " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n",
83
+ "\n",
84
+ " # attention blocks\n",
85
+ " out['attention/ln'] = n_embd # note, bias=False in our LN\n",
86
+ " out['attention/kqv'] = n_embd * 3*n_embd\n",
87
+ " out['attention/proj'] = n_embd**2\n",
88
+ " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n",
89
+ "\n",
90
+ " # MLP blocks\n",
91
+ " ffw_size = 4*n_embd # feed forward size\n",
92
+ " out['mlp/ln'] = n_embd\n",
93
+ " out['mlp/ffw'] = n_embd * ffw_size\n",
94
+ " out['mlp/proj'] = ffw_size * n_embd\n",
95
+ " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n",
96
+ " \n",
97
+ " # the transformer and the rest of it\n",
98
+ " out['block'] = out['attention'] + out['mlp']\n",
99
+ " out['transformer'] = n_layer * out['block']\n",
100
+ " out['ln_f'] = n_embd # final layernorm\n",
101
+ " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n",
102
+ "\n",
103
+ " # total\n",
104
+ " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n",
105
+ "\n",
106
+ " return out\n",
107
+ "\n",
108
+ "# compare our param count to that reported by PyTorch\n",
109
+ "p = params()\n",
110
+ "params_total = p['total']\n",
111
+ "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n",
112
+ "# create a header\n",
113
+ "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n",
114
+ "for k,v in p.items():\n",
115
+ " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n",
116
+ " "
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 4,
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "est checkpoint size: 1.49 GB\n",
129
+ "measured with wc -c ckpt.pt: 1542470366\n",
130
+ "fluff ratio: 103.38%\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "# we can now calculate the size of each checkpoint\n",
136
+ "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n",
137
+ "params_bytes = params_total*4\n",
138
+ "params_and_buffers_bytes = params_bytes + 2*params_bytes\n",
139
+ "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n",
140
+ "measured_bytes = 1542470366 # from wc -c ckpt.pt\n",
141
+ "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n",
142
+ "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")"
143
+ ]
144
+ },
145
+ {
146
+ "attachments": {},
147
+ "cell_type": "markdown",
148
+ "metadata": {},
149
+ "source": [
150
+ "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 5,
156
+ "metadata": {},
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "memory ratio taken up just for parameters: 3.73%\n"
163
+ ]
164
+ }
165
+ ],
166
+ "source": [
167
+ "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n",
168
+ "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")"
169
+ ]
170
+ },
171
+ {
172
+ "attachments": {},
173
+ "cell_type": "markdown",
174
+ "metadata": {},
175
+ "source": [
176
+ "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models."
177
+ ]
178
+ },
179
+ {
180
+ "attachments": {},
181
+ "cell_type": "markdown",
182
+ "metadata": {},
183
+ "source": [
184
+ "Let's estimate FLOPs for a single forward pass."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": 6,
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "name": "stdout",
194
+ "output_type": "stream",
195
+ "text": [
196
+ "name flops ratio (%) \n",
197
+ "attention/kqv 3623878656 1.2426\n",
198
+ "attention/scores 1610612736 0.5522\n",
199
+ "attention/reduce 1610612736 0.5522\n",
200
+ "attention/proj 1207959552 0.4142\n",
201
+ "attention 8053063680 2.7612\n",
202
+ "mlp/ffw1 4831838208 1.6567\n",
203
+ "mlp/ffw2 4831838208 1.6567\n",
204
+ "mlp 9663676416 3.3135\n",
205
+ "block 17716740096 6.0747\n",
206
+ "transformer 212600881152 72.8963\n",
207
+ "dense 79047426048 27.1037\n",
208
+ "forward_total 291648307200 100.0000\n",
209
+ "backward_total 583296614400 200.0000\n",
210
+ "total 874944921600 300.0000\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "def flops():\n",
216
+ " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n",
217
+ " # we count actual FLOPs, not MACs. Hence 2* all over the place\n",
218
+ " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n",
219
+ "\n",
220
+ " out = OrderedDict()\n",
221
+ " head_size = n_embd // n_head\n",
222
+ "\n",
223
+ " # attention blocks\n",
224
+ " # 1) the projection to key, query, values\n",
225
+ " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n",
226
+ " # 2) calculating the attention scores\n",
227
+ " out['attention/scores'] = 2 * block_size * block_size * n_embd\n",
228
+ " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
229
+ " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n",
230
+ " # 4) the final linear projection\n",
231
+ " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n",
232
+ " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n",
233
+ "\n",
234
+ " # MLP blocks\n",
235
+ " ffw_size = 4*n_embd # feed forward size\n",
236
+ " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n",
237
+ " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n",
238
+ " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n",
239
+ "\n",
240
+ " # the transformer and the rest of it\n",
241
+ " out['block'] = out['attention'] + out['mlp']\n",
242
+ " out['transformer'] = n_layer * out['block']\n",
243
+ " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n",
244
+ "\n",
245
+ " # forward,backward,total\n",
246
+ " out['forward_total'] = out['transformer'] + out['dense']\n",
247
+ " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n",
248
+ " out['total'] = out['forward_total'] + out['backward_total']\n",
249
+ "\n",
250
+ " return out\n",
251
+ " \n",
252
+ "# compare our param count to that reported by PyTorch\n",
253
+ "f = flops()\n",
254
+ "flops_total = f['forward_total']\n",
255
+ "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n",
256
+ "for k,v in f.items():\n",
257
+ " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n",
258
+ " "
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 7,
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "name": "stdout",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "# now here is an estimate copy pasted from the PaLM paper\n",
276
+ "# this formula is often used to calculate MFU (model flops utilization)\n",
277
+ "def palm_flops():\n",
278
+ " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
279
+ " # non-embedding model parameters. note that we do not subtract the\n",
280
+ " # embedding/token params because those are tied and get used in the last layer.\n",
281
+ " N = params()['total'] - params()['emebedding/position']\n",
282
+ " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
283
+ " mf_per_token = 6*N + 12*L*H*Q*T\n",
284
+ " mf = mf_per_token * block_size\n",
285
+ " return mf\n",
286
+ "\n",
287
+ "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")"
288
+ ]
289
+ },
290
+ {
291
+ "attachments": {},
292
+ "cell_type": "markdown",
293
+ "metadata": {},
294
+ "source": [
295
+ "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 8,
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "fraction of A100 used: 37.14%\n"
308
+ ]
309
+ }
310
+ ],
311
+ "source": [
312
+ "# here is what we currently roughly measure\n",
313
+ "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n",
314
+ "measured_time = 0.755 # in seconds per iteration\n",
315
+ "measured_throughput = batch_size / measured_time\n",
316
+ "flops_achieved = f['total'] * measured_throughput\n",
317
+ "\n",
318
+ "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n",
319
+ "a100_flops_promised = 312e12\n",
320
+ "\n",
321
+ "# the fraction of the A100 that we are using:\n",
322
+ "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")"
323
+ ]
324
+ },
325
+ {
326
+ "attachments": {},
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU."
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 9,
336
+ "metadata": {},
337
+ "outputs": [
338
+ {
339
+ "name": "stdout",
340
+ "output_type": "stream",
341
+ "text": [
342
+ "time needed to train the model: 3.46 days\n"
343
+ ]
344
+ }
345
+ ],
346
+ "source": [
347
+ "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n",
348
+ "model_size = params()['total'] # this is number of parameters, N\n",
349
+ "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n",
350
+ "a100_flops = 312e12 # 312 TFLOPS\n",
351
+ "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n",
352
+ "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n",
353
+ "flops_needed = 6 * model_size * tokens_num # 6ND\n",
354
+ "time_needed_s = flops_needed / flops_throughput # in seconds\n",
355
+ "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")"
356
+ ]
357
+ },
358
+ {
359
+ "attachments": {},
360
+ "cell_type": "markdown",
361
+ "metadata": {},
362
+ "source": [
363
+ "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)."
364
+ ]
365
+ },
366
+ {
367
+ "attachments": {},
368
+ "cell_type": "markdown",
369
+ "metadata": {},
370
+ "source": [
371
+ "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later."
372
+ ]
373
+ }
374
+ ],
375
+ "metadata": {
376
+ "kernelspec": {
377
+ "display_name": "pytorch2",
378
+ "language": "python",
379
+ "name": "python3"
380
+ },
381
+ "language_info": {
382
+ "codemirror_mode": {
383
+ "name": "ipython",
384
+ "version": 3
385
+ },
386
+ "file_extension": ".py",
387
+ "mimetype": "text/x-python",
388
+ "name": "python",
389
+ "nbconvert_exporter": "python",
390
+ "pygments_lexer": "ipython3",
391
+ "version": "3.10.8"
392
+ },
393
+ "orig_nbformat": 4,
394
+ "vscode": {
395
+ "interpreter": {
396
+ "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222"
397
+ }
398
+ }
399
+ },
400
+ "nbformat": 4,
401
+ "nbformat_minor": 2
402
+ }
wandb/run-20241230_125819-geso4xvw/files/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.18.6
4
+ m: []
5
+ python_version: 3.11.7
6
+ t:
7
+ "1":
8
+ - 1
9
+ - 55
10
+ - 105
11
+ "2":
12
+ - 1
13
+ - 55
14
+ - 105
15
+ "3":
16
+ - 16
17
+ - 23
18
+ - 55
19
+ "4": 3.11.7
20
+ "5": 0.18.6
21
+ "8":
22
+ - 3
23
+ - 5
24
+ "12": 0.18.6
25
+ "13": windows-amd64
26
+ architecture:
27
+ value: DTAT
28
+ batch_size:
29
+ value: 32
30
+ block_size:
31
+ value: 1024
32
+ dataset:
33
+ value: enwik8
34
+ dropout:
35
+ value: 0.1
36
+ learning_rate:
37
+ value: 0.0006
38
+ n_embd:
39
+ value: 512
40
+ n_head:
41
+ value: 8
42
+ n_layer:
43
+ value: 12
44
+ sparse_topk:
45
+ value: 32
46
+ warmup_iters:
47
+ value: 2000
wandb/run-20241230_125819-geso4xvw/files/output.log ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Loading data...
2
+ Initializing model...
3
+ Traceback (most recent call last):
4
+ File "C:\sakana\enwik8-model\train_dtat.py", line 256, in <module>
5
+ main()
6
+ File "C:\sakana\enwik8-model\train_dtat.py", line 137, in main
7
+ model = DTATTransformer(config)
8
+ ^^^^^^^^^^^^^^^^^^^^^^^
9
+ File "C:\sakana\enwik8-model\model_dtat.py", line 172, in __init__
10
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
11
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12
+ File "C:\sakana\enwik8-model\model_dtat.py", line 172, in <listcomp>
13
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
14
+ ^^^^^^^^^^^^^
15
+ File "C:\sakana\enwik8-model\model_dtat.py", line 129, in __init__
16
+ self.attn = SparseDenseAttention(config)
17
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
18
+ File "C:\sakana\enwik8-model\model_dtat.py", line 80, in __init__
19
+ self.sparse_topk = config.get('sparse_topk', 32) # Number of tokens to attend to for less important tokens
20
+ ^^^^^^^^^^
21
+ AttributeError: 'DTATConfig' object has no attribute 'get'
wandb/run-20241230_125819-geso4xvw/files/wandb-metadata.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Windows-10-10.0.26100-SP0",
3
+ "python": "3.11.7",
4
+ "startedAt": "2024-12-30T10:58:19.924711Z",
5
+ "program": "C:\\sakana\\enwik8-model\\train_dtat.py",
6
+ "codePath": "train_dtat.py",
7
+ "git": {
8
+ "remote": "https://github.com/karpathy/nanoGPT.git",
9
+ "commit": "93a43d9a5c22450bbf06e78da2cb6eeef084b717"
10
+ },
11
+ "email": "mitel40181@gholar.com",
12
+ "root": "C:\\sakana\\enwik8-model",
13
+ "host": "SILX",
14
+ "username": "silxs",
15
+ "executable": "C:\\fcc-intro-to-llms\\cuda\\Scripts\\python.exe",
16
+ "codePathLocal": "train_dtat.py",
17
+ "cpu_count": 8,
18
+ "cpu_count_logical": 16,
19
+ "gpu": "NVIDIA GeForce RTX 3050 Laptop GPU",
20
+ "gpu_count": 1,
21
+ "disk": {
22
+ "/": {
23
+ "total": "487147769856",
24
+ "used": "485680205824"
25
+ }
26
+ },
27
+ "memory": {
28
+ "total": "16387997696"
29
+ },
30
+ "cpu": {
31
+ "count": 8,
32
+ "countLogical": 16
33
+ },
34
+ "gpu_nvidia": [
35
+ {
36
+ "name": "NVIDIA GeForce RTX 3050 Laptop GPU",
37
+ "memoryTotal": "4294967296",
38
+ "cudaCores": 2048,
39
+ "architecture": "Ampere"
40
+ }
41
+ ],
42
+ "cudaVersion": "12.6"
43
+ }
wandb/run-20241230_125819-geso4xvw/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":1}}
wandb/run-20241230_125819-geso4xvw/logs/debug-core.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-12-30T12:58:19.192321+02:00","level":"INFO","msg":"started logging, with flags","port-filename":"C:\\Users\\silxs\\AppData\\Local\\Temp\\tmpf0be7_0z\\port-16680.txt","pid":16680,"debug":false,"disable-analytics":false}
2
+ {"time":"2024-12-30T12:58:19.192321+02:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
3
+ {"time":"2024-12-30T12:58:19.1989835+02:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":16680}
4
+ {"time":"2024-12-30T12:58:19.1989835+02:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":53467,"Zone":""}}
5
+ {"time":"2024-12-30T12:58:19.3765291+02:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:53468"}
6
+ {"time":"2024-12-30T12:58:19.9252228+02:00","level":"INFO","msg":"handleInformInit: received","streamId":"geso4xvw","id":"127.0.0.1:53468"}
7
+ {"time":"2024-12-30T12:58:20.0386713+02:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"geso4xvw","id":"127.0.0.1:53468"}
8
+ {"time":"2024-12-30T12:58:21.8885719+02:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:53468"}
9
+ {"time":"2024-12-30T12:58:21.8891443+02:00","level":"INFO","msg":"server is shutting down"}
10
+ {"time":"2024-12-30T12:58:21.8891443+02:00","level":"INFO","msg":"connection: Close: initiating connection closure","id":"127.0.0.1:53468"}
11
+ {"time":"2024-12-30T12:58:21.8891443+02:00","level":"INFO","msg":"connection: Close: connection successfully closed","id":"127.0.0.1:53468"}
12
+ {"time":"2024-12-30T12:58:25.467087+02:00","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"127.0.0.1:53468"}
13
+ {"time":"2024-12-30T12:58:25.467087+02:00","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"127.0.0.1:53468"}
14
+ {"time":"2024-12-30T12:58:25.467087+02:00","level":"INFO","msg":"server is closed"}
wandb/run-20241230_125819-geso4xvw/logs/debug-internal.log ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-12-30T12:58:19.9262449+02:00","level":"INFO","msg":"using version","core version":"0.18.6"}
2
+ {"time":"2024-12-30T12:58:19.9267933+02:00","level":"INFO","msg":"created symlink","path":"C:\\sakana\\enwik8-model\\wandb\\run-20241230_125819-geso4xvw\\logs\\debug-core.log"}
3
+ {"time":"2024-12-30T12:58:20.0381603+02:00","level":"INFO","msg":"created new stream","id":"geso4xvw"}
4
+ {"time":"2024-12-30T12:58:20.0386713+02:00","level":"INFO","msg":"stream: started","id":"geso4xvw"}
5
+ {"time":"2024-12-30T12:58:20.0386713+02:00","level":"INFO","msg":"handler: started","stream_id":{"value":"geso4xvw"}}
6
+ {"time":"2024-12-30T12:58:20.0386713+02:00","level":"INFO","msg":"sender: started","stream_id":"geso4xvw"}
7
+ {"time":"2024-12-30T12:58:20.0386713+02:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"geso4xvw"}}
8
+ {"time":"2024-12-30T12:58:20.9024895+02:00","level":"INFO","msg":"Starting system monitor"}
9
+ {"time":"2024-12-30T12:58:21.8891443+02:00","level":"INFO","msg":"stream: closing","id":"geso4xvw"}
10
+ {"time":"2024-12-30T12:58:21.8891443+02:00","level":"INFO","msg":"Stopping system monitor"}
11
+ {"time":"2024-12-30T12:58:21.8901726+02:00","level":"INFO","msg":"Stopped system monitor"}
12
+ {"time":"2024-12-30T12:58:24.8986622+02:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
13
+ {"time":"2024-12-30T12:58:25.4660418+02:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"geso4xvw"}}
14
+ {"time":"2024-12-30T12:58:25.4660418+02:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"geso4xvw"}}
15
+ {"time":"2024-12-30T12:58:25.4660418+02:00","level":"INFO","msg":"sender: closed","stream_id":"geso4xvw"}
16
+ {"time":"2024-12-30T12:58:25.4665528+02:00","level":"INFO","msg":"stream: closed","id":"geso4xvw"}
wandb/run-20241230_125819-geso4xvw/logs/debug.log ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Current SDK version is 0.18.6
2
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Configure stats pid to 16680
3
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Loading settings from C:\Users\silxs\.config\wandb\settings
4
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Loading settings from C:\sakana\enwik8-model\wandb\settings
5
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Loading settings from environment variables: {}
6
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': None, '_disable_service': None}
7
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': 'train_dtat.py', 'program_abspath': 'C:\\sakana\\enwik8-model\\train_dtat.py', 'program': 'C:\\sakana\\enwik8-model\\train_dtat.py'}
8
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_setup.py:_flush():79] Applying login settings: {}
9
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_init.py:_log_setup():533] Logging user logs to C:\sakana\enwik8-model\wandb\run-20241230_125819-geso4xvw\logs\debug.log
10
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_init.py:_log_setup():534] Logging internal logs to C:\sakana\enwik8-model\wandb\run-20241230_125819-geso4xvw\logs\debug-internal.log
11
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_init.py:init():619] calling init triggers
12
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_init.py:init():626] wandb.init called with sweep_config: {}
13
+ config: {'architecture': 'DTAT', 'dataset': 'enwik8', 'batch_size': 32, 'learning_rate': 0.0006, 'warmup_iters': 2000, 'block_size': 1024, 'n_layer': 12, 'n_head': 8, 'n_embd': 512, 'dropout': 0.1, 'sparse_topk': 32}
14
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_init.py:init():669] starting backend
15
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [wandb_init.py:init():673] sending inform_init request
16
+ 2024-12-30 12:58:19,921 INFO MainThread:16680 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=spawn, using: spawn
17
+ 2024-12-30 12:58:19,924 INFO MainThread:16680 [wandb_init.py:init():686] backend started and connected
18
+ 2024-12-30 12:58:19,927 INFO MainThread:16680 [wandb_init.py:init():781] updated telemetry
19
+ 2024-12-30 12:58:19,977 INFO MainThread:16680 [wandb_init.py:init():814] communicating run to backend with 90.0 second timeout
20
+ 2024-12-30 12:58:20,892 INFO MainThread:16680 [wandb_init.py:init():867] starting run threads in backend
21
+ 2024-12-30 12:58:21,272 INFO MainThread:16680 [wandb_run.py:_console_start():2451] atexit reg
22
+ 2024-12-30 12:58:21,272 INFO MainThread:16680 [wandb_run.py:_redirect():2299] redirect: wrap_raw
23
+ 2024-12-30 12:58:21,272 INFO MainThread:16680 [wandb_run.py:_redirect():2364] Wrapping output streams.
24
+ 2024-12-30 12:58:21,272 INFO MainThread:16680 [wandb_run.py:_redirect():2389] Redirects installed.
25
+ 2024-12-30 12:58:21,277 INFO MainThread:16680 [wandb_init.py:init():911] run started, returning control to user process
26
+ 2024-12-30 12:58:21,889 WARNING MsgRouterThr:16680 [router.py:message_loop():75] message_loop has been closed
wandb/run-20241230_125819-geso4xvw/run-geso4xvw.wandb ADDED
Binary file (3.19 kB). View file
 
wandb/run-20241230_125924-h4hgg9ir/files/config.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _wandb:
2
+ value:
3
+ cli_version: 0.18.6
4
+ m: []
5
+ python_version: 3.11.7
6
+ t:
7
+ "1":
8
+ - 1
9
+ - 55
10
+ - 105
11
+ "2":
12
+ - 1
13
+ - 55
14
+ - 105
15
+ "3":
16
+ - 16
17
+ - 23
18
+ - 55
19
+ "4": 3.11.7
20
+ "5": 0.18.6
21
+ "8":
22
+ - 3
23
+ - 5
24
+ "12": 0.18.6
25
+ "13": windows-amd64
26
+ architecture:
27
+ value: DTAT
28
+ batch_size:
29
+ value: 32
30
+ block_size:
31
+ value: 1024
32
+ dataset:
33
+ value: enwik8
34
+ dropout:
35
+ value: 0.1
36
+ learning_rate:
37
+ value: 0.0006
38
+ n_embd:
39
+ value: 512
40
+ n_head:
41
+ value: 8
42
+ n_layer:
43
+ value: 12
44
+ sparse_topk:
45
+ value: 32
46
+ warmup_iters:
47
+ value: 2000
wandb/run-20241230_125924-h4hgg9ir/files/output.log ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Loading data...
2
+ Initializing model...
3
+ number of parameters: 42.40M
4
+ Starting training...
5
+ Traceback (most recent call last):
6
+ File "C:\sakana\enwik8-model\train_dtat.py", line 256, in <module>
7
+ main()
8
+ File "C:\sakana\enwik8-model\train_dtat.py", line 166, in main
9
+ logits, loss, importance_scores = model(X, Y)
10
+ ^^^^^^^^^^^
11
+ File "C:\fcc-intro-to-llms\cuda\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
12
+ return self._call_impl(*args, **kwargs)
13
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
14
+ File "C:\fcc-intro-to-llms\cuda\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
15
+ return forward_call(*args, **kwargs)
16
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
17
+ File "C:\sakana\enwik8-model\model_dtat.py", line 218, in forward
18
+ importance_scores = self.importance_net(x, freq_table, pos)
19
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
20
+ File "C:\fcc-intro-to-llms\cuda\Lib\site-packages\torch\nn\modules\module.py", line 1736, in _wrapped_call_impl
21
+ return self._call_impl(*args, **kwargs)
22
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23
+ File "C:\fcc-intro-to-llms\cuda\Lib\site-packages\torch\nn\modules\module.py", line 1747, in _call_impl
24
+ return forward_call(*args, **kwargs)
25
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26
+ File "C:\sakana\enwik8-model\model_dtat.py", line 51, in forward
27
+ combined = torch.cat([x, freq_emb, pos_emb], dim=-1)
28
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29
+ RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1024 but got size 256 for tensor number 1 in the list.
wandb/run-20241230_125924-h4hgg9ir/files/wandb-metadata.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Windows-10-10.0.26100-SP0",
3
+ "python": "3.11.7",
4
+ "startedAt": "2024-12-30T10:59:24.719225Z",
5
+ "program": "C:\\sakana\\enwik8-model\\train_dtat.py",
6
+ "codePath": "train_dtat.py",
7
+ "git": {
8
+ "remote": "https://github.com/karpathy/nanoGPT.git",
9
+ "commit": "93a43d9a5c22450bbf06e78da2cb6eeef084b717"
10
+ },
11
+ "email": "mitel40181@gholar.com",
12
+ "root": "C:\\sakana\\enwik8-model",
13
+ "host": "SILX",
14
+ "username": "silxs",
15
+ "executable": "C:\\fcc-intro-to-llms\\cuda\\Scripts\\python.exe",
16
+ "codePathLocal": "train_dtat.py",
17
+ "cpu_count": 8,
18
+ "cpu_count_logical": 16,
19
+ "gpu": "NVIDIA GeForce RTX 3050 Laptop GPU",
20
+ "gpu_count": 1,
21
+ "disk": {
22
+ "/": {
23
+ "total": "487147769856",
24
+ "used": "485685227520"
25
+ }
26
+ },
27
+ "memory": {
28
+ "total": "16387997696"
29
+ },
30
+ "cpu": {
31
+ "count": 8,
32
+ "countLogical": 16
33
+ },
34
+ "gpu_nvidia": [
35
+ {
36
+ "name": "NVIDIA GeForce RTX 3050 Laptop GPU",
37
+ "memoryTotal": "4294967296",
38
+ "cudaCores": 2048,
39
+ "architecture": "Ampere"
40
+ }
41
+ ],
42
+ "cudaVersion": "12.6"
43
+ }
wandb/run-20241230_125924-h4hgg9ir/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"_wandb":{"runtime":4}}
wandb/run-20241230_125924-h4hgg9ir/logs/debug-core.log ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-12-30T12:59:23.8853518+02:00","level":"INFO","msg":"started logging, with flags","port-filename":"C:\\Users\\silxs\\AppData\\Local\\Temp\\tmpmcp7jkur\\port-36980.txt","pid":36980,"debug":false,"disable-analytics":false}
2
+ {"time":"2024-12-30T12:59:23.8853518+02:00","level":"INFO","msg":"FeatureState","shutdownOnParentExitEnabled":false}
3
+ {"time":"2024-12-30T12:59:23.8919931+02:00","level":"INFO","msg":"Will exit if parent process dies.","ppid":36980}
4
+ {"time":"2024-12-30T12:59:23.8919931+02:00","level":"INFO","msg":"server is running","addr":{"IP":"127.0.0.1","Port":53707,"Zone":""}}
5
+ {"time":"2024-12-30T12:59:24.0739714+02:00","level":"INFO","msg":"connection: ManageConnectionData: new connection created","id":"127.0.0.1:53716"}
6
+ {"time":"2024-12-30T12:59:24.7197359+02:00","level":"INFO","msg":"handleInformInit: received","streamId":"h4hgg9ir","id":"127.0.0.1:53716"}
7
+ {"time":"2024-12-30T12:59:24.831581+02:00","level":"INFO","msg":"handleInformInit: stream started","streamId":"h4hgg9ir","id":"127.0.0.1:53716"}
8
+ {"time":"2024-12-30T12:59:28.8801803+02:00","level":"INFO","msg":"handleInformTeardown: server teardown initiated","id":"127.0.0.1:53716"}
9
+ {"time":"2024-12-30T12:59:28.8801803+02:00","level":"INFO","msg":"connection: Close: initiating connection closure","id":"127.0.0.1:53716"}
10
+ {"time":"2024-12-30T12:59:28.8801803+02:00","level":"INFO","msg":"server is shutting down"}
11
+ {"time":"2024-12-30T12:59:28.8801803+02:00","level":"INFO","msg":"connection: Close: connection successfully closed","id":"127.0.0.1:53716"}
12
+ {"time":"2024-12-30T12:59:53.2992062+02:00","level":"INFO","msg":"handleInformTeardown: server shutdown complete","id":"127.0.0.1:53716"}
13
+ {"time":"2024-12-30T12:59:53.2992062+02:00","level":"INFO","msg":"connection: ManageConnectionData: connection closed","id":"127.0.0.1:53716"}
14
+ {"time":"2024-12-30T12:59:53.2992062+02:00","level":"INFO","msg":"server is closed"}
wandb/run-20241230_125924-h4hgg9ir/logs/debug-internal.log ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"time":"2024-12-30T12:59:24.7202464+02:00","level":"INFO","msg":"using version","core version":"0.18.6"}
2
+ {"time":"2024-12-30T12:59:24.7207602+02:00","level":"INFO","msg":"created symlink","path":"C:\\sakana\\enwik8-model\\wandb\\run-20241230_125924-h4hgg9ir\\logs\\debug-core.log"}
3
+ {"time":"2024-12-30T12:59:24.8310365+02:00","level":"INFO","msg":"created new stream","id":"h4hgg9ir"}
4
+ {"time":"2024-12-30T12:59:24.831581+02:00","level":"INFO","msg":"stream: started","id":"h4hgg9ir"}
5
+ {"time":"2024-12-30T12:59:24.831581+02:00","level":"INFO","msg":"sender: started","stream_id":"h4hgg9ir"}
6
+ {"time":"2024-12-30T12:59:24.831581+02:00","level":"INFO","msg":"handler: started","stream_id":{"value":"h4hgg9ir"}}
7
+ {"time":"2024-12-30T12:59:24.831581+02:00","level":"INFO","msg":"writer: Do: started","stream_id":{"value":"h4hgg9ir"}}
8
+ {"time":"2024-12-30T12:59:25.363056+02:00","level":"INFO","msg":"Starting system monitor"}
9
+ {"time":"2024-12-30T12:59:28.8801803+02:00","level":"INFO","msg":"stream: closing","id":"h4hgg9ir"}
10
+ {"time":"2024-12-30T12:59:28.8801803+02:00","level":"INFO","msg":"Stopping system monitor"}
11
+ {"time":"2024-12-30T12:59:28.8812132+02:00","level":"INFO","msg":"Stopped system monitor"}
12
+ {"time":"2024-12-30T12:59:29.7987804+02:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
13
+ {"time":"2024-12-30T12:59:50.82286+02:00","level":"INFO","msg":"api: retrying error","error":"Post \"https://api.wandb.ai/files/mitel40181-silx/enwik8-dtat/h4hgg9ir/file_stream\": dial tcp 35.186.228.49:443: connectex: A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond."}
14
+ {"time":"2024-12-30T12:59:53.2987018+02:00","level":"INFO","msg":"handler: closed","stream_id":{"value":"h4hgg9ir"}}
15
+ {"time":"2024-12-30T12:59:53.2987018+02:00","level":"INFO","msg":"sender: closed","stream_id":"h4hgg9ir"}
16
+ {"time":"2024-12-30T12:59:53.2987018+02:00","level":"INFO","msg":"writer: Close: closed","stream_id":{"value":"h4hgg9ir"}}
17
+ {"time":"2024-12-30T12:59:53.2987018+02:00","level":"INFO","msg":"stream: closed","id":"h4hgg9ir"}
wandb/run-20241230_125924-h4hgg9ir/logs/debug.log ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2024-12-30 12:59:24,715 INFO MainThread:36980 [wandb_setup.py:_flush():79] Current SDK version is 0.18.6
2
+ 2024-12-30 12:59:24,715 INFO MainThread:36980 [wandb_setup.py:_flush():79] Configure stats pid to 36980
3
+ 2024-12-30 12:59:24,715 INFO MainThread:36980 [wandb_setup.py:_flush():79] Loading settings from C:\Users\silxs\.config\wandb\settings
4
+ 2024-12-30 12:59:24,715 INFO MainThread:36980 [wandb_setup.py:_flush():79] Loading settings from C:\sakana\enwik8-model\wandb\settings
5
+ 2024-12-30 12:59:24,715 INFO MainThread:36980 [wandb_setup.py:_flush():79] Loading settings from environment variables: {}
6
+ 2024-12-30 12:59:24,715 INFO MainThread:36980 [wandb_setup.py:_flush():79] Applying setup settings: {'mode': None, '_disable_service': None}
7
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_setup.py:_flush():79] Inferring run settings from compute environment: {'program_relpath': 'train_dtat.py', 'program_abspath': 'C:\\sakana\\enwik8-model\\train_dtat.py', 'program': 'C:\\sakana\\enwik8-model\\train_dtat.py'}
8
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_setup.py:_flush():79] Applying login settings: {}
9
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_init.py:_log_setup():533] Logging user logs to C:\sakana\enwik8-model\wandb\run-20241230_125924-h4hgg9ir\logs\debug.log
10
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_init.py:_log_setup():534] Logging internal logs to C:\sakana\enwik8-model\wandb\run-20241230_125924-h4hgg9ir\logs\debug-internal.log
11
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_init.py:init():619] calling init triggers
12
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_init.py:init():626] wandb.init called with sweep_config: {}
13
+ config: {'architecture': 'DTAT', 'dataset': 'enwik8', 'batch_size': 32, 'learning_rate': 0.0006, 'warmup_iters': 2000, 'block_size': 1024, 'n_layer': 12, 'n_head': 8, 'n_embd': 512, 'dropout': 0.1, 'sparse_topk': 32}
14
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_init.py:init():669] starting backend
15
+ 2024-12-30 12:59:24,716 INFO MainThread:36980 [wandb_init.py:init():673] sending inform_init request
16
+ 2024-12-30 12:59:24,718 INFO MainThread:36980 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=spawn, using: spawn
17
+ 2024-12-30 12:59:24,719 INFO MainThread:36980 [wandb_init.py:init():686] backend started and connected
18
+ 2024-12-30 12:59:24,722 INFO MainThread:36980 [wandb_init.py:init():781] updated telemetry
19
+ 2024-12-30 12:59:24,755 INFO MainThread:36980 [wandb_init.py:init():814] communicating run to backend with 90.0 second timeout
20
+ 2024-12-30 12:59:25,357 INFO MainThread:36980 [wandb_init.py:init():867] starting run threads in backend
21
+ 2024-12-30 12:59:25,623 INFO MainThread:36980 [wandb_run.py:_console_start():2451] atexit reg
22
+ 2024-12-30 12:59:25,623 INFO MainThread:36980 [wandb_run.py:_redirect():2299] redirect: wrap_raw
23
+ 2024-12-30 12:59:25,624 INFO MainThread:36980 [wandb_run.py:_redirect():2364] Wrapping output streams.
24
+ 2024-12-30 12:59:25,624 INFO MainThread:36980 [wandb_run.py:_redirect():2389] Redirects installed.
25
+ 2024-12-30 12:59:25,626 INFO MainThread:36980 [wandb_init.py:init():911] run started, returning control to user process
26
+ 2024-12-30 12:59:28,880 WARNING MsgRouterThr:36980 [router.py:message_loop():75] message_loop has been closed
wandb/run-20241230_125924-h4hgg9ir/run-h4hgg9ir.wandb ADDED
Binary file (4.22 kB). View file