Text Generation
Transformers
PyTorch
longllama
code
text-generation-inference
custom_code
Eval Results
File size: 11,535 Bytes
a1dfdfd
 
76ac8b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf5464
35678d3
 
bbf5464
35678d3
bbf5464
 
 
 
 
 
81abcb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1dfdfd
0552cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d7488ec
0552cd3
d7488ec
0552cd3
 
 
c682b6d
239069a
0552cd3
 
 
c682b6d
 
0f3a950
c682b6d
 
 
 
0552cd3
 
 
 
 
 
 
 
049a26e
c682b6d
0552cd3
 
2b961b3
0552cd3
 
 
 
 
 
0f3a950
 
0552cd3
 
 
 
bea3280
 
c879d23
 
 
0552cd3
 
 
 
 
 
 
 
 
07822d3
0552cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c682b6d
0552cd3
c682b6d
0552cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c682b6d
 
0552cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9f5b43
 
0552cd3
 
 
a9f5b43
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
---
license: llama2
datasets:
- bigcode/starcoderdata
- togethercomputer/RedPajama-Data-1T
- tiiuae/falcon-refinedweb
metrics:
- code_eval
- accuracy
pipeline_tag: text-generation
tags:
- code
- text-generation-inference
model-index: 
- name: long_llama_code_7b
  results:
  - task: 
      name: Code Generation
      type: code-generation
    dataset:
      name: "HumanEval" 
      type: openai_humaneval
    metrics:
    - name: pass@1
      type: pass@1
      value: 0.286
      verified: false
  - task: 
      name: Math Reasoning
      type: reasoning
    dataset:
      name: "GSM8K-Python" 
      type: gsm8k
    metrics:
    - name: pass@1
      type: pass@1
      value: 0.249
      verified: false
  - task: 
      name: Math Reasoning
      type: reasoning
    dataset:
      name: "GSM8K" 
      type: gsm8k
    metrics:
    - name: pass@1
      type: pass@1
      value: 0.174
      verified: false
  - task: 
      name: Knowledge
      type: knowledge
    dataset:
      name: "MMLU" 
      type: mmlu
    metrics:
    - name: accuracy
      type: accuracy
      value: 0.399
      verified: false
---


# LongLLaMA: Focused Transformer Training for Context Scaling


<div align="center">


<table>

  <tr>
  <td align="center">
    <span style="font-size:300%">{</span>
    </td>
    <td align="center">
    <span style="font-size:115%">
    <b>
    <a href="https://huggingface.co/syzymon/long_llama_code_7b" tyle="margin-bottom:30px">LongLLaMA Code-7B</a>
    </b>
    </span>
    </td>
    <td align="center">
    <span style="font-size:300%">}</span>
    </td>

 </tr>
</table>


</div>


<div align="center">

 [TLDR](#TLDR) | [Overview](#Overview) | [Results](#Results) | [Usage](#Usage) | [Authors](#Authors) | [Citation](#Citation) | [License](#License) | [Acknowledgments](#Acknowledgments)
 
 [FoT continued pretraining](https://github.com/CStanKonrad/long_llama/tree/main/fot_continued_pretraining) | [Instruction tuning](https://github.com/CStanKonrad/long_llama/tree/main/instruction_fine_tuning)

</div>



## TLDR
This repository contains the research preview of **LongLLaMA, a large language model capable of handling long contexts of 256k tokens or even more**. 

LongLLaMA-Code is built upon the foundation of [Code Llama](https://huggingface.co/codellama/CodeLlama-7b-hf).  

LongLLaMA-Code has **improved reasoning capabilities** compared to CodeLlama, in particular we improve **GSM8K math reasoning from 13% to 17.4% after just continued pre-training, no in-distribution fine-tuning**.

<p align="center" width="100%">
<img src="https://raw.githubusercontent.com/CStanKonrad/long_llama/main/assets/results.png" alt="LongLLaMA" style="width: 70%; min-width: 300px; display: block; margin: auto;">
</p>

## Overview

### Base models
[Focused Transformer: Contrastive Training for Context Scaling](https://arxiv.org/abs/2307.03170) (FoT) presents a simple method for endowing language models with the ability to handle context consisting possibly of millions of tokens while training on significantly shorter input. FoT permits a subset of attention layers to access a memory cache of (key, value) pairs to extend the context length. The distinctive aspect of FoT is its training procedure, drawing from contrastive learning. Specifically, we deliberately expose the memory attention layers to both relevant and irrelevant keys (like negative samples from unrelated documents). This strategy incentivizes the model to differentiate keys connected with semantically diverse values, thereby enhancing their structure. This, in turn, makes it possible to extrapolate the effective context length much beyond what is seen in training. 


**LongLLaMA** is an [OpenLLaMA](https://github.com/openlm-research/open_llama) model finetuned with the FoT method,
with three layers used for context extension. **Crucially, LongLLaMA is able to extrapolate much beyond the context length seen in training: 8k. E.g., in the passkey retrieval task, it can handle inputs of length 256k**.  
**LongLLaMA-Code** is a [Code Llama](https://huggingface.co/codellama/CodeLlama-7b-hf) model finetuned with the FoT method.


#### Model card
<div align="center">

|  | [LongLLaMA-3B](https://huggingface.co/syzymon/long_llama_3b) | [LongLLaMA-3Bv1.1](https://huggingface.co/syzymon/long_llama_3b_v1_1) | [LongLLaMA Code-7B](https://huggingface.co/syzymon/long_llama_code_7b) |
|----------------|----------|----------|-----------|
| Source model         | [OpenLLaMA-3B](https://huggingface.co/openlm-research/open_llama_3b_easylm)      | [OpenLLaMA-3Bv2](https://huggingface.co/openlm-research/open_llama_3b_v2_easylm) | [CodeLLaMA-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf)       |
| Source model tokens     | 1T      |  1 T |  2T + 0.5 T       |
| Fine-tuning context | 8K      | **32K** | **32K** |
| Fine-tuning tokens  | 10B     | 5B | **35B**     |
| Memory layers         |  6, 12, 18        |   6, 12, 18        |  8, 16, 24        |

</div>

## Results

<p align="center" width="100%">
<img src="https://raw.githubusercontent.com/CStanKonrad/long_llama/main/assets/full_results.png" alt="LongLLaMA" style="width: 70%; min-width: 300px; display: block; margin: auto;">
</p>

## Usage

See also: 
* [Colab with LongLLaMA-Instruct-3Bv1.1](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_instruct_colab.ipynb).
* [Colab with an example usage of base LongLLaMA](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_colab.ipynb).
### Requirements
```
pip install --upgrade pip
pip install git+https://github.com/huggingface/transformers.git@main sentencepiece accelerate
```

### Loading model
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("syzymon/long_llama_code_7b")
model = AutoModelForCausalLM.from_pretrained("syzymon/long_llama_code_7b", 
                                            torch_dtype=torch.float32, 
                                            trust_remote_code=True)
```

### Input handling and generation
LongLLaMA uses the Hugging Face interface, the long input given to the model will be 
split into context windows and loaded into the memory cache.
```python
prompt = "My name is Julien and I like to"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model(input_ids=input_ids)
```
During the model call, one can provide the parameter `last_context_length` which specifies the number of tokens left in the last context window. Tuning this parameter can improve generation as the first layers do not have access to memory. See details in [How LongLLaMA handles long inputs](#How-LongLLaMA-handles-long-inputs).

```python
generation_output = model.generate(
    input_ids=input_ids,
    max_new_tokens=1024,
    num_beams=1,
    last_context_length=3072,
    do_sample=True,
    temperature=1.0,
)
print(tokenizer.decode(generation_output[0]))
```

### Additional configuration
LongLLaMA has several other parameters:
* `mem_layers` specifies layers endowed with memory (should be either an empty list or a list of all memory layers specified in the description of the checkpoint).
* `mem_dtype` allows changing the type of memory cache
* `mem_attention_grouping` can trade off speed for reduced memory usage. 
  When equal to `(4, 2048)`, the memory layers will process at most $4*2048$ queries at once ($4$ heads and $2048$ queries for each head).

```python
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM

tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_code_7b")
model = AutoModelForCausalLM.from_pretrained(
    "syzymon/long_llama_code_7b", torch_dtype=torch.float32, 
    mem_layers=[], 
    mem_dtype='bfloat16',
    trust_remote_code=True,
    mem_attention_grouping=(4, 2048),
)
```


### Drop-in use with LLaMA code
 LongLLaMA checkpoints can also be used as a drop-in replacement for LLaMA checkpoints in [Hugging Face implementation of LLaMA](https://huggingface.co/docs/transformers/main/model_doc/llama), but in this case, they will be limited to the original context length.

```python
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch

tokenizer = LlamaTokenizer.from_pretrained("syzymon/long_llama_code_7b")
model = LlamaForCausalLM.from_pretrained("syzymon/long_llama_code_7b", torch_dtype=torch.float32)
```


### How LongLLaMA handles long inputs
Inputs over $ctx=2048$ ($ctx=4096$ for LongLLaMA Code) tokens are automatically split into windows $w_1, \ldots, w_m$. The first $m-2$ windows contain $ctx$ tokens each, $w_{m-1}$ has no more than $2048$ tokens, and $w_m$ contains the number of tokens specified by `last_context_length`. The model processes the windows one by one extending the memory cache after each. If `use_cache` is `True`, then the last window will not be loaded to the memory cache but to the local (generation) cache.

The memory cache stores $(key, value)$ pairs for each head of the specified memory layers `mem_layers`. In addition to this, it stores attention masks. 

If `use_cache=True` (which is the case in generation), LongLLaMA will use two caches: the memory cache for the specified layers and the local (generation) cache for all layers. When the local cache exceeds $2048$ elements, its content is moved to the memory cache for the memory layers.

For simplicity, context extension is realized with a memory cache and full attention in this repo. Replacing this simple mechanism with a KNN search over an external database is possible with systems like [Faiss](https://github.com/facebookresearch/faiss). This potentially would enable further context length scaling. We leave this as a future work.


## Authors
- [Szymon Tworkowski](https://scholar.google.com/citations?user=1V8AeXYAAAAJ&hl=en)
- [Konrad Staniszewski](https://scholar.google.com/citations?user=CM6PCBYAAAAJ)
- [Mikołaj Pacek](https://scholar.google.com/citations?user=eh6iEbQAAAAJ&hl=en&oi=ao)
- [Henryk Michalewski](https://scholar.google.com/citations?user=YdHW1ycAAAAJ&hl=en)
- [Yuhuai Wu](https://scholar.google.com/citations?user=bOQGfFIAAAAJ&hl=en)
- [Piotr Miłoś](https://scholar.google.pl/citations?user=Se68XecAAAAJ&hl=pl&oi=ao)


## Citation
To cite this work please use
```bibtex
@misc{tworkowski2023focused,
      title={Focused Transformer: Contrastive Training for Context Scaling}, 
      author={Szymon Tworkowski and Konrad Staniszewski and Mikołaj Pacek and Yuhuai Wu and Henryk Michalewski and Piotr Miłoś},
      year={2023},
      eprint={2307.03170},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}
```


## License 
For the LongLLaMA Code see [codellama/CodeLlama-7b-hf](https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/LICENSE) license.  
Some of the examples use external code (see headers of files for copyright notices and licenses).

## Acknowledgments
Special thanks to [Keiran Paster](https://twitter.com/keirp1) for providing immensely valuable suggestions about the pre-training data.

We gratefully acknowledge the TPU Research Cloud program, which was instrumental to our research by providing significant computational resources. We are also grateful to Xinyang Geng and Hao Liu for releasing [OpenLLaMA](https://github.com/openlm-research/open_llama) checkpoints and the [EasyLM](https://github.com/young-geng/EasyLM) library.

We would like to thank [Xiaosong,He](https://github.com/hxs91) for suggestions on how to improve the explanations of cross-batch code.