kaiokendev
Upload lora
67bf26a
|
raw
history blame
2.18 kB
metadata
license: mit

SuperHOT Prototype 2 w/ 4-8K Context

This is a second prototype of SuperHOT, this time with 4K context and no RLHF. In my testing, it can go all the way to 6K without breaking down and I made the change with intention to reach 8K, so I'll assume it will go to 8K although I only trained on 4K sequences.

In order to use the 8K context, you will need to apply the monkeypatch I have added in this repo -- without it, it will not work. The patch is very simple, and you can make the changes yourself:

  • Increase the max_position_embeddings to 8192 to stretch the sinusoidal
  • Stretch the frequency steps by a scale of 0.25

The intuition is to calibrate the model to within the learned positions of the pre-trained model as the model may be overfit on the token-position relationship (not my idea, Ofir Press'). By interpolating the encodings, we remain within the bounds of the pre-trained model (work with the overfitting rather than against it). The monkeypatch will work for the pre-trained model without fine-tuning, but you will need to fine-tune as the results will not be that good without it.

It can probably be even better than this with a few other modifications which I am testing (swap softmax for ReLU, increase head dimension)

In my testing, I tried random positional encoding, but I was not able to replicate the results of Jianlin Su, so maybe I did it incorrectly. I also tried shifted positions, log n scaling, log-sigmoid, and increase the head dimension, though this dilated RoPE (DoPE :) ) is the only one which worked for me consistently -- Note these are all based on finetuning, since the goal is to extend the context of the pre-trained model. Pre-training will paint a different picture.

I trained the LoRA with the following configuration:

  • 1200 samples (~400 samples over 2048 sequence length)
  • learning rate of 3e-4
  • 3 epochs
  • The exported modules are:
    • q_proj
    • k_proj
    • v_proj
    • o_proj
    • all bias
  • Rank = 2
  • Alpha = 8
  • no dropout
  • weight decay of 0.1
  • AdamW beta1 of 0.9 and beta2 0.99, epsilon of 1e-5
  • Trained on 4-bit base model