heegyu commited on
Commit
924483b
1 Parent(s): 56a76eb

min-new-token 추가

Browse files
Files changed (2) hide show
  1. app.py +34 -19
  2. test.ipynb +120 -139
app.py CHANGED
@@ -2,13 +2,40 @@ import gradio as gr
2
  import torch
3
  import random
4
  import time
5
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- generator = pipeline(
8
- 'text-generation',
9
- model="heegyu/bluechat-v0",
10
- device="cuda:0" if torch.cuda.is_available() else 'cpu'
11
- )
12
 
13
  def query(message, chat_history, max_turn=4):
14
  prompt = []
@@ -21,19 +48,7 @@ def query(message, chat_history, max_turn=4):
21
  prompt.append(f"<usr> {message}")
22
  prompt = "\n".join(prompt) + "\n<bot>"
23
 
24
- output = generator(
25
- prompt,
26
- # repetition_penalty=1.3,
27
- # no_repeat_ngram_size=2,
28
- eos_token_id=2, # \n
29
- max_new_tokens=128,
30
- do_sample=True,
31
- top_p=0.9,
32
- )[0]['generated_text']
33
-
34
- print(output)
35
-
36
- response = output[len(prompt):]
37
  return response.strip()
38
 
39
  with gr.Blocks() as demo:
 
2
  import torch
3
  import random
4
  import time
5
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
6
+
7
+
8
+ model_name="heegyu/bluechat-v0"
9
+ device="cuda:0" if torch.cuda.is_available() else 'cpu'
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
+ # generator = pipeline(
14
+ # 'text-generation',
15
+ # model="heegyu/bluechat-v0",
16
+ # device="cuda:0" if torch.cuda.is_available() else 'cpu'
17
+ # )
18
+
19
+ def get_message(prompt, min_new_tokens=16, max_turn=4):
20
+ prompt = prompt.strip()
21
+ ids = tokenizer(prompt, return_tensors="pt").to(device)
22
+ min_length = ids['input_ids'].shape[1] + min_new_tokens
23
+
24
+ output = model.generate(
25
+ **ids,
26
+ no_repeat_ngram_size=3,
27
+ eos_token_id=2, # 375=\n 2=</s>, 0:open-end
28
+ max_new_tokens=128,
29
+ min_length=min_length,
30
+ do_sample=True,
31
+ top_p=0.7,
32
+ early_stopping=True
33
+ ) # [0]['generated_text']
34
+
35
+ output = tokenizer.decode(output.cpu()[0])
36
+ print(output)
37
+ return output[len(prompt):]
38
 
 
 
 
 
 
39
 
40
  def query(message, chat_history, max_turn=4):
41
  prompt = []
 
48
  prompt.append(f"<usr> {message}")
49
  prompt = "\n".join(prompt) + "\n<bot>"
50
 
51
+ response = get_message(prompt, 8)
 
 
 
 
 
 
 
 
 
 
 
 
52
  return response.strip()
53
 
54
  with gr.Blocks() as demo:
test.ipynb CHANGED
@@ -2,161 +2,42 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "metadata": {},
7
- "outputs": [
8
- {
9
- "name": "stderr",
10
- "output_type": "stream",
11
- "text": [
12
- "/opt/anaconda3/lib/python3.9/site-packages/huggingface_hub/utils/_hf_folder.py:92: UserWarning: A token has been found in `/Users/casa/.huggingface/token`. This is the old path where tokens were stored. The new location is `/Users/casa/.cache/huggingface/token` which is configurable using `HF_HOME` environment variable. Your token has been copied to this new location. You can now safely delete the old token file manually or use `huggingface-cli logout`.\n",
13
- " warnings.warn(\n"
14
- ]
15
- },
16
- {
17
- "data": {
18
- "application/vnd.jupyter.widget-view+json": {
19
- "model_id": "e42b34cf3f07417592f26316fea86e1a",
20
- "version_major": 2,
21
- "version_minor": 0
22
- },
23
- "text/plain": [
24
- "Downloading (…)lve/main/config.json: 0%| | 0.00/944 [00:00<?, ?B/s]"
25
- ]
26
- },
27
- "metadata": {},
28
- "output_type": "display_data"
29
- },
30
- {
31
- "data": {
32
- "application/vnd.jupyter.widget-view+json": {
33
- "model_id": "4f89d76d6b7e4cf59a9dd631bd739221",
34
- "version_major": 2,
35
- "version_minor": 0
36
- },
37
- "text/plain": [
38
- "Downloading pytorch_model.bin: 0%| | 0.00/1.66G [00:00<?, ?B/s]"
39
- ]
40
- },
41
- "metadata": {},
42
- "output_type": "display_data"
43
- },
44
- {
45
- "data": {
46
- "application/vnd.jupyter.widget-view+json": {
47
- "model_id": "a690f8b53a204d489f4d53a937068ac6",
48
- "version_major": 2,
49
- "version_minor": 0
50
- },
51
- "text/plain": [
52
- "Downloading (…)neration_config.json: 0%| | 0.00/111 [00:00<?, ?B/s]"
53
- ]
54
- },
55
- "metadata": {},
56
- "output_type": "display_data"
57
- },
58
- {
59
- "data": {
60
- "application/vnd.jupyter.widget-view+json": {
61
- "model_id": "14302bef459f485a998d908b131f43ec",
62
- "version_major": 2,
63
- "version_minor": 0
64
- },
65
- "text/plain": [
66
- "Downloading (…)okenizer_config.json: 0%| | 0.00/771 [00:00<?, ?B/s]"
67
- ]
68
- },
69
- "metadata": {},
70
- "output_type": "display_data"
71
- },
72
- {
73
- "data": {
74
- "application/vnd.jupyter.widget-view+json": {
75
- "model_id": "33826da838e1402581f62fafd3657b90",
76
- "version_major": 2,
77
- "version_minor": 0
78
- },
79
- "text/plain": [
80
- "Downloading (…)olve/main/vocab.json: 0%| | 0.00/1.27M [00:00<?, ?B/s]"
81
- ]
82
- },
83
- "metadata": {},
84
- "output_type": "display_data"
85
- },
86
- {
87
- "data": {
88
- "application/vnd.jupyter.widget-view+json": {
89
- "model_id": "3ebc87d16a79449998bcb21e33d2ec0b",
90
- "version_major": 2,
91
- "version_minor": 0
92
- },
93
- "text/plain": [
94
- "Downloading (…)olve/main/merges.txt: 0%| | 0.00/925k [00:00<?, ?B/s]"
95
- ]
96
- },
97
- "metadata": {},
98
- "output_type": "display_data"
99
- },
100
- {
101
- "data": {
102
- "application/vnd.jupyter.widget-view+json": {
103
- "model_id": "d70c4a2755d04e0d995686f9425b49f8",
104
- "version_major": 2,
105
- "version_minor": 0
106
- },
107
- "text/plain": [
108
- "Downloading (…)/main/tokenizer.json: 0%| | 0.00/3.07M [00:00<?, ?B/s]"
109
- ]
110
- },
111
- "metadata": {},
112
- "output_type": "display_data"
113
- },
114
- {
115
- "data": {
116
- "application/vnd.jupyter.widget-view+json": {
117
- "model_id": "cd341cbb7ff445daa312695cc9be1a13",
118
- "version_major": 2,
119
- "version_minor": 0
120
- },
121
- "text/plain": [
122
- "Downloading (…)cial_tokens_map.json: 0%| | 0.00/96.0 [00:00<?, ?B/s]"
123
- ]
124
- },
125
- "metadata": {},
126
- "output_type": "display_data"
127
- }
128
- ],
129
  "source": [
130
  "import torch\n",
131
  "import random\n",
132
  "import time\n",
133
- "from transformers import pipeline\n",
134
  "\n",
135
- "generator = pipeline(\n",
136
- " 'text-generation',\n",
137
- " model=\"heegyu/bluechat-v0\",\n",
138
- " device=\"cuda:0\" if torch.cuda.is_available() else 'cpu'\n",
139
- ")"
140
  ]
141
  },
142
  {
143
  "cell_type": "code",
144
- "execution_count": 32,
145
  "metadata": {},
146
  "outputs": [],
147
  "source": [
148
  "\n",
149
- "def query(prompt, max_turn=4):\n",
150
- " output = generator(\n",
151
- " prompt.strip(),\n",
152
- " # no_repeat_ngram_size=2,\n",
153
- " eos_token_id=0, # 375=\\n 2=</s>, 0:open-end\n",
 
 
154
  " max_new_tokens=128,\n",
 
155
  " do_sample=True,\n",
156
  " top_p=0.7,\n",
157
  " early_stopping=True\n",
158
- " )[0]['generated_text']\n",
159
- "\n",
160
  " print(output)\n",
161
  "\n",
162
  " # response = output[len(prompt):]\n",
@@ -165,19 +46,34 @@
165
  },
166
  {
167
  "cell_type": "code",
168
- "execution_count": 33,
169
  "metadata": {},
170
  "outputs": [
 
 
 
 
 
 
 
171
  {
172
  "name": "stdout",
173
  "output_type": "stream",
174
  "text": [
 
175
  "0 : 안녕하세요</s>\n",
176
  "1 : 반가워요</s>\n",
177
  "0 : 요즘 좋아하는 음악 있으신가요?</s>\n",
178
  "1 : 최근에 들어서인지 너무 많이 들어요</s>\n",
179
  "0 : 음 주로 어떤거요?</s>\n",
180
- "1 : 최근에 들어올린 음악은 무엇인가요?0 : 네 키키 제가 좋아하는 곡은 바로 아이유에요1 : 아 아이유 노래 정말 좋네요0 : 아이유 노래 참 좋아요1 : 아이유 노래 진짜 좋아요0 : 아 진짜 아이유 노래 잘부르세요1 : 네 아이유 노래 좋아요0 : 아이유 노래 진짜 좋죠1 : 아 진짜 좋네요0 : 아이유 노래는 참 좋아요1 : 아이유 노래 정말 좋아요0 : 아이유 노래 정말 좋아요1 : 아이유 노래 정말 좋아요0 : 아이유 노래 진짜 좋아요1 : 아이유 노래 정말 좋아요0 : 아 진짜 좋아요1 : 아 진짜 좋아요0 : 아이유 노래\n"
 
 
 
 
 
 
 
181
  ]
182
  }
183
  ],
@@ -191,6 +87,91 @@
191
  "1 : \n",
192
  "\"\"\")"
193
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  }
195
  ],
196
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 40,
6
  "metadata": {},
7
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  "source": [
9
  "import torch\n",
10
  "import random\n",
11
  "import time\n",
12
+ "from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\n",
13
  "\n",
14
+ "model_name=\"heegyu/bluechat-v0\"\n",
15
+ "device=\"cuda:0\" if torch.cuda.is_available() else 'cpu'\n",
16
+ "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
17
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
 
18
  ]
19
  },
20
  {
21
  "cell_type": "code",
22
+ "execution_count": 54,
23
  "metadata": {},
24
  "outputs": [],
25
  "source": [
26
  "\n",
27
+ "def query(prompt, min_new_tokens=16, max_turn=4):\n",
28
+ " ids = tokenizer(prompt.strip(), return_tensors=\"pt\").to(device)\n",
29
+ " min_length = ids['input_ids'].shape[1] + min_new_tokens\n",
30
+ " output = model.generate(\n",
31
+ " **ids,\n",
32
+ " no_repeat_ngram_size=3,\n",
33
+ " eos_token_id=2, # 375=\\n 2=</s>, 0:open-end\n",
34
  " max_new_tokens=128,\n",
35
+ " min_length=min_length,\n",
36
  " do_sample=True,\n",
37
  " top_p=0.7,\n",
38
  " early_stopping=True\n",
39
+ " ) # [0]['generated_text']\n",
40
+ " output = tokenizer.decode(output.cpu()[0])\n",
41
  " print(output)\n",
42
  "\n",
43
  " # response = output[len(prompt):]\n",
 
46
  },
47
  {
48
  "cell_type": "code",
49
+ "execution_count": 42,
50
  "metadata": {},
51
  "outputs": [
52
+ {
53
+ "name": "stderr",
54
+ "output_type": "stream",
55
+ "text": [
56
+ "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n"
57
+ ]
58
+ },
59
  {
60
  "name": "stdout",
61
  "output_type": "stream",
62
  "text": [
63
+ "\n",
64
  "0 : 안녕하세요</s>\n",
65
  "1 : 반가워요</s>\n",
66
  "0 : 요즘 좋아하는 음악 있으신가요?</s>\n",
67
  "1 : 최근에 들어서인지 너무 많이 들어요</s>\n",
68
  "0 : 음 주로 어떤거요?</s>\n",
69
+ "1 : \n",
70
+ " music : music songs 수록곡을 즐겨들어요</s><bot> 앗 어떤 장르를 주로 들으시나요?</s>\n",
71
+ "1 : music songs 좋죠</s>\n",
72
+ "bot> 저도 요즘 들어 좋아하게 된 곡들 위주로 들어요 ㅎㅎ</s>\n",
73
+ "2 : music songs 어떤 노래들 자주 들어요?</s>\n",
74
+ "bot> 저 music songs someone이 제일 좋더라구요 ㅎㅎ</s>\n",
75
+ "1 : music songs는 어떤 곡들 주로 들어요?</s>\n",
76
+ "bot> 저 music songs는 주로 music songs를 많이 들어요 ㅎㅎ</s>\n"
77
  ]
78
  }
79
  ],
 
87
  "1 : \n",
88
  "\"\"\")"
89
  ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 48,
94
+ "metadata": {},
95
+ "outputs": [
96
+ {
97
+ "name": "stderr",
98
+ "output_type": "stream",
99
+ "text": [
100
+ "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
101
+ ]
102
+ },
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "<usr> 안녕하세요\n",
108
+ "<bot> 안녕하세요~ 저녁 드셨나요? ㅎㅎ? ㅎㅎ</s>\n"
109
+ ]
110
+ }
111
+ ],
112
+ "source": [
113
+ "query(\"\"\"\n",
114
+ "<usr> 안녕하세요\n",
115
+ "<bot>\n",
116
+ "\"\"\", 8)"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": 55,
122
+ "metadata": {},
123
+ "outputs": [
124
+ {
125
+ "name": "stderr",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
129
+ ]
130
+ },
131
+ {
132
+ "name": "stdout",
133
+ "output_type": "stream",
134
+ "text": [
135
+ "<usr> 안녕하세요 식사 하셨나요?\n",
136
+ "<bot> 안녕하세요 네~ 점심 먹었어요 식사하셨나요?\n",
137
+ "네~ 뭐드셨나요?</s>\n"
138
+ ]
139
+ }
140
+ ],
141
+ "source": [
142
+ "query(\"\"\"\n",
143
+ "<usr> 안녕하세요 식사 하셨나요?\n",
144
+ "<bot>\n",
145
+ "\"\"\", 8)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 63,
151
+ "metadata": {},
152
+ "outputs": [
153
+ {
154
+ "name": "stderr",
155
+ "output_type": "stream",
156
+ "text": [
157
+ "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
158
+ ]
159
+ },
160
+ {
161
+ "name": "stdout",
162
+ "output_type": "stream",
163
+ "text": [
164
+ "<usr> 창업에 관심이 있나요?\n",
165
+ "<bot> 네! 근데 요즘 창업에 대한 관심이 많이 떨어지더라구요</s>\n"
166
+ ]
167
+ }
168
+ ],
169
+ "source": [
170
+ "query(\"\"\"\n",
171
+ "<usr> 창업에 관심이 있나요?\n",
172
+ "<bot>\n",
173
+ "\"\"\", 8)"
174
+ ]
175
  }
176
  ],
177
  "metadata": {