|
--- |
|
license: gemma |
|
datasets: |
|
- kunishou/databricks-dolly-15k-ja |
|
language: |
|
- ja |
|
base_model: |
|
- google/gemma-2-2b |
|
library_name: transformers |
|
tags: |
|
- text-generation-inference |
|
- transformers |
|
--- |
|
|
|
## 出力方法 |
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import json |
|
from tqdm import tqdm |
|
|
|
def generate_task_outputs(input_jsonl_path, output_jsonl_path): |
|
# モデルとトークナイザーのロード |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"google/gemma-2b-it", # ベースモデル |
|
torch_dtype=torch.float16, |
|
device_map={"": 0} |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b") |
|
|
|
# LoRAアダプターの適用 |
|
model = PeftModel.from_pretrained( |
|
model, |
|
"ユーザー名/リポジトリ名" |
|
) |
|
model.eval() |
|
|
|
# 入力データの読み込み |
|
tasks = [] |
|
with open(input_jsonl_path, 'r') as f: |
|
for line in f: |
|
tasks.append(json.loads(line)) |
|
|
|
# 出力の生成 |
|
results = [] |
|
for task in tqdm(tasks): |
|
input_text = task["input"] |
|
prompt = f"入力: {input_text}\n出力: " |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0") |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_length=512, |
|
temperature=0.7, |
|
do_sample=False, |
|
repetition_penalty=1.2 |
|
) |
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
output_text = generated_text.split("出力: ")[-1].strip() |
|
|
|
results.append({ |
|
"task_id": task["task_id"], |
|
"output": output_text |
|
}) |
|
|
|
# 結果の保存 |
|
with open(output_jsonl_path, 'w', encoding='utf-8') as f: |
|
for result in results: |
|
json.dump(result, f, ensure_ascii=False) |
|
f.write('\n') |
|
|
|
# 使用例 |
|
input_file = "path/to/input.jsonl" |
|
output_file = "path/to/output.jsonl" |
|
generate_task_outputs(input_file, output_file) |
|
``` |