|
--- |
|
library_name: peft |
|
--- |
|
|
|
# モデル概要 |
|
|
|
[meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)を日本語データ([taka-yayoi/databricks-dolly-15k-ja](https://huggingface.co/datasets/taka-yayoi/databricks-dolly-15k-ja))を用いてインストラクションチューニングしました. |
|
|
|
# 使用方法 |
|
|
|
```python |
|
import torch |
|
from peft import PeftModel |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
|
|
# モデルの読み込み |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"meta-llama/Llama-2-7b-hf", |
|
quantization_config=BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
), |
|
device_map={"":0} |
|
) |
|
|
|
# トークナイザーの読み込み |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
"asaoka/Llama-2-7b-hf-qlora-dolly15k-japanese", |
|
) |
|
|
|
# LoRAの読み込み |
|
model = PeftModel.from_pretrained( |
|
model, |
|
"asaoka/Llama-2-7b-hf-qlora-dolly15k-japanese", |
|
device_map={"":0} |
|
) |
|
model.eval() |
|
|
|
# プロンプトの準備 |
|
prompt = "### Instruction: 富士山とは?\n\n### Response: " |
|
|
|
# 推論の実行 |
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0") |
|
with torch.no_grad(): |
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
``` |
|
|
|
使用方法は,[「Google Colab で Llama-2-7B のQLoRA ファインチューニングを試す」](https://note.com/npaka/n/na7c631175111#f2af0e53-4ef3-4288-b152-6524f1b940a7)を参照しました. |
|
|
|
# トークナイザーの日本語への拡張 |
|
|
|
### 1. 日本語のトークナイザーを学習 |
|
|
|
トークナイザーの学習は,[ce-lery/japanese-mistral-300m-base](https://huggingface.co/ce-lery/japanese-mistral-300m-base)を参照しました. |
|
|
|
### 2. Llama-2-7b-hfのトークナイザーと日本語のトークナイザーをマージ |
|
|
|
トークナイザーのマージは,[「日本語が話せるLlamaモデルをDIYする」](https://qiita.com/Taiyou2000/items/3229d320c252d6de33c7)を参照しました. |
|
|
|
# トレーニング方法 |
|
|
|
- ファインチューニング:インストラクションチューニング + QLoRA(4bitLoRA) |
|
|
|
トレーニング方法は,[「MetaのLlama 2をDatabricksでQLoRAを使ってファインチューニングしてみる」](https://qiita.com/taka_yayoi/items/a973fa2d08062224d422)を参照しました. |
|
|
|
# JGLUEスコア |
|
|
|
| タスク | Llama-2-7b-hf | This Model | |
|
|:-|:-|:-| |
|
| jcommonsenseqa-1.1-0.6(acc) | 0.7274 | 0.7060 | |
|
|
|
[JGLUEスコア](https://aclanthology.org/2022.lrec-1.317/)は,Stability AI社の[lm-evaluation-harness](https://github.com/Stability-AI/lm-evaluation-harness)を用いて |
|
算出しました.JGLUEスコアの算出に用いたスクリプトを下記に示します. |
|
|
|
- Llama-2-7b-hf |
|
|
|
```bash |
|
!python main.py \ |
|
--model hf-causal-experimental \ |
|
--model_args pretrained=meta-llama/Llama-2-7b-hf \ |
|
--tasks jcommonsenseqa-1.1-0.6 \ |
|
--num_fewshot 3 \ |
|
--device cuda \ |
|
--output_path ./results.json |
|
``` |
|
|
|
- This Model |
|
|
|
|
|
```bash |
|
!python main.py \ |
|
--model hf-causal-experimental \ |
|
--model_args pretrained=meta-llama/Llama-2-7b-hf,peft=asaoka/Llama-2-7b-hf-qlora-dolly15k-japanese \ |
|
--tasks jcommonsenseqa-1.1-0.6 \ |
|
--num_fewshot 3 \ |
|
--device cuda \ |
|
--output_path ./results.json |
|
``` |
|
|
|
|