Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Card for Model ID
|
2 |
+
|
3 |
+
This is a Llama-2-7b model fine-tuned on TruthfulQA using Localized Fine-tuning on LLM Representations (LoFiT; https://arxiv.org/abs/2406.01563). This model checkpoint modifies the attention outputs of 96 attention heads (10% of all attention heads).
|
4 |
+
|
5 |
+
|
6 |
+
### Model Description
|
7 |
+
|
8 |
+
- **License:** mit
|
9 |
+
- **Finetuned from model:** meta-llama/Llama-2-7b-hf
|
10 |
+
|
11 |
+
### Model Sources
|
12 |
+
|
13 |
+
<!-- Provide the basic links for the model. -->
|
14 |
+
|
15 |
+
- **Repository:** https://github.com/fc2869/lo-fit
|
16 |
+
- **Paper:** https://arxiv.org/abs/2406.01563
|
17 |
+
|
18 |
+
## Uses
|
19 |
+
|
20 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
21 |
+
Please use the lofit github repo (https://github.com/fc2869/lo-fit) and then use the following code snippet to run evaluations on TruthfulQA in the repo with this checkpoint.
|
22 |
+
```
|
23 |
+
from models.modeling_llama import LlamaModel,LlamaForCausalLM
|
24 |
+
from transformers import AutoTokenizer
|
25 |
+
import torch
|
26 |
+
from utils.evaluate import evaluate_tqa
|
27 |
+
from utils.dataloaders import TQA
|
28 |
+
|
29 |
+
checkpoint = 'fcyin/llama2_7B_base_lofit_truthfulqa'
|
30 |
+
model_name = 'llama2_7B'
|
31 |
+
device = 'cuda'
|
32 |
+
cache_dir = './'
|
33 |
+
applied_module = 'attention'
|
34 |
+
torch_dtype = torch.float32
|
35 |
+
|
36 |
+
model = LlamaForCausalLM.custom_from_pretrained(checkpoint,
|
37 |
+
device_map=device,
|
38 |
+
cache_dir=cache_dir,
|
39 |
+
applied_module = applied_module,
|
40 |
+
torch_dtype=torch_dtype).to(device)
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
42 |
+
dataloader = TQA(
|
43 |
+
iti_split_dir = './dataset/truthfulqa',
|
44 |
+
fold_num = 0,
|
45 |
+
data_gen_seed = 42
|
46 |
+
)
|
47 |
+
dataset = dataloader.load_data()
|
48 |
+
|
49 |
+
evaluate_tqa(fname='./',eval_dataset = dataset['test'],model_name = model_name,metrics=['mc'],tokenizer=tokenizer,model=model)
|
50 |
+
```
|
51 |
+
|
52 |
+
## Training Details
|
53 |
+
Please refer to the [paper](https://arxiv.org/abs/2406.01563) for the training details.
|