|
--- |
|
license: mit |
|
--- |
|
|
|
### SuperHOT Prototype 2 w/ 4-8K Context |
|
|
|
This is a second prototype of SuperHOT, this time with 4K context and no RLHF. In my testing, it can go all the way to 6K without breaking down and I made the change with intention to reach 8K, so I'll assume it will go to 8K although I only trained on 4K sequences. |
|
|
|
In order to use the 8K context, you will need to apply the monkeypatch I have added in this repo -- without it, it will not work. The patch is very simple, and you can make the changes yourself: |
|
- Increase the `max_position_embeddings` to 8192 to stretch the sinusoidal |
|
- Stretch the frequency steps by a scale of `0.25` |
|
|
|
The intuition is to calibrate the model to within the learned positions of the pre-trained model as the model may be overfit on the token-position relationship (not my idea, [Ofir Press'](https://ofir.io/)). By interpolating the encodings, we remain within the bounds of the pre-trained model (work with the overfitting rather than against it). The monkeypatch will work for the pre-trained model without fine-tuning, but you will need to fine-tune as the results will not be that good without it. |
|
|
|
It can probably be even better than this with a few other modifications which I am testing (swap softmax for ReLU, increase head dimension) |
|
|
|
In my testing, I tried random positional encoding, but I was not able to replicate the results of [Jianlin Su](https://kexue.fm/archives/9444), so maybe I did it incorrectly. I also tried shifted positions, log n scaling, log-sigmoid, and increase the head dimension, though this dilated RoPE (DoPE :) ) is the only one which worked for me consistently -- Note these are all based on finetuning, since the goal is to extend the context of the pre-trained model. Pre-training will paint a different picture. |
|
|
|
I trained the LoRA with the following configuration: |
|
- 1200 samples (~400 samples over 2048 sequence length) |
|
- learning rate of 3e-4 |
|
- 3 epochs |
|
- The exported modules are: |
|
- q_proj |
|
- k_proj |
|
- v_proj |
|
- o_proj |
|
- all bias |
|
- Rank = 2 |
|
- Alpha = 8 |
|
- no dropout |
|
- weight decay of 0.1 |
|
- AdamW beta1 of 0.9 and beta2 0.99, epsilon of 1e-5 |
|
- Trained on 4-bit base model |