Safetensors

Mixtral8X7B Instructの日本語生成を安定させるためのLoraです。

目的

Mixtral-8x7Bは高性能な言語モデルですが、日本語出力に多言語が混入するcode-switchingがよく見られます。 元の性能を維持しながら、日本語生成を安定させる方法として、Loraの効果を検証しました。

学習データセット

学習データセットとして、下記のDPOデータセットを使用しています。
DPO trainingはVRAM消費が多く、今回はchosenのデータを使用したsft学習しています。

Chatbot Arena Conversations JA (calm2) Dataset
:cyberagent/chatbot-arena-ja-calm2-7b-chat-experimental
指示文 : lmsys/chatbot_arena_conversationsのユーザ入力(CC-BY 4.0)を利用。
指示文の和訳 : facebookの翻訳モデル(MIT License)が使用されています。
応答文 : calm2-7b-chat(Apache 2.0)の出力です。

evaluation

大きな性能低下がないことを確認しました

##Lora

num_fewshot: 2, batch_size: 1

Task Version Metric Value Stderr
jsquad-1.1-0.3 1.1 exact_match 72.3323
f1 85.4772
jcommonsenseqa-1.1-0.3 1.1 acc 0.7498 ± 0.0130
acc_norm 0.4138 ± 0.0147

num_fewshot: 2, batch_size: 1

Task Version Metric Value Stderr
jnli-1.1-0.3 1.1 acc 0.5912 ± 0.0100
acc_norm 0.4108 ± 0.0100
marc_ja-1.1-0.3 1.1 acc 0.9620 ± 0.0025
acc_norm 0.9620 ± 0.0025
jaqket_v2-0.1-0.3 0.1 exact_match 71.6495
f1 79.4725

##Base model

num_fewshot: 3,3, batch_size: 1

Task Version Metric Value Stderr
jsquad-1.1-0.3 1.1 exact_match 68.1225
f1 83.5285
jcommonsenseqa-1.1-0.3 1.1 acc 0.7766 ± 0.0125
acc_norm 0.4629 ± 0.0149

num_fewshot: 2, batch_size: 1

Task Version Metric Value Stderr
jnli-1.1-0.3 1.1 acc 0.6228 ± 0.0098
acc_norm 0.5288 ± 0.0101
marc_ja-1.1-0.3 1.1 acc 0.9630 ± 0.0025
acc_norm 0.9630 ± 0.0025
jaqket_v2-0.1-0.3 0.1 exact_match 67.9553
f1 78.7550

その他

Lora学習時のcontext長は4096tokenまでですが、4k token以上の出力も可能です。

注:bf16での使用を想定しています。 量子化推論する場合は、bf16でモデルを読み込んだ状態でLora適応またはマージ、その後に量子化してください。

2/8更新 学習強度が1/3と、2/3のcheck pointも公開しました
こちらのほうがベースモデルの汎化性能維持できている可能性があります

learningstrength0.3
num_fewshot: 2,2, batch_size: 1

Task Version Metric Value Stderr
jsquad-1.1-0.3 1.1 exact_match 72.1747
f1 85.3325
jcommonsenseqa-1.1-0.3 1.1 acc 0.7534 ± 0.0129
acc_norm 0.4111 ± 0.0147

learningstrength0.6
num_fewshot: 2,2, batch_size: 1

Task Version Metric Value Stderr
jsquad-1.1-0.3 1.1 exact_match 72.3548
f1 85.5144
jcommonsenseqa-1.1-0.3 1.1 acc 0.7480 ± 0.0130
acc_norm 0.4111 ± 0.0147
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train aixsatoshi/Mixtral-8x7B-ja-Lora-sft-ChatbotArenaJAcalm2