Gemma-2-2B-SearchHelper-20240906
このモデルは某検索AIサイトのプロサーチ的なところをマネするために作りました。
Model Details
このモデルは質問に対して、回答ではなく、質問を解決するための手順をリスト型で返すモデルです。
返答される手順は日本語でおおよそ3段階程度です。
軽量なGemma-2-2Bモデルをllama-3.1-405Bにより合成されたデータセットでfine-tuningしました。
unslothを使用してloraで学習しています。
Uses
通常通り、transformersからも使用できます。
model_name = "kurogane/Gemma-2-2B-SearchHelper-20240906"
# トークナイザーの読み込み
tokenizer = AutoTokenizer.from_pretrained(
model_name,
)
# モデルの読み込み
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
)
model = model.to('cuda')
使用例
input_text = "Pythonの勉強の仕方を教えてください。"
messages = [
{"role": "user", "content": input_text},
]
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=256, eos_token_id=tokenizer.encode('<end_of_turn>'))
print(tokenizer.decode(outputs[0]))
出力例
<bos><start_of_turn>user
Unityでエディタ拡張を作る方法を教えてください。<end_of_turn>
<start_of_turn>model
["Unityの公式ドキュメントでエディタ拡張のチュートリアルを探す","Stack OverflowでUnityエディタ拡張に関する質問を検索する","YouTubeでUnityエディタ拡張のチュートリアル動画を探す" ]<end_of_turn>
<bos><start_of_turn>user
Pythonの勉強の仕方を教えてください。<end_of_turn>
<start_of_turn>model
["Pythonの公式ドキュメントを検索し、基本的な文法やデータ構造を学ぶ","オンラインコースやチュートリアルサイト(例:Udemy、Coursera)でPythonの基礎から応用までを学ぶ","GitHubやStack Overflowなどの開発者コミュニティで、実際のプロジェクトやコード例を参考にする" ]<end_of_turn>
jsonモジュールなどで出力をリスト化することもできます。
input_ids = tokenizer.apply_chat_template(messages, tokenize=False,)
s_output_only = tokenizer.decode(outputs[0]).replace(input_ids, "").replace("<end_of_turn>", "").replace("<start_of_turn>model", "")
json_text = json.loads(s_output_only)
print(json_text)
出力例
['Pythonの公式ドキュメントを検索し、基本的な文法やデータ構造を学ぶ', 'オンラインコースやチュートリアルサイト(例:Udemy、Coursera)でPythonの基礎から応用までを学ぶ', 'GitHubやStack Overflowなどの開発者コミュニティで、実際のプロジェクトやコード例を参考にする']
Bias, Risks, and Limitations
出力は不完全かもしれません。2Bモデルであること、15000件程度の合成データセットのみを使用して学習していることから、確実な手順を推定できないかもしれません。また、検索に不適切な手順を提示する恐れがあります。
Recommendations
後日、合成データセットを公開する予定です。7Bモデルなどの十分に賢いモデルのfine tuningに使用してみたほうがこのモデルより適切に推定してくれるかもしれません。
Training Details
Training Data
後日公開予定です。llama-3.1-405Bを使用して作成した役15000件の合成データセットを使用しています。
また、instructionには、各種データセットの最初のチャットデータのみを使用しました。
以下のモデルからランダムに抽出しています。
- bigbio/med_qa
- Aratako/Synthetic-JP-Conversations-Magpie-Nemotron-4-10k
- kunishou/oasst1-chat-44k-ja
- truthfulqa/truthful_qa
Training Procedure
Unslothにぶち込みました。a6000でだいたい35分位かかりました。
Training Hyperparameters
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = max_seq_length,
dataset_num_proc = 2,
packing = False, # Can make training 5x faster for short sequences.
args = TrainingArguments(
per_device_train_batch_size = 8,
gradient_accumulation_steps = 8,
warmup_steps = 5,
num_train_epochs = 1, # Set this for 1 full training run.
# max_steps = 60,
learning_rate = 2e-4,
fp16 = not is_bfloat16_supported(),
bf16 = is_bfloat16_supported(),
logging_steps = 1,
optim = "adamw_8bit",
weight_decay = 0.01,
lr_scheduler_type = "linear",
seed = 3407,
output_dir = "outputs",
),
)
Model Card Authors/Contact
バーチャル一般人のクロガネと言います!
Please follow me!!↓
https://x.com/Kurogane_8_Gk
Framework versions
- PEFT 0.12.0
- Downloads last month
- 0