kaiokendev commited on
Commit
67bf26a
1 Parent(s): 65084ac

Upload lora

Browse files
README.md CHANGED
@@ -1,3 +1,34 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ ### SuperHOT Prototype 2 w/ 4-8K Context
6
+
7
+ This is a second prototype of SuperHOT, this time with 4K context and no RLHF. In my testing, it can go all the way to 6K without breaking down and I made the change with intention to reach 8K, so I'll assume it will go to 8K although I only trained on 4K sequences.
8
+
9
+ In order to use the 8K context, you will need to apply the monkeypatch I have added in this repo -- without it, it will not work. The patch is very simple, and you can make the changes yourself:
10
+ - Increase the `max_position_embeddings` to 8192 to stretch the sinusoidal
11
+ - Stretch the frequency steps by a scale of `0.25`
12
+
13
+ The intuition is to calibrate the model to within the learned positions of the pre-trained model as the model may be overfit on the token-position relationship (not my idea, [Ofir Press'](https://ofir.io/)). By interpolating the encodings, we remain within the bounds of the pre-trained model (work with the overfitting rather than against it). The monkeypatch will work for the pre-trained model without fine-tuning, but you will need to fine-tune as the results will not be that good without it.
14
+
15
+ It can probably be even better than this with a few other modifications which I am testing (swap softmax for ReLU, increase head dimension)
16
+
17
+ In my testing, I tried random positional encoding, but I was not able to replicate the results of [Jianlin Su](https://kexue.fm/archives/9444), so maybe I did it incorrectly. I also tried shifted positions, log n scaling, log-sigmoid, and increase the head dimension, though this dilated RoPE (DoPE :) ) is the only one which worked for me consistently -- Note these are all based on finetuning, since the goal is to extend the context of the pre-trained model. Pre-training will paint a different picture.
18
+
19
+ I trained the LoRA with the following configuration:
20
+ - 1200 samples (~400 samples over 2048 sequence length)
21
+ - learning rate of 3e-4
22
+ - 3 epochs
23
+ - The exported modules are:
24
+ - q_proj
25
+ - k_proj
26
+ - v_proj
27
+ - o_proj
28
+ - all bias
29
+ - Rank = 2
30
+ - Alpha = 8
31
+ - no dropout
32
+ - weight decay of 0.1
33
+ - AdamW beta1 of 0.9 and beta2 0.99, epsilon of 1e-5
34
+ - Trained on 4-bit base model
adapter_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_model_name_or_path": "",
3
+ "bias": "all",
4
+ "fan_in_fan_out": false,
5
+ "inference_mode": true,
6
+ "init_lora_weights": true,
7
+ "lora_alpha": 8,
8
+ "lora_dropout": 0,
9
+ "modules_to_save": null,
10
+ "peft_type": "LORA",
11
+ "r": 2,
12
+ "target_modules": [
13
+ "q_proj",
14
+ "k_proj",
15
+ "v_proj",
16
+ "o_proj"
17
+ ],
18
+ "task_type": "CAUSAL_LM"
19
+ }
adapter_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76133dc631ac8dc28341c45f8f469cc603174cb1d16c728f65b33778f8f497e4
3
+ size 17579562
llama_rope_scaled_monkey_patch.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ import transformers.models.llama.modeling_llama
4
+ from einops import rearrange
5
+ import random
6
+
7
+
8
+ class ScaledRotaryEmbedding(torch.nn.Module):
9
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
10
+ super().__init__()
11
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
12
+ self.register_buffer("inv_freq", inv_freq)
13
+
14
+ max_position_embeddings = 8192
15
+
16
+ # Build here to make `torch.jit.trace` work.
17
+ self.max_seq_len_cached = max_position_embeddings
18
+ t = torch.arange(
19
+ self.max_seq_len_cached,
20
+ device=self.inv_freq.device,
21
+ dtype=self.inv_freq.dtype,
22
+ )
23
+
24
+ self.scale = 1 / 4
25
+ t *= self.scale
26
+
27
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
28
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
29
+ emb = torch.cat((freqs, freqs), dim=-1)
30
+ self.register_buffer(
31
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
32
+ )
33
+ self.register_buffer(
34
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
35
+ )
36
+
37
+ def forward(self, x, seq_len=None):
38
+ # x: [bs, num_attention_heads, seq_len, head_size]
39
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
40
+ if seq_len > self.max_seq_len_cached:
41
+ self.max_seq_len_cached = seq_len
42
+ t = torch.arange(
43
+ self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
44
+ )
45
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
46
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
47
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
48
+ self.register_buffer(
49
+ "cos_cached", emb.cos()[None, None, :, :], persistent=False
50
+ )
51
+ self.register_buffer(
52
+ "sin_cached", emb.sin()[None, None, :, :], persistent=False
53
+ )
54
+ return (
55
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
56
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
57
+ )
58
+
59
+
60
+ def replace_llama_rope_with_scaled_rope():
61
+ transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = (
62
+ ScaledRotaryEmbedding
63
+ )