Sakalti commited on
Commit
7730746
·
verified ·
1 Parent(s): f7961d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -32
app.py CHANGED
@@ -1,36 +1,10 @@
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from mergekit.config import MergeConfig
4
- from mergekit.merge import Merger
5
 
6
- # MergeKit の設定を読み込む
7
- merge_config = MergeConfig.from_dict({
8
- "slices": [
9
- {
10
- "sources": [
11
- {"model": "Qwen/Qwen2.5-0.5B-Instruct", "layer_range": [0, 23]},
12
- {"model": "Qwen/Qwen2.5-0.5B-Instruct", "layer_range": [0, 23]}
13
- ]
14
- }
15
- ],
16
- "merge_method": "slerp",
17
- "base_model": "Qwen/Qwen2.5-1.5B-Instruct",
18
- "parameters": {
19
- "t": [
20
- {"filter": "self_attn", "value": [0, 0.5, 0.3, 0.7, 1]},
21
- {"filter": "mlp", "value": [1, 0.5, 0.7, 0.3, 0]},
22
- {"value": 0.5}
23
- ]
24
- },
25
- "dtype": "bfloat16"
26
- })
27
-
28
- # マージされたモデルを生成
29
- merger = Merger(merge_config)
30
- merged_model = merger.merge()
31
-
32
- # トークナイザーの読み込み
33
- tokenizer = AutoTokenizer.from_pretrained(merge_config.base_model)
34
 
35
  # 生成用の関数
36
  def respond(input_text, system_message, max_new_tokens, temperature, top_p):
@@ -44,7 +18,7 @@ def respond(input_text, system_message, max_new_tokens, temperature, top_p):
44
  print(inputs['input_ids'].shape)
45
 
46
  # モデルに入力を渡して生成
47
- outputs = merged_model.generate(
48
  input_ids=inputs['input_ids'],
49
  attention_mask=inputs['attention_mask'],
50
  max_new_tokens=max_new_tokens, # 新規トークンの最大数
@@ -62,7 +36,7 @@ def respond(input_text, system_message, max_new_tokens, temperature, top_p):
62
 
63
  # Gradioインターフェースの作成
64
  with gr.Blocks() as demo:
65
- gr.Markdown("## Sakalti/aquamarine チャットボット")
66
 
67
  # 追加の入力フィールドをリストで設定
68
  additional_inputs = [
 
1
  import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
3
 
4
+ # モデルとトークナイザーの読み込み
5
+ model_name = "Sakalti/chromet1-0.48b"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForCausalLM.from_pretrained(model_name, ignore_mismatched_sizes=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # 生成用の関数
10
  def respond(input_text, system_message, max_new_tokens, temperature, top_p):
 
18
  print(inputs['input_ids'].shape)
19
 
20
  # モデルに入力を渡して生成
21
+ outputs = model.generate(
22
  input_ids=inputs['input_ids'],
23
  attention_mask=inputs['attention_mask'],
24
  max_new_tokens=max_new_tokens, # 新規トークンの最大数
 
36
 
37
  # Gradioインターフェースの作成
38
  with gr.Blocks() as demo:
39
+ gr.Markdown("## chromet チャットボット")
40
 
41
  # 追加の入力フィールドをリストで設定
42
  additional_inputs = [