File size: 4,769 Bytes
990b55a 7586afb 0143b56 990b55a 7586afb 0143b56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
---
license: llama2
language:
- ja
- en
---
## モデル概要
Llama2-13bに日本語語彙を追加して継続事前学習を行った大喜利言語モデルです。
事前学習後に大喜利データでFine-tuningしています。
本モデルは[AWS LLM 開発支援プログラム](https://aws.amazon.com/jp/local/llm-development-support-program/)の支援を受けております。
継続事前学習ではAWS Trainium を搭載したインスタンス [trn1.32xlarge](https://aws.amazon.com/jp/ec2/instance-types/trn1/)×4の並列学習を行っております。
* License: [LLAMA 2 COMMUNITY LICENSE](https://github.com/facebookresearch/llama/blob/main/LICENSE)
* Library: [neuronx-nemo-megatron](https://github.com/aws-neuron/neuronx-nemo-megatron)
### トークナイザー
Llama2の本来のトークナイザーに含まれる語彙32,000に日本語の語彙をBPEで学習して13,046追加し、総語彙サイズは45,046です。
語彙を追加する際、漢字一文字のトークンは、常用漢字と学習データ内で出現頻度が高いものに絞っています。
数字と文字がペアのトークンや記号と文字がペアのトークンが学習されないように、予めデータから数字や記号を取り除いています。
## 学習データ
以下のコーパスを使用して、事前学習を行いました。その際のトークン数は650億トークンでした。
* [C4](https://huggingface.co/datasets/mc4)の日本語データ
* [CC-100](https://huggingface.co/datasets/cc100)の日本語データ
* [OSCAR](https://huggingface.co/datasets/oscar)の日本語データ
* Wikipediaの[日本語ダンプデータ](https://ja.wikipedia.org/wiki/%E3%83%A1%E3%82%A4%E3%83%B3%E3%83%9A%E3%83%BC%E3%82%B8)と[英語ダンプデータ](https://en.wikipedia.org/wiki/Main_Page)
* 自社データ
## 使用方法
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "watashiha/Watashiha-Llama-2-13B-Ogiri-sft"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
if torch.cuda.is_available():
model = model.to("cuda")
odai = "マジシャンのショーでアシスタントが消えたまま戻ってこない時の一言。"
text = f"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
入力の文は大喜利のお題です。お題に沿った面白いボケを生成してください。
### 入力:
{odai}
### 応答:
"""
text = text.lstrip()
with torch.no_grad():
token_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
output_ids = model.generate(
token_ids,
do_sample=True,
min_new_tokens=1,
max_new_tokens=64,
top_p=0.9,
top_k=50,
temperature=0.8,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids.tolist()[0], skip_special_tokens=True)
print(output)
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。
### 指示:
入力の文は大喜利のお題です。お題に沿った面白いボケを生成してください。
### 入力:
マジシャンのショーでアシスタントが消えたまま戻ってこない時の一言。
### 応答:
帰りの電車賃あげるから
"""
```
### AWSのinf2.xlargeで動かす方法
AWSの[inf2インスタンス](https://aws.amazon.com/jp/ec2/instance-types/inf2/)はパラメータ数が10Bを超えるモデルをGPUインスタンスと比べ安価に運用できます(2024/1/24現在)。
こちらは近日、モデルとソースコードを公開予定です。
## 性能比較
以下は大喜利Fine-tuningしたモデルが出力させたボケを、ケータイ大喜利レジェンドに4段階で評価してもらった結果です。
圏外:お題を日本語として理解できていない
1本:お題を理解はできているがボケとして成立していない(面白みがない)
2本:ボケとして成立している(面白みがある)
3本:面白い(一定以上の面白さがある)
| | 圏外 | 1本 | 2本 | 3本 |
|--------------|------|-----|-----|-----|
| Watashiha-Llama-2-13B-Ogiri-sft | 75 | 133 | 209 | 81 |
| [watashiha-gpt-6b](https://huggingface.co/watashiha/watashiha-gpt-6b) | 77 | 204 | 175 | 44 | |