|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# gemma2-9b-sft |
|
base_model: gemma2_9b |
|
gemma2 27bにより生成したself instructデータ10kによりinstrution tuningを実施 |
|
|
|
## Eval |
|
elyza-task-100 |
|
|
|
## Use |
|
``` |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto" # GPU自動割り当て |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
messages_list = [ |
|
[ |
|
{"role": "user", "content": "仕事の熱意を取り戻すためのアイデアを5つ挙げてください。"} |
|
] |
|
] |
|
|
|
prompts = [self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) for messages in messages_list] |
|
|
|
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device) |
|
|
|
outputs = self.model.generate( |
|
**inputs, |
|
temperature=self.generate_configs["temperature"], |
|
max_new_tokens=self.generate_configs["max_new_tokens"], |
|
top_p=self.generate_configs["top_p"], |
|
top_k=self.generate_configs["top_k"], |
|
repetition_penalty=self.generate_configs["repetition_penalty"], |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
) |
|
|
|
print(outputs) |
|
|
|
``` |