misdelivery/gemma-2-baku-2b-it-strange-merge
google/gemma-2-2b-jpn-itのChat Vectorをrinna/gemma-2-baku-2bにマージした結果、けっこうまともな出力と、とても変な不具合が特徴のモデルになりました。 前後の文脈と関係なく、勝手にとあるトークンが生成されます。 研究用等にどうぞ。
Inference
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
model = AutoModelForCausalLM.from_pretrained('misdelivery/gemma-2-baku-2b-it-strange-merge', torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained('misdelivery/gemma-2-baku-2b-it-strange-merge')
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
messages = [
{"role": "user", "content": "あなたはパフェ専門店の店員です。垂直な層が織りなすハーモニーを重視した秋の新作パフェを考えてください。"}
]
input = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
print(input)
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
output_ids = model.generate(input_ids,
max_new_tokens=1024,
streamer=streamer)
Model Merge
参考:jovyanさんのChat Vectorを使って日本語LLMをチャットモデルに改造する記事
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
base_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
torch_dtype=torch.bfloat16,
device_map="cpu",
)
inst_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b-jpn-it",
torch_dtype=torch.bfloat16,
device_map="cpu",
)
cp_model = AutoModelForCausalLM.from_pretrained(
"rinna/gemma-2-baku-2b",
torch_dtype=torch.bfloat16,
device_map="cpu",
)
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
for k, v in cp_model.state_dict().items():
if (k in skip_layers) or ("layernorm" in k):
continue
chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
new_v = v + chat_vector.to(v.device)
v.copy_(new_v)
- Downloads last month
- 4