masamori commited on
Commit
f331232
1 Parent(s): 2687299

Update README

Browse files
Files changed (1) hide show
  1. README.md +87 -3
README.md CHANGED
@@ -1,3 +1,87 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - kinokokoro/ichikara-instruction-003
5
+ language:
6
+ - ja
7
+ base_model:
8
+ - llm-jp/llm-jp-3-13b
9
+ library_name: transformers
10
+ tags:
11
+ - text-generation-inference
12
+ - transformers
13
+ ---
14
+
15
+ # Sample Use
16
+ ```python
17
+ MODEL_DIR = os.path.join("model_dir")
18
+
19
+ def load_model():
20
+ print("モデルとトークナイザーを読み込み中...")
21
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_DIR,
24
+ torch_dtype=torch.float16,
25
+ device_map={"": 0}, # 明示的にGPU割り当て
26
+ use_cache=True, # キャッシュを有効化
27
+ ).to('cuda') # 明示的にGPUへ
28
+
29
+ model.eval() # 評価モード
30
+ return model, tokenizer
31
+
32
+ def generate_predictions(model, tokenizer, input_file, output_file):
33
+ # バッチ処理の追加
34
+ BATCH_SIZE = 8 # バッチサイズの設定
35
+
36
+ print(f"入力ファイルを読み込み中: {input_file}")
37
+ tasks = []
38
+ with open(input_file, 'r', encoding='utf-8') as f:
39
+ for line in f:
40
+ tasks.append(json.loads(line))
41
+
42
+ results = []
43
+ print("推論を実行中...")
44
+
45
+ # バッチ処理
46
+ for i in tqdm(range(0, len(tasks), BATCH_SIZE)):
47
+ batch_tasks = tasks[i:i + BATCH_SIZE]
48
+ prompts = [f"入力: {task['input']}\n出力: " for task in batch_tasks]
49
+
50
+ # バッチでの推論
51
+ inputs = tokenizer(
52
+ prompts,
53
+ return_tensors="pt",
54
+ padding=True,
55
+ truncation=True,
56
+ max_length=512
57
+ ).to('cuda')
58
+
59
+ with torch.no_grad():
60
+ outputs = model.generate(
61
+ inputs.input_ids,
62
+ max_length=512,
63
+ temperature=0.7,
64
+ do_sample=False,
65
+ repetition_penalty=1.2,
66
+ pad_token_id=tokenizer.pad_token_id,
67
+ num_return_sequences=1,
68
+ early_stopping=True, # 早期停止を有効化
69
+ use_cache=True # キャッシュを使用
70
+ )
71
+
72
+ # バッチ出力の処理
73
+ for j, output in enumerate(outputs):
74
+ generated_text = tokenizer.decode(output, skip_special_tokens=True)
75
+ output_text = generated_text.split("出力: ")[-1].strip()
76
+
77
+ results.append({
78
+ "task_id": batch_tasks[j]["task_id"],
79
+ "output": output_text
80
+ })
81
+
82
+ print(f"結果を保存中: {output_file}")
83
+ with open(output_file, 'w', encoding='utf-8') as f:
84
+ for result in results:
85
+ json.dump(result, f, ensure_ascii=False)
86
+ f.write('\n')
87
+ ```