--- license: apache-2.0 datasets: - kinokokoro/ichikara-instruction-003 data_license: - cc language: - ja base_model: - llm-jp/llm-jp-3-13b library_name: transformers tags: - text-generation-inference - transformers --- # Sample Use ```python from tqdm import tqdm import json import os MODEL_DIR = os.path.join(BASE_DIR, "fine_tuned_model") def generate_predictions(model, tokenizer, input_file, output_file): # バッチ処理の追加 BATCH_SIZE = 16 # バッチサイズの設定 print(f"入力ファイルを読み込み中: {input_file}") tasks = [] with open(input_file, 'r', encoding='utf-8') as f: for line in f: tasks.append(json.loads(line)) results = [] print("推論を実行中...") # バッチ処理 for i in tqdm(range(0, len(tasks), BATCH_SIZE)): batch_tasks = tasks[i:i + BATCH_SIZE] prompts = [f"入力: {task['input']}\n出力: " for task in batch_tasks] # バッチでの推論 inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=512 ) with torch.no_grad(): outputs = model.generate( inputs.input_ids, max_length=512, temperature=0.9, do_sample=False, repetition_penalty=1.2, pad_token_id=tokenizer.pad_token_id, top_k=50, top_p=0.95, early_stopping=True, # 早期停止を有効化 use_cache=True # キャッシュを使用 ) # バッチ出力の処理 for k, task in enumerate(batch_tasks): # 各タスクについてループ output_index = k # インデックスはタスクごとに1つだけ if output_index < len(outputs): # 範囲外アクセスを防ぐ generated_text = tokenizer.decode(outputs[output_index], skip_special_tokens=True) output_text = generated_text.split("出力: ")[-1].strip() results.append({ "task_id": task["task_id"], # 正しいタスクIDを取得 "output": output_text # 対応する出力 }) print(f"結果を保存中: {output_file}") with open(output_file, 'w', encoding='utf-8') as f: for result in results: json.dump(result, f, ensure_ascii=False) f.write('\n') def main(): # GPUメモリのクリア torch.cuda.empty_cache() # 時間計測の追加 import time start_time = time.time() model, tokenizer = load_model() input_file = "{$file_path}" output_file = os.path.join(BASE_DIR, "{$file_path}") generate_predictions(model, tokenizer, input_file, output_file) # 実行時間の表示 elapsed_time = time.time() - start_time print(f"総実行時間: {elapsed_time / 60:.2f}分") if __name__ == "__main__": main() ```