yumyeom commited on
Commit
b2c1876
·
1 Parent(s): 7138209

경량화 모델 및 기타 file commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/gaepago-20-lite filter=lfs diff=lfs merge=lfs -text
37
+ model/gaepago-20-lite/model_quant_int8.pt filter=lfs diff=lfs merge=lfs -text
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -4,7 +4,7 @@
4
  from transformers import AutoModelForAudioClassification
5
  from transformers import AutoFeatureExtractor
6
  from transformers import pipeline
7
- from datasets import Dataset
8
  import gradio as gr
9
  import torch
10
 
@@ -13,7 +13,10 @@ MODEL_NAME = "Gae8J/gaepago-20"
13
  DATASET_NAME = "Gae8J/modeling_v1"
14
 
15
  # Import Model & feature extractor
16
- model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
 
 
 
17
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
18
 
19
  # 모델 cpu로 변경하여 진행
@@ -27,9 +30,12 @@ def gaepago_fn(tmp_audio_dir):
27
  ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
28
  ,return_tensors="pt")
29
  with torch.no_grad():
30
- logits = model(**inputs).logits
 
 
 
31
  predicted_class_ids = torch.argmax(logits).item()
32
- predicted_label = model.config.id2label[predicted_class_ids]
33
 
34
  return predicted_label
35
 
@@ -47,4 +53,4 @@ with main_api:
47
  b1.click(gaepago_fn, inputs=audio, outputs=transcription)
48
  # examples = gr.Examples(examples=example_list,
49
  # inputs=[audio])
50
- main_api.launch()
 
4
  from transformers import AutoModelForAudioClassification
5
  from transformers import AutoFeatureExtractor
6
  from transformers import pipeline
7
+ from datasets import Dataset, Audio
8
  import gradio as gr
9
  import torch
10
 
 
13
  DATASET_NAME = "Gae8J/modeling_v1"
14
 
15
  # Import Model & feature extractor
16
+ # model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
17
+ from transformers import AutoConfig
18
+ config = AutoConfig.from_pretrained(MODEL_NAME)
19
+ model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt")
20
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
21
 
22
  # 모델 cpu로 변경하여 진행
 
30
  ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
31
  ,return_tensors="pt")
32
  with torch.no_grad():
33
+ # logits = model(**inputs).logits
34
+ logits = model(**inputs)["logits"]
35
+ # predicted_class_ids = torch.argmax(logits).item()
36
+ # predicted_label = model.config.id2label[predicted_class_ids]
37
  predicted_class_ids = torch.argmax(logits).item()
38
+ predicted_label = config.id2label[predicted_class_ids]
39
 
40
  return predicted_label
41
 
 
53
  b1.click(gaepago_fn, inputs=audio, outputs=transcription)
54
  # examples = gr.Examples(examples=example_list,
55
  # inputs=[audio])
56
+ main_api.launch(share=True)
.ipynb_checkpoints/eval_and_inference-checkpoint.ipynb ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "544a588c-68ff-440f-be5c-389f1f02a0b7",
6
+ "metadata": {},
7
+ "source": [
8
+ "# example"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "7ef8c97c-cefd-4905-8d63-af303c412d1a",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "MODEL_NAME = \"gaepago-20\"\n",
19
+ "DATASET_NAME = \"Gae8J/modeling_v1\""
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "044499ce-7821-4b59-9f4b-5971b6a24cce",
25
+ "metadata": {},
26
+ "source": [
27
+ "## load dataset (test data)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "e827e3bb-820d-46b3-b2e8-fdb97787bde1",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/Gae8J___parquet/Gae8J--modeling_v1-b480c78c61a26816/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
41
+ ]
42
+ },
43
+ {
44
+ "data": {
45
+ "application/vnd.jupyter.widget-view+json": {
46
+ "model_id": "f078fd108d2044b48a961bee6ed49747",
47
+ "version_major": 2,
48
+ "version_minor": 0
49
+ },
50
+ "text/plain": [
51
+ " 0%| | 0/3 [00:00<?, ?it/s]"
52
+ ]
53
+ },
54
+ "metadata": {},
55
+ "output_type": "display_data"
56
+ }
57
+ ],
58
+ "source": [
59
+ "from datasets import load_dataset, Audio\n",
60
+ "\n",
61
+ "dataset = load_dataset(DATASET_NAME)\n",
62
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
63
+ "test_data = dataset['test']\n",
64
+ "sampling_rate = test_data.features[\"audio\"].sampling_rate"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "id": "d0c16b3d-32dd-4e61-86bd-e21232840e98",
70
+ "metadata": {},
71
+ "source": [
72
+ "## run"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 5,
78
+ "id": "d504778d-4ba3-43d3-b22b-76ce838a5edf",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "from transformers import AutoModelForAudioClassification\n",
83
+ "from transformers import AutoFeatureExtractor\n",
84
+ "import torch\n",
85
+ "\n",
86
+ "model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)\n",
87
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)\n",
88
+ "\n",
89
+ "preds = []\n",
90
+ "gts = []\n",
91
+ "for i in range(len(test_data)):\n",
92
+ " inputs = feature_extractor(test_data[i][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
93
+ " with torch.no_grad():\n",
94
+ " logits = model(**inputs).logits\n",
95
+ " predicted_class_ids = torch.argmax(logits).item()\n",
96
+ " predicted_label = model.config.id2label[predicted_class_ids]\n",
97
+ " preds.append(predicted_label)\n",
98
+ " gts.append(model.config.id2label[test_data[i]['label']])"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "f200bec5-c2d9-4549-8bb8-1400c484f499",
104
+ "metadata": {},
105
+ "source": [
106
+ "## performance"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 6,
112
+ "id": "be97683d-da60-4d23-abc9-0be9b86cd636",
113
+ "metadata": {},
114
+ "outputs": [
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ " precision recall f1-score support\n",
120
+ "\n",
121
+ " bark 0.56 0.62 0.59 8\n",
122
+ " growling 1.00 0.83 0.91 6\n",
123
+ " howl 0.75 0.86 0.80 7\n",
124
+ " panting 1.00 0.80 0.89 10\n",
125
+ " whimper 0.38 0.43 0.40 7\n",
126
+ "\n",
127
+ " accuracy 0.71 38\n",
128
+ " macro avg 0.74 0.71 0.72 38\n",
129
+ "weighted avg 0.75 0.71 0.72 38\n",
130
+ "\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "from sklearn.metrics import classification_report\n",
136
+ "test_performance = classification_report(gts, preds)\n",
137
+ "print(test_performance)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "id": "ea3ee48d-19c7-4f9d-9c2c-4b03d4748acb",
143
+ "metadata": {},
144
+ "source": [
145
+ "## load dataset (validation data)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 7,
151
+ "id": "33e5051e-75a2-4523-905c-fe1dbc81eda2",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "name": "stderr",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "WARNING:datasets.builder:Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/Gae8J___parquet/Gae8J--modeling_v1-b480c78c61a26816/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
159
+ ]
160
+ },
161
+ {
162
+ "data": {
163
+ "application/vnd.jupyter.widget-view+json": {
164
+ "model_id": "cf5cfe439c174b8284b4668419af6dca",
165
+ "version_major": 2,
166
+ "version_minor": 0
167
+ },
168
+ "text/plain": [
169
+ " 0%| | 0/3 [00:00<?, ?it/s]"
170
+ ]
171
+ },
172
+ "metadata": {},
173
+ "output_type": "display_data"
174
+ }
175
+ ],
176
+ "source": [
177
+ "from datasets import load_dataset, Audio\n",
178
+ "\n",
179
+ "dataset = load_dataset(DATASET_NAME)\n",
180
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
181
+ "test_data = dataset['validation']\n",
182
+ "sampling_rate = test_data.features[\"audio\"].sampling_rate"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "id": "36bee3b3-e66f-46dc-8030-cef3cb62ff97",
188
+ "metadata": {},
189
+ "source": [
190
+ "## run"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 9,
196
+ "id": "914a471c-5d76-482b-a4f3-3c5eeebdd697",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "from transformers import AutoModelForAudioClassification\n",
201
+ "import torch\n",
202
+ "\n",
203
+ "model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)\n",
204
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)\n",
205
+ "\n",
206
+ "preds = []\n",
207
+ "gts = []\n",
208
+ "for i in range(len(test_data)):\n",
209
+ " inputs = feature_extractor(test_data[i][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
210
+ " with torch.no_grad():\n",
211
+ " logits = model(**inputs).logits\n",
212
+ " predicted_class_ids = torch.argmax(logits).item()\n",
213
+ " predicted_label = model.config.id2label[predicted_class_ids]\n",
214
+ " preds.append(predicted_label)\n",
215
+ " gts.append(model.config.id2label[test_data[i]['label']])"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "4f1d5bab-4f88-4628-918e-d14b29c2143b",
221
+ "metadata": {},
222
+ "source": [
223
+ "## performance"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 10,
229
+ "id": "26e0c704-b5b6-4bf0-8b58-1e3615b76cb7",
230
+ "metadata": {},
231
+ "outputs": [
232
+ {
233
+ "name": "stdout",
234
+ "output_type": "stream",
235
+ "text": [
236
+ " precision recall f1-score support\n",
237
+ "\n",
238
+ " bark 0.75 0.67 0.71 9\n",
239
+ " growling 1.00 0.71 0.83 7\n",
240
+ " howl 0.86 0.86 0.86 7\n",
241
+ " panting 1.00 0.70 0.82 10\n",
242
+ " whimper 0.54 1.00 0.70 7\n",
243
+ "\n",
244
+ " accuracy 0.78 40\n",
245
+ " macro avg 0.83 0.79 0.78 40\n",
246
+ "weighted avg 0.84 0.78 0.78 40\n",
247
+ "\n"
248
+ ]
249
+ }
250
+ ],
251
+ "source": [
252
+ "from sklearn.metrics import classification_report\n",
253
+ "valid_performance = classification_report(gts, preds)\n",
254
+ "print(valid_performance)"
255
+ ]
256
+ }
257
+ ],
258
+ "metadata": {
259
+ "kernelspec": {
260
+ "display_name": "g3p8",
261
+ "language": "python",
262
+ "name": "g3p8"
263
+ },
264
+ "language_info": {
265
+ "codemirror_mode": {
266
+ "name": "ipython",
267
+ "version": 3
268
+ },
269
+ "file_extension": ".py",
270
+ "mimetype": "text/x-python",
271
+ "name": "python",
272
+ "nbconvert_exporter": "python",
273
+ "pygments_lexer": "ipython3",
274
+ "version": "3.7.9"
275
+ }
276
+ },
277
+ "nbformat": 4,
278
+ "nbformat_minor": 5
279
+ }
.ipynb_checkpoints/eval_and_inference_lite_v1-checkpoint.ipynb ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "544a588c-68ff-440f-be5c-389f1f02a0b7",
6
+ "metadata": {},
7
+ "source": [
8
+ "# example"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "7ef8c97c-cefd-4905-8d63-af303c412d1a",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "MODEL_NAME = \"gaepago-20-lite\"\n",
19
+ "DATASET_NAME = \"Gae8J/modeling_v1\""
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "044499ce-7821-4b59-9f4b-5971b6a24cce",
25
+ "metadata": {},
26
+ "source": [
27
+ "## load dataset (test data)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "e827e3bb-820d-46b3-b2e8-fdb97787bde1",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "WARNING:datasets.builder:Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/Gae8J___parquet/Gae8J--modeling_v1-b480c78c61a26816/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
41
+ ]
42
+ },
43
+ {
44
+ "data": {
45
+ "application/vnd.jupyter.widget-view+json": {
46
+ "model_id": "4438f0b33464423b92fecc698c1935e5",
47
+ "version_major": 2,
48
+ "version_minor": 0
49
+ },
50
+ "text/plain": [
51
+ " 0%| | 0/3 [00:00<?, ?it/s]"
52
+ ]
53
+ },
54
+ "metadata": {},
55
+ "output_type": "display_data"
56
+ }
57
+ ],
58
+ "source": [
59
+ "from datasets import load_dataset, Audio\n",
60
+ "from transformers import AutoFeatureExtractor\n",
61
+ "dataset = load_dataset(DATASET_NAME)\n",
62
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
63
+ "test_data = dataset['test']\n",
64
+ "sampling_rate = test_data.features[\"audio\"].sampling_rate\n",
65
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 7,
71
+ "id": "779c547a-7e27-4481-8a66-fd9900e41964",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "from transformers import AutoConfig\n",
76
+ "config = AutoConfig.from_pretrained(MODEL_NAME)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "id": "03659af7-3d90-4431-a4ea-a8d99e93602f",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "import torch"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 4,
92
+ "id": "0f58cfcf-ba2d-45e4-b4e9-87df88e9dbad",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "loaded_quantized_model = torch.jit.load(\"gaepago-20-lite/model_quant_int8.pt\")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "id": "52212656-a3e9-4bd2-ac2d-427acb5795c6",
102
+ "metadata": {},
103
+ "source": [
104
+ "## 모델결과"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 9,
110
+ "id": "3d4f5365-d6f1-4163-9c47-ce8c89e13884",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "preds = []\n",
115
+ "gts = []\n",
116
+ "# quant_logits_list = []\n",
117
+ "for i in range(len(test_data)):\n",
118
+ " inputs = feature_extractor(test_data[i][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
119
+ " with torch.no_grad():\n",
120
+ " logits = loaded_quantized_model(**inputs)['logits']\n",
121
+ "# quant_logits_list.append(logits)\n",
122
+ " predicted_class_ids = torch.argmax(logits).item()\n",
123
+ " predicted_label = config.id2label[predicted_class_ids]\n",
124
+ " preds.append(predicted_label)\n",
125
+ " gts.append(config.id2label[test_data[i]['label']])"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 10,
131
+ "id": "93b3c424-bab6-4774-915e-9e9f534f762d",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "name": "stdout",
136
+ "output_type": "stream",
137
+ "text": [
138
+ " precision recall f1-score support\n",
139
+ "\n",
140
+ " bark 0.5556 0.6250 0.5882 8\n",
141
+ " growling 1.0000 0.8333 0.9091 6\n",
142
+ " howl 0.7500 0.8571 0.8000 7\n",
143
+ " panting 1.0000 0.8000 0.8889 10\n",
144
+ " whimper 0.3750 0.4286 0.4000 7\n",
145
+ "\n",
146
+ " accuracy 0.7105 38\n",
147
+ " macro avg 0.7361 0.7088 0.7172 38\n",
148
+ "weighted avg 0.7452 0.7105 0.7224 38\n",
149
+ "\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "from sklearn.metrics import classification_report\n",
155
+ "test_performance = classification_report(gts, preds,digits=4)\n",
156
+ "print(test_performance)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "99a3ea38-54c8-4aed-9bbf-12f98bf09dc5",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": []
166
+ }
167
+ ],
168
+ "metadata": {
169
+ "kernelspec": {
170
+ "display_name": "g3p8",
171
+ "language": "python",
172
+ "name": "g3p8"
173
+ },
174
+ "language_info": {
175
+ "codemirror_mode": {
176
+ "name": "ipython",
177
+ "version": 3
178
+ },
179
+ "file_extension": ".py",
180
+ "mimetype": "text/x-python",
181
+ "name": "python",
182
+ "nbconvert_exporter": "python",
183
+ "pygments_lexer": "ipython3",
184
+ "version": "3.7.9"
185
+ }
186
+ },
187
+ "nbformat": 4,
188
+ "nbformat_minor": 5
189
+ }
.ipynb_checkpoints/text_label-checkpoint.json ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bark": [
3
+ [
4
+ "너무 신나서 어쩌지?",
5
+ "긍정"
6
+ ],
7
+ [
8
+ "집사, 놀아줘!",
9
+ "긍정"
10
+ ],
11
+ [
12
+ "지금 너무 신나!",
13
+ "긍정"
14
+ ],
15
+ [
16
+ "누가 왔나 봐!",
17
+ "긍정"
18
+ ],
19
+ [
20
+ "놀아줘!! 놀아달란말이야!!",
21
+ "긍정"
22
+ ],
23
+ [
24
+ "안녕 🐶",
25
+ "긍정"
26
+ ],
27
+ [
28
+ "난 너를 좋아하는 걸, 그런데 너는 나를 좋아해?",
29
+ "긍정"
30
+ ],
31
+ [
32
+ "주목해줘! 놀자!",
33
+ "긍정"
34
+ ],
35
+ [
36
+ "놀이 시간이야, 같이 놀자!",
37
+ "긍정"
38
+ ],
39
+ [
40
+ "다가오지마!",
41
+ "부정"
42
+ ],
43
+ [
44
+ "뭔가 이상한 소리 들려!",
45
+ "부정"
46
+ ],
47
+ [
48
+ "경계해, 경계해!",
49
+ "부정"
50
+ ],
51
+ [
52
+ "아니야, 아니야!",
53
+ "부정"
54
+ ],
55
+ [
56
+ "건들지마!!!!",
57
+ "부정"
58
+ ],
59
+ [
60
+ "뭔가 불안해, 도와줘!",
61
+ "부정"
62
+ ],
63
+ [
64
+ "주인~ 뭐해~?",
65
+ "중립"
66
+ ],
67
+ [
68
+ "밖에 뭐가 있는 거 같아!",
69
+ "중립"
70
+ ],
71
+ [
72
+ "이리 와봐!",
73
+ "중립"
74
+ ],
75
+ [
76
+ "날 보고있어?",
77
+ "중립"
78
+ ],
79
+ [
80
+ "밖에 뭐 있어?",
81
+ "중립"
82
+ ],
83
+ [
84
+ "이거 내꺼야!",
85
+ "중립"
86
+ ],
87
+ [
88
+ "물 마실래, 마실 것 좀 줘.",
89
+ "중립"
90
+ ],
91
+ [
92
+ "목이 말라, 물 좀 줄래?",
93
+ "중립"
94
+ ]
95
+ ],
96
+ "growling": [
97
+ [
98
+ "나 좀 내버려 둬!",
99
+ "부정"
100
+ ],
101
+ [
102
+ "더 이상 다가오지마!",
103
+ "부정"
104
+ ],
105
+ [
106
+ "너무 까다로워!",
107
+ "부정"
108
+ ],
109
+ [
110
+ "내가 경계하고 있어!",
111
+ "부정"
112
+ ],
113
+ [
114
+ "빨리 이리 와!",
115
+ "부정"
116
+ ],
117
+ [
118
+ "나 너무 화나!",
119
+ "부정"
120
+ ],
121
+ [
122
+ "나 싸울 준비됐어!",
123
+ "부정"
124
+ ],
125
+ [
126
+ "그만 좀 해!",
127
+ "부정"
128
+ ],
129
+ [
130
+ "내게 장난치지마!",
131
+ "부정"
132
+ ],
133
+ [
134
+ "나 지금 너무 짜증나!",
135
+ "부정"
136
+ ],
137
+ [
138
+ "나 지금 안 좋아!",
139
+ "부정"
140
+ ],
141
+ [
142
+ "다가오지마!",
143
+ "부정"
144
+ ],
145
+ [
146
+ "너에게 화난 거야!",
147
+ "부정"
148
+ ],
149
+ [
150
+ "좀 멀리 가!",
151
+ "부정"
152
+ ],
153
+ [
154
+ "나 싸우려고 준비됐어!",
155
+ "부정"
156
+ ],
157
+ [
158
+ "한번 더 건드리면 물어버릴거야!!!",
159
+ "부정"
160
+ ],
161
+ [
162
+ "나한테 이렇게 위협적으로 다가오지마!",
163
+ "부정"
164
+ ],
165
+ [
166
+ "나의 영역을 침범하면 안돼! 이해해줘!",
167
+ "부정"
168
+ ],
169
+ [
170
+ "그만 좀 귀찮게 해! 내가 분명히 경고했잖아!",
171
+ "부정"
172
+ ],
173
+ [
174
+ "불편해, 물러서줘.",
175
+ "부정"
176
+ ],
177
+ [
178
+ "경고하는 거야, 가까이 오지 마.",
179
+ "부정"
180
+ ],
181
+ [
182
+ "좀 너무 가까워, 거리 좀 둬.",
183
+ "부정"
184
+ ],
185
+ [
186
+ "나를 방해하지 마, 신경 써줘.",
187
+ "부정"
188
+ ],
189
+ [
190
+ "내가 불편해, 거리 좀 두고 있어.",
191
+ "부정"
192
+ ],
193
+ [
194
+ "가까이 오지 마.",
195
+ "부정"
196
+ ],
197
+ [
198
+ "나를 방해하지 마, 존중해줘. Respect Me!!",
199
+ "부정"
200
+ ]
201
+ ],
202
+ "howl": [
203
+ [
204
+ "나 여기있어, 봐줘!",
205
+ "중립"
206
+ ],
207
+ [
208
+ "너 어디 갔어?!",
209
+ "중립"
210
+ ],
211
+ [
212
+ "나 너무 외로워!",
213
+ "중립"
214
+ ],
215
+ [
216
+ "이리 와봐, 나 있는 곳으로!",
217
+ "중립"
218
+ ],
219
+ [
220
+ "너 없으면 너무 심심해!",
221
+ "중립"
222
+ ],
223
+ [
224
+ "나도 같이 가고 싶어!",
225
+ "중립"
226
+ ],
227
+ [
228
+ "나 심심해",
229
+ "중립"
230
+ ],
231
+ [
232
+ "어디야? 나 찾아봐!",
233
+ "중립"
234
+ ],
235
+ [
236
+ "언제 오려고 그래?",
237
+ "중립"
238
+ ],
239
+ [
240
+ "나는 여기 있는데!",
241
+ "중립"
242
+ ],
243
+ [
244
+ "빨리 돌아와줘!",
245
+ "중립"
246
+ ],
247
+ [
248
+ "나 혼자 남겨두지 마!",
249
+ "중립"
250
+ ],
251
+ [
252
+ "나 여기있어!! 나좀 봐줘!!!",
253
+ "중립"
254
+ ],
255
+ [
256
+ "나 잘 보고 있어? 나 괜찮아?",
257
+ "중립"
258
+ ],
259
+ [
260
+ "주인, 나 좀 안아줄 수 있을까?",
261
+ "중립"
262
+ ],
263
+ [
264
+ "외로워, 보고 싶어.",
265
+ "중립"
266
+ ],
267
+ [
268
+ "다른 강아지와 '합창'하고 싶어.",
269
+ "중립"
270
+ ],
271
+ [
272
+ "너를 보고싶어, 언제 와?",
273
+ "중립"
274
+ ],
275
+ [
276
+ "무언가 알려고 하는 중이야.",
277
+ "중립"
278
+ ],
279
+ [
280
+ "다른 강아지들이랑 노래하고 싶어.",
281
+ "긍정"
282
+ ]
283
+ ],
284
+ "panting": [
285
+ [
286
+ "더워~ 에어컨 켜줘.",
287
+ "부정"
288
+ ],
289
+ [
290
+ "운동 후 휴식 중이야.",
291
+ "중립"
292
+ ],
293
+ [
294
+ "숨이 차, 좀 도와줘.",
295
+ "부정"
296
+ ],
297
+ [
298
+ "휴식이 필요해, 좀 쉬자.",
299
+ "부정"
300
+ ],
301
+ [
302
+ "너무 더워, 물 좀 줄래?",
303
+ "부정"
304
+ ],
305
+ [
306
+ "너무 더워, 바람 좀 쐬자.",
307
+ "부정"
308
+ ],
309
+ [
310
+ "힘들게 운동했어, 휴식 좀!",
311
+ "부정"
312
+ ],
313
+ [
314
+ "숨이 차, 쉬는 시간이 필요해.",
315
+ "부정"
316
+ ],
317
+ [
318
+ "휴식이 필요해, 조용히 좀...",
319
+ "부정"
320
+ ],
321
+ [
322
+ "물 좀 마시고 싶어, 줄래?",
323
+ "중립"
324
+ ],
325
+ [
326
+ "많이 뛰어서 힘들어, 휴식이 필요해.",
327
+ "부정"
328
+ ],
329
+ [
330
+ "휴식이 필요해, 좀 더 쉬자.",
331
+ "중립"
332
+ ],
333
+ [
334
+ "너무 더워서 물 좀 마시고 싶어.",
335
+ "중립"
336
+ ],
337
+ [
338
+ "좀 더운 �� 같아, 바람 좀 쐬고 싶어.",
339
+ "중립"
340
+ ],
341
+ [
342
+ "지금 좀 쉴 시간이 필요해, 잠시만 기다려.",
343
+ "중립"
344
+ ],
345
+ [
346
+ "지금 진정할 시간이 필요해!!!",
347
+ "중립"
348
+ ],
349
+ [
350
+ "나 지금 너무 신나",
351
+ "긍정"
352
+ ],
353
+ [
354
+ "너랑 놀면 더 재밌을 것 같아",
355
+ "긍정"
356
+ ],
357
+ [
358
+ "나랑 놀지 않을래?",
359
+ "긍정"
360
+ ],
361
+ [
362
+ "밖에 나가면 재미난 일이 있을 것 같아!",
363
+ "긍정"
364
+ ],
365
+ [
366
+ "오늘은 무슨 일이 있을까? 좋은 일이 생길 것 같아!",
367
+ "긍정"
368
+ ],
369
+ [
370
+ "세상 모든 것들이 반가워~",
371
+ "긍정"
372
+ ],
373
+ [
374
+ "너랑 친해지고 싶어~",
375
+ "긍정"
376
+ ],
377
+ [
378
+ "오늘 기분 아주 나이스~",
379
+ "긍정"
380
+ ],
381
+ [
382
+ "세상에서 제일 좋아!!",
383
+ "긍정"
384
+ ],
385
+ [
386
+ "나 지금 기분이가 좋아~",
387
+ "긍정"
388
+ ],
389
+ [
390
+ "너랑 놀고싶어~",
391
+ "긍정"
392
+ ],
393
+ [
394
+ "오늘 되게 행복한 하루다~",
395
+ "긍정"
396
+ ],
397
+ [
398
+ "오늘 내 생일인가? 너무 행복해><",
399
+ "긍정"
400
+ ],
401
+ [
402
+ "만나서 반가워",
403
+ "긍정"
404
+ ],
405
+ [
406
+ "너는 이름이 뭐니?",
407
+ "긍정"
408
+ ],
409
+ [
410
+ "난 너가 좋아!!",
411
+ "긍정"
412
+ ],
413
+ [
414
+ "나 매우 재밌어",
415
+ "긍정"
416
+ ],
417
+ [
418
+ "나랑 같이 놀러 나가자",
419
+ "긍정"
420
+ ]
421
+ ],
422
+ "whimper": [
423
+ [
424
+ "나 너무 두려워",
425
+ "부정"
426
+ ],
427
+ [
428
+ "나 지금 너무 외로워",
429
+ "부정"
430
+ ],
431
+ [
432
+ "나 너무 슬퍼",
433
+ "부정"
434
+ ],
435
+ [
436
+ "나 좀 안아줘",
437
+ "부정"
438
+ ],
439
+ [
440
+ "나 지금 너무 불편해",
441
+ "부정"
442
+ ],
443
+ [
444
+ "나 너무 피곤해",
445
+ "부정"
446
+ ],
447
+ [
448
+ "조금만 더 안아줘",
449
+ "부정"
450
+ ],
451
+ [
452
+ "나 좀 위로해줘",
453
+ "부정"
454
+ ],
455
+ [
456
+ "나 기다리는 중",
457
+ "부정"
458
+ ],
459
+ [
460
+ "외로워서 눈물이 나",
461
+ "부정"
462
+ ],
463
+ [
464
+ "나 상처받았어, 너무 두려워...ㅠㅡㅠ",
465
+ "부정"
466
+ ],
467
+ [
468
+ "나 놀래쪄ㅠㅡㅠ 힝구힝구..",
469
+ "부정"
470
+ ],
471
+ [
472
+ "무셔워... 안아죠~~~",
473
+ "부정"
474
+ ],
475
+ [
476
+ "너무 슬퍼서 맘이 아파... 안아줘...",
477
+ "부정"
478
+ ],
479
+ [
480
+ "나 기분이 너무 안 좋아... 어떻게 해줄래?",
481
+ "부정"
482
+ ],
483
+ [
484
+ "힝...미안해...",
485
+ "부정"
486
+ ],
487
+ [
488
+ "불안해, 곁에 있어줘.",
489
+ "부정"
490
+ ],
491
+ [
492
+ "밖으로 나가고 싶어.",
493
+ "중립"
494
+ ],
495
+ [
496
+ "미안해, 실수했어.",
497
+ "부정"
498
+ ],
499
+ [
500
+ "너무 슬퍼, 위로 좀 해줘.",
501
+ "부정"
502
+ ],
503
+ [
504
+ "스트레스 받았어, 도와줘.",
505
+ "부정"
506
+ ],
507
+ [
508
+ "내가 불안해, 붙어있어줘.",
509
+ "부정"
510
+ ],
511
+ [
512
+ "너무 외로워, 애정을 보여줘.",
513
+ "부정"
514
+ ],
515
+ [
516
+ "산책 좀 가고 싶어.",
517
+ "중립"
518
+ ],
519
+ [
520
+ "정말 슬퍼, 안아줘.",
521
+ "부정"
522
+ ],
523
+ [
524
+ "스트레스가 너무 많아, 안아줘.",
525
+ "부정"
526
+ ]
527
+ ]
528
+ }
.ipynb_checkpoints/text_mapping_example-checkpoint.ipynb ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 13,
6
+ "id": "8f925fb7-86ba-487f-ab85-88754d777860",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import json\n",
13
+ "with open(\"text/text_label.json\",\"r\",encoding='utf-8') as f:\n",
14
+ " text_label = json.load(f)"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 14,
20
+ "id": "d2c0a048-1db7-4236-9f26-539ed31d3d27",
21
+ "metadata": {
22
+ "tags": []
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "import random\n",
27
+ "random.seed(0)\n",
28
+ "def post_process(model_output,text_label):\n",
29
+ " text_list = text_label[model_output]\n",
30
+ " text,sent = random.sample(text_list,1)[0]\n",
31
+ " return {'label' : model_output,\n",
32
+ " 'text' : text,\n",
33
+ " 'sentiment' : sent}"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 15,
39
+ "id": "f8ca0ad8-bc0c-4766-8e13-fe093c5290df",
40
+ "metadata": {
41
+ "tags": []
42
+ },
43
+ "outputs": [
44
+ {
45
+ "data": {
46
+ "text/plain": [
47
+ "{'label': 'bark', 'text': '아니야, 아니야!', 'sentiment': '부정'}"
48
+ ]
49
+ },
50
+ "execution_count": 15,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ }
54
+ ],
55
+ "source": [
56
+ "model_output = 'bark'\n",
57
+ "post_process(model_output,text_label)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "da690a64-4dea-4b2a-89c1-23ea8bad955c",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": []
67
+ }
68
+ ],
69
+ "metadata": {
70
+ "kernelspec": {
71
+ "display_name": "Python 3 (ipykernel)",
72
+ "language": "python",
73
+ "name": "python3"
74
+ },
75
+ "language_info": {
76
+ "codemirror_mode": {
77
+ "name": "ipython",
78
+ "version": 3
79
+ },
80
+ "file_extension": ".py",
81
+ "mimetype": "text/x-python",
82
+ "name": "python",
83
+ "nbconvert_exporter": "python",
84
+ "pygments_lexer": "ipython3",
85
+ "version": "3.10.8"
86
+ }
87
+ },
88
+ "nbformat": 4,
89
+ "nbformat_minor": 5
90
+ }
app.py CHANGED
@@ -13,7 +13,10 @@ MODEL_NAME = "Gae8J/gaepago-20"
13
  DATASET_NAME = "Gae8J/modeling_v1"
14
 
15
  # Import Model & feature extractor
16
- model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
 
 
 
17
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
18
 
19
  # 모델 cpu로 변경하여 진행
@@ -27,9 +30,12 @@ def gaepago_fn(tmp_audio_dir):
27
  ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
28
  ,return_tensors="pt")
29
  with torch.no_grad():
30
- logits = model(**inputs).logits
 
 
 
31
  predicted_class_ids = torch.argmax(logits).item()
32
- predicted_label = model.config.id2label[predicted_class_ids]
33
 
34
  return predicted_label
35
 
@@ -47,4 +53,4 @@ with main_api:
47
  b1.click(gaepago_fn, inputs=audio, outputs=transcription)
48
  # examples = gr.Examples(examples=example_list,
49
  # inputs=[audio])
50
- main_api.launch()
 
13
  DATASET_NAME = "Gae8J/modeling_v1"
14
 
15
  # Import Model & feature extractor
16
+ # model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
17
+ from transformers import AutoConfig
18
+ config = AutoConfig.from_pretrained(MODEL_NAME)
19
+ model = torch.jit.load(f"./model/gaepago-20-lite/model_quant_int8.pt")
20
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
21
 
22
  # 모델 cpu로 변경하여 진행
 
30
  ,sampling_rate=audio_dataset[0]["audio"]["sampling_rate"]
31
  ,return_tensors="pt")
32
  with torch.no_grad():
33
+ # logits = model(**inputs).logits
34
+ logits = model(**inputs)["logits"]
35
+ # predicted_class_ids = torch.argmax(logits).item()
36
+ # predicted_label = model.config.id2label[predicted_class_ids]
37
  predicted_class_ids = torch.argmax(logits).item()
38
+ predicted_label = config.id2label[predicted_class_ids]
39
 
40
  return predicted_label
41
 
 
53
  b1.click(gaepago_fn, inputs=audio, outputs=transcription)
54
  # examples = gr.Examples(examples=example_list,
55
  # inputs=[audio])
56
+ main_api.launch(share=True)
eval_and_inference.ipynb ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "544a588c-68ff-440f-be5c-389f1f02a0b7",
6
+ "metadata": {},
7
+ "source": [
8
+ "# example"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "7ef8c97c-cefd-4905-8d63-af303c412d1a",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "MODEL_NAME = \"gaepago-20\"\n",
19
+ "DATASET_NAME = \"Gae8J/modeling_v1\""
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "044499ce-7821-4b59-9f4b-5971b6a24cce",
25
+ "metadata": {},
26
+ "source": [
27
+ "## load dataset (test data)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "e827e3bb-820d-46b3-b2e8-fdb97787bde1",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/Gae8J___parquet/Gae8J--modeling_v1-b480c78c61a26816/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
41
+ ]
42
+ },
43
+ {
44
+ "data": {
45
+ "application/vnd.jupyter.widget-view+json": {
46
+ "model_id": "f078fd108d2044b48a961bee6ed49747",
47
+ "version_major": 2,
48
+ "version_minor": 0
49
+ },
50
+ "text/plain": [
51
+ " 0%| | 0/3 [00:00<?, ?it/s]"
52
+ ]
53
+ },
54
+ "metadata": {},
55
+ "output_type": "display_data"
56
+ }
57
+ ],
58
+ "source": [
59
+ "from datasets import load_dataset, Audio\n",
60
+ "\n",
61
+ "dataset = load_dataset(DATASET_NAME)\n",
62
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
63
+ "test_data = dataset['test']\n",
64
+ "sampling_rate = test_data.features[\"audio\"].sampling_rate"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "id": "d0c16b3d-32dd-4e61-86bd-e21232840e98",
70
+ "metadata": {},
71
+ "source": [
72
+ "## run"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": 5,
78
+ "id": "d504778d-4ba3-43d3-b22b-76ce838a5edf",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "from transformers import AutoModelForAudioClassification\n",
83
+ "from transformers import AutoFeatureExtractor\n",
84
+ "import torch\n",
85
+ "\n",
86
+ "model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)\n",
87
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)\n",
88
+ "\n",
89
+ "preds = []\n",
90
+ "gts = []\n",
91
+ "for i in range(len(test_data)):\n",
92
+ " inputs = feature_extractor(test_data[i][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
93
+ " with torch.no_grad():\n",
94
+ " logits = model(**inputs).logits\n",
95
+ " predicted_class_ids = torch.argmax(logits).item()\n",
96
+ " predicted_label = model.config.id2label[predicted_class_ids]\n",
97
+ " preds.append(predicted_label)\n",
98
+ " gts.append(model.config.id2label[test_data[i]['label']])"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "f200bec5-c2d9-4549-8bb8-1400c484f499",
104
+ "metadata": {},
105
+ "source": [
106
+ "## performance"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 6,
112
+ "id": "be97683d-da60-4d23-abc9-0be9b86cd636",
113
+ "metadata": {},
114
+ "outputs": [
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ " precision recall f1-score support\n",
120
+ "\n",
121
+ " bark 0.56 0.62 0.59 8\n",
122
+ " growling 1.00 0.83 0.91 6\n",
123
+ " howl 0.75 0.86 0.80 7\n",
124
+ " panting 1.00 0.80 0.89 10\n",
125
+ " whimper 0.38 0.43 0.40 7\n",
126
+ "\n",
127
+ " accuracy 0.71 38\n",
128
+ " macro avg 0.74 0.71 0.72 38\n",
129
+ "weighted avg 0.75 0.71 0.72 38\n",
130
+ "\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "from sklearn.metrics import classification_report\n",
136
+ "test_performance = classification_report(gts, preds)\n",
137
+ "print(test_performance)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "id": "ea3ee48d-19c7-4f9d-9c2c-4b03d4748acb",
143
+ "metadata": {},
144
+ "source": [
145
+ "## load dataset (validation data)"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 7,
151
+ "id": "33e5051e-75a2-4523-905c-fe1dbc81eda2",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "name": "stderr",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "WARNING:datasets.builder:Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/Gae8J___parquet/Gae8J--modeling_v1-b480c78c61a26816/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
159
+ ]
160
+ },
161
+ {
162
+ "data": {
163
+ "application/vnd.jupyter.widget-view+json": {
164
+ "model_id": "cf5cfe439c174b8284b4668419af6dca",
165
+ "version_major": 2,
166
+ "version_minor": 0
167
+ },
168
+ "text/plain": [
169
+ " 0%| | 0/3 [00:00<?, ?it/s]"
170
+ ]
171
+ },
172
+ "metadata": {},
173
+ "output_type": "display_data"
174
+ }
175
+ ],
176
+ "source": [
177
+ "from datasets import load_dataset, Audio\n",
178
+ "\n",
179
+ "dataset = load_dataset(DATASET_NAME)\n",
180
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
181
+ "test_data = dataset['validation']\n",
182
+ "sampling_rate = test_data.features[\"audio\"].sampling_rate"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "id": "36bee3b3-e66f-46dc-8030-cef3cb62ff97",
188
+ "metadata": {},
189
+ "source": [
190
+ "## run"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 9,
196
+ "id": "914a471c-5d76-482b-a4f3-3c5eeebdd697",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "from transformers import AutoModelForAudioClassification\n",
201
+ "import torch\n",
202
+ "\n",
203
+ "model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)\n",
204
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)\n",
205
+ "\n",
206
+ "preds = []\n",
207
+ "gts = []\n",
208
+ "for i in range(len(test_data)):\n",
209
+ " inputs = feature_extractor(test_data[i][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
210
+ " with torch.no_grad():\n",
211
+ " logits = model(**inputs).logits\n",
212
+ " predicted_class_ids = torch.argmax(logits).item()\n",
213
+ " predicted_label = model.config.id2label[predicted_class_ids]\n",
214
+ " preds.append(predicted_label)\n",
215
+ " gts.append(model.config.id2label[test_data[i]['label']])"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "4f1d5bab-4f88-4628-918e-d14b29c2143b",
221
+ "metadata": {},
222
+ "source": [
223
+ "## performance"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": 10,
229
+ "id": "26e0c704-b5b6-4bf0-8b58-1e3615b76cb7",
230
+ "metadata": {},
231
+ "outputs": [
232
+ {
233
+ "name": "stdout",
234
+ "output_type": "stream",
235
+ "text": [
236
+ " precision recall f1-score support\n",
237
+ "\n",
238
+ " bark 0.75 0.67 0.71 9\n",
239
+ " growling 1.00 0.71 0.83 7\n",
240
+ " howl 0.86 0.86 0.86 7\n",
241
+ " panting 1.00 0.70 0.82 10\n",
242
+ " whimper 0.54 1.00 0.70 7\n",
243
+ "\n",
244
+ " accuracy 0.78 40\n",
245
+ " macro avg 0.83 0.79 0.78 40\n",
246
+ "weighted avg 0.84 0.78 0.78 40\n",
247
+ "\n"
248
+ ]
249
+ }
250
+ ],
251
+ "source": [
252
+ "from sklearn.metrics import classification_report\n",
253
+ "valid_performance = classification_report(gts, preds)\n",
254
+ "print(valid_performance)"
255
+ ]
256
+ }
257
+ ],
258
+ "metadata": {
259
+ "kernelspec": {
260
+ "display_name": "g3p8",
261
+ "language": "python",
262
+ "name": "g3p8"
263
+ },
264
+ "language_info": {
265
+ "codemirror_mode": {
266
+ "name": "ipython",
267
+ "version": 3
268
+ },
269
+ "file_extension": ".py",
270
+ "mimetype": "text/x-python",
271
+ "name": "python",
272
+ "nbconvert_exporter": "python",
273
+ "pygments_lexer": "ipython3",
274
+ "version": "3.7.9"
275
+ }
276
+ },
277
+ "nbformat": 4,
278
+ "nbformat_minor": 5
279
+ }
eval_and_inference_lite_v1.ipynb ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "544a588c-68ff-440f-be5c-389f1f02a0b7",
6
+ "metadata": {},
7
+ "source": [
8
+ "# example"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "7ef8c97c-cefd-4905-8d63-af303c412d1a",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "MODEL_NAME = \"gaepago-20-lite\"\n",
19
+ "DATASET_NAME = \"Gae8J/modeling_v1\""
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "044499ce-7821-4b59-9f4b-5971b6a24cce",
25
+ "metadata": {},
26
+ "source": [
27
+ "## load dataset (test data)"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "id": "e827e3bb-820d-46b3-b2e8-fdb97787bde1",
34
+ "metadata": {},
35
+ "outputs": [
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "WARNING:datasets.builder:Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/Gae8J___parquet/Gae8J--modeling_v1-b480c78c61a26816/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
41
+ ]
42
+ },
43
+ {
44
+ "data": {
45
+ "application/vnd.jupyter.widget-view+json": {
46
+ "model_id": "4438f0b33464423b92fecc698c1935e5",
47
+ "version_major": 2,
48
+ "version_minor": 0
49
+ },
50
+ "text/plain": [
51
+ " 0%| | 0/3 [00:00<?, ?it/s]"
52
+ ]
53
+ },
54
+ "metadata": {},
55
+ "output_type": "display_data"
56
+ }
57
+ ],
58
+ "source": [
59
+ "from datasets import load_dataset, Audio\n",
60
+ "from transformers import AutoFeatureExtractor\n",
61
+ "dataset = load_dataset(DATASET_NAME)\n",
62
+ "dataset = dataset.cast_column(\"audio\", Audio(sampling_rate=16000))\n",
63
+ "test_data = dataset['test']\n",
64
+ "sampling_rate = test_data.features[\"audio\"].sampling_rate\n",
65
+ "feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 7,
71
+ "id": "779c547a-7e27-4481-8a66-fd9900e41964",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "from transformers import AutoConfig\n",
76
+ "config = AutoConfig.from_pretrained(MODEL_NAME)"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "id": "03659af7-3d90-4431-a4ea-a8d99e93602f",
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "import torch"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 4,
92
+ "id": "0f58cfcf-ba2d-45e4-b4e9-87df88e9dbad",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "loaded_quantized_model = torch.jit.load(\"gaepago-20-lite/model_quant_int8.pt\")"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "id": "52212656-a3e9-4bd2-ac2d-427acb5795c6",
102
+ "metadata": {},
103
+ "source": [
104
+ "## 모델결과"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 9,
110
+ "id": "3d4f5365-d6f1-4163-9c47-ce8c89e13884",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "preds = []\n",
115
+ "gts = []\n",
116
+ "# quant_logits_list = []\n",
117
+ "for i in range(len(test_data)):\n",
118
+ " inputs = feature_extractor(test_data[i][\"audio\"][\"array\"], sampling_rate=sampling_rate, return_tensors=\"pt\")\n",
119
+ " with torch.no_grad():\n",
120
+ " logits = loaded_quantized_model(**inputs)['logits']\n",
121
+ "# quant_logits_list.append(logits)\n",
122
+ " predicted_class_ids = torch.argmax(logits).item()\n",
123
+ " predicted_label = config.id2label[predicted_class_ids]\n",
124
+ " preds.append(predicted_label)\n",
125
+ " gts.append(config.id2label[test_data[i]['label']])"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 10,
131
+ "id": "93b3c424-bab6-4774-915e-9e9f534f762d",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "name": "stdout",
136
+ "output_type": "stream",
137
+ "text": [
138
+ " precision recall f1-score support\n",
139
+ "\n",
140
+ " bark 0.5556 0.6250 0.5882 8\n",
141
+ " growling 1.0000 0.8333 0.9091 6\n",
142
+ " howl 0.7500 0.8571 0.8000 7\n",
143
+ " panting 1.0000 0.8000 0.8889 10\n",
144
+ " whimper 0.3750 0.4286 0.4000 7\n",
145
+ "\n",
146
+ " accuracy 0.7105 38\n",
147
+ " macro avg 0.7361 0.7088 0.7172 38\n",
148
+ "weighted avg 0.7452 0.7105 0.7224 38\n",
149
+ "\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "from sklearn.metrics import classification_report\n",
155
+ "test_performance = classification_report(gts, preds,digits=4)\n",
156
+ "print(test_performance)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": null,
162
+ "id": "99a3ea38-54c8-4aed-9bbf-12f98bf09dc5",
163
+ "metadata": {},
164
+ "outputs": [],
165
+ "source": []
166
+ }
167
+ ],
168
+ "metadata": {
169
+ "kernelspec": {
170
+ "display_name": "Python 3 (ipykernel)",
171
+ "language": "python",
172
+ "name": "python3"
173
+ },
174
+ "language_info": {
175
+ "codemirror_mode": {
176
+ "name": "ipython",
177
+ "version": 3
178
+ },
179
+ "file_extension": ".py",
180
+ "mimetype": "text/x-python",
181
+ "name": "python",
182
+ "nbconvert_exporter": "python",
183
+ "pygments_lexer": "ipython3",
184
+ "version": "3.8.16"
185
+ }
186
+ },
187
+ "nbformat": 4,
188
+ "nbformat_minor": 5
189
+ }
model/gaepago-20-lite/.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gaepago-20",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForSequenceClassification"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_norm": "group",
51
+ "feat_proj_dropout": 0.1,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "freeze_feat_extract_train": true,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.1,
57
+ "hidden_size": 768,
58
+ "id2label": {
59
+ "0": "howl",
60
+ "1": "growling",
61
+ "2": "bark",
62
+ "3": "panting",
63
+ "4": "whimper"
64
+ },
65
+ "initializer_range": 0.02,
66
+ "intermediate_size": 3072,
67
+ "label2id": {
68
+ "bark": "2",
69
+ "growling": "1",
70
+ "howl": "0",
71
+ "panting": "3",
72
+ "whimper": "4"
73
+ },
74
+ "layer_norm_eps": 1e-05,
75
+ "layerdrop": 0.0,
76
+ "mask_channel_length": 10,
77
+ "mask_channel_min_space": 1,
78
+ "mask_channel_other": 0.0,
79
+ "mask_channel_prob": 0.0,
80
+ "mask_channel_selection": "static",
81
+ "mask_feature_length": 10,
82
+ "mask_feature_min_masks": 0,
83
+ "mask_feature_prob": 0.0,
84
+ "mask_time_length": 10,
85
+ "mask_time_min_masks": 2,
86
+ "mask_time_min_space": 1,
87
+ "mask_time_other": 0.0,
88
+ "mask_time_prob": 0.05,
89
+ "mask_time_selection": "static",
90
+ "model_type": "wav2vec2",
91
+ "no_mask_channel_overlap": false,
92
+ "no_mask_time_overlap": false,
93
+ "num_adapter_layers": 3,
94
+ "num_attention_heads": 12,
95
+ "num_codevector_groups": 2,
96
+ "num_codevectors_per_group": 320,
97
+ "num_conv_pos_embedding_groups": 16,
98
+ "num_conv_pos_embeddings": 128,
99
+ "num_feat_extract_layers": 7,
100
+ "num_hidden_layers": 12,
101
+ "num_negatives": 100,
102
+ "output_hidden_size": 768,
103
+ "pad_token_id": 0,
104
+ "proj_codevector_dim": 256,
105
+ "tdnn_dilation": [
106
+ 1,
107
+ 2,
108
+ 3,
109
+ 1,
110
+ 1
111
+ ],
112
+ "tdnn_dim": [
113
+ 512,
114
+ 512,
115
+ 512,
116
+ 512,
117
+ 1500
118
+ ],
119
+ "tdnn_kernel": [
120
+ 5,
121
+ 3,
122
+ 3,
123
+ 1,
124
+ 1
125
+ ],
126
+ "torch_dtype": "float32",
127
+ "transformers_version": "4.29.2",
128
+ "use_weighted_layer_sum": false,
129
+ "vocab_size": 32,
130
+ "xvector_output_dim": 512
131
+ }
model/gaepago-20-lite/config.json ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "gaepago-20",
3
+ "activation_dropout": 0.0,
4
+ "adapter_kernel_size": 3,
5
+ "adapter_stride": 2,
6
+ "add_adapter": false,
7
+ "apply_spec_augment": true,
8
+ "architectures": [
9
+ "Wav2Vec2ForSequenceClassification"
10
+ ],
11
+ "attention_dropout": 0.1,
12
+ "bos_token_id": 1,
13
+ "classifier_proj_size": 256,
14
+ "codevector_dim": 256,
15
+ "contrastive_logits_temperature": 0.1,
16
+ "conv_bias": false,
17
+ "conv_dim": [
18
+ 512,
19
+ 512,
20
+ 512,
21
+ 512,
22
+ 512,
23
+ 512,
24
+ 512
25
+ ],
26
+ "conv_kernel": [
27
+ 10,
28
+ 3,
29
+ 3,
30
+ 3,
31
+ 3,
32
+ 2,
33
+ 2
34
+ ],
35
+ "conv_stride": [
36
+ 5,
37
+ 2,
38
+ 2,
39
+ 2,
40
+ 2,
41
+ 2,
42
+ 2
43
+ ],
44
+ "ctc_loss_reduction": "sum",
45
+ "ctc_zero_infinity": false,
46
+ "diversity_loss_weight": 0.1,
47
+ "do_stable_layer_norm": false,
48
+ "eos_token_id": 2,
49
+ "feat_extract_activation": "gelu",
50
+ "feat_extract_norm": "group",
51
+ "feat_proj_dropout": 0.1,
52
+ "feat_quantizer_dropout": 0.0,
53
+ "final_dropout": 0.0,
54
+ "freeze_feat_extract_train": true,
55
+ "hidden_act": "gelu",
56
+ "hidden_dropout": 0.1,
57
+ "hidden_size": 768,
58
+ "id2label": {
59
+ "0": "howl",
60
+ "1": "growling",
61
+ "2": "bark",
62
+ "3": "panting",
63
+ "4": "whimper"
64
+ },
65
+ "initializer_range": 0.02,
66
+ "intermediate_size": 3072,
67
+ "label2id": {
68
+ "bark": "2",
69
+ "growling": "1",
70
+ "howl": "0",
71
+ "panting": "3",
72
+ "whimper": "4"
73
+ },
74
+ "layer_norm_eps": 1e-05,
75
+ "layerdrop": 0.0,
76
+ "mask_channel_length": 10,
77
+ "mask_channel_min_space": 1,
78
+ "mask_channel_other": 0.0,
79
+ "mask_channel_prob": 0.0,
80
+ "mask_channel_selection": "static",
81
+ "mask_feature_length": 10,
82
+ "mask_feature_min_masks": 0,
83
+ "mask_feature_prob": 0.0,
84
+ "mask_time_length": 10,
85
+ "mask_time_min_masks": 2,
86
+ "mask_time_min_space": 1,
87
+ "mask_time_other": 0.0,
88
+ "mask_time_prob": 0.05,
89
+ "mask_time_selection": "static",
90
+ "model_type": "wav2vec2",
91
+ "no_mask_channel_overlap": false,
92
+ "no_mask_time_overlap": false,
93
+ "num_adapter_layers": 3,
94
+ "num_attention_heads": 12,
95
+ "num_codevector_groups": 2,
96
+ "num_codevectors_per_group": 320,
97
+ "num_conv_pos_embedding_groups": 16,
98
+ "num_conv_pos_embeddings": 128,
99
+ "num_feat_extract_layers": 7,
100
+ "num_hidden_layers": 12,
101
+ "num_negatives": 100,
102
+ "output_hidden_size": 768,
103
+ "pad_token_id": 0,
104
+ "proj_codevector_dim": 256,
105
+ "tdnn_dilation": [
106
+ 1,
107
+ 2,
108
+ 3,
109
+ 1,
110
+ 1
111
+ ],
112
+ "tdnn_dim": [
113
+ 512,
114
+ 512,
115
+ 512,
116
+ 512,
117
+ 1500
118
+ ],
119
+ "tdnn_kernel": [
120
+ 5,
121
+ 3,
122
+ 3,
123
+ 1,
124
+ 1
125
+ ],
126
+ "torch_dtype": "float32",
127
+ "transformers_version": "4.29.2",
128
+ "use_weighted_layer_sum": false,
129
+ "vocab_size": 32,
130
+ "xvector_output_dim": 512
131
+ }
model/gaepago-20-lite/model_quant_int8.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f76b504bb04245a11ec92b145dcfb53391b2105fa204b082fd5c58a862447769
3
+ size 122374341
model/gaepago-20-lite/preprocessor_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_extractor_type": "Wav2Vec2FeatureExtractor",
4
+ "feature_size": 1,
5
+ "padding_side": "right",
6
+ "padding_value": 0.0,
7
+ "return_attention_mask": false,
8
+ "sampling_rate": 16000
9
+ }
requirements.txt CHANGED
@@ -3,15 +3,19 @@ aiohttp==3.8.4
3
  aiosignal==1.3.1
4
  altair==5.0.1
5
  anyio==3.7.0
 
6
  async-timeout==4.0.2
7
  attrs==23.1.0
 
8
  certifi==2023.5.7
 
9
  charset-normalizer==3.1.0
10
  click==8.1.3
11
  cmake==3.26.4
12
  contourpy==1.1.0
13
  cycler==0.11.0
14
  datasets==2.13.0
 
15
  dill==0.3.6
16
  exceptiongroup==1.1.1
17
  fastapi==0.97.0
@@ -27,21 +31,28 @@ httpcore==0.17.2
27
  httpx==0.24.1
28
  huggingface-hub==0.15.1
29
  idna==3.4
 
30
  importlib-resources==5.12.0
31
  Jinja2==3.1.2
 
32
  jsonschema==4.17.3
33
  kiwisolver==1.4.4
 
 
34
  linkify-it-py==2.0.2
35
  lit==16.0.6
 
36
  markdown-it-py==2.2.0
37
  MarkupSafe==2.1.3
38
  matplotlib==3.7.1
39
  mdit-py-plugins==0.3.3
40
  mdurl==0.1.2
41
  mpmath==1.3.0
 
42
  multidict==6.0.4
43
  multiprocess==0.70.14
44
  networkx==3.1
 
45
  numpy==1.24.3
46
  nvidia-cublas-cu11==11.10.3.66
47
  nvidia-cuda-cupti-cu11==11.7.101
@@ -59,7 +70,9 @@ packaging==23.1
59
  pandas==2.0.2
60
  Pillow==9.5.0
61
  pkgutil_resolve_name==1.3.10
 
62
  pyarrow==12.0.1
 
63
  pydantic==1.10.9
64
  pydub==0.25.1
65
  Pygments==2.15.1
@@ -72,11 +85,16 @@ PyYAML==6.0
72
  regex==2023.6.3
73
  requests==2.31.0
74
  safetensors==0.3.1
 
 
75
  semantic-version==2.10.0
76
  six==1.16.0
77
  sniffio==1.3.0
 
 
78
  starlette==0.27.0
79
  sympy==1.12
 
80
  tokenizers==0.13.3
81
  toolz==0.12.0
82
  torch==2.0.1
 
3
  aiosignal==1.3.1
4
  altair==5.0.1
5
  anyio==3.7.0
6
+ appdirs==1.4.4
7
  async-timeout==4.0.2
8
  attrs==23.1.0
9
+ audioread==3.0.0
10
  certifi==2023.5.7
11
+ cffi==1.15.1
12
  charset-normalizer==3.1.0
13
  click==8.1.3
14
  cmake==3.26.4
15
  contourpy==1.1.0
16
  cycler==0.11.0
17
  datasets==2.13.0
18
+ decorator==5.1.1
19
  dill==0.3.6
20
  exceptiongroup==1.1.1
21
  fastapi==0.97.0
 
31
  httpx==0.24.1
32
  huggingface-hub==0.15.1
33
  idna==3.4
34
+ importlib-metadata==6.7.0
35
  importlib-resources==5.12.0
36
  Jinja2==3.1.2
37
+ joblib==1.2.0
38
  jsonschema==4.17.3
39
  kiwisolver==1.4.4
40
+ lazy_loader==0.2
41
+ librosa==0.10.0.post2
42
  linkify-it-py==2.0.2
43
  lit==16.0.6
44
+ llvmlite==0.40.1rc1
45
  markdown-it-py==2.2.0
46
  MarkupSafe==2.1.3
47
  matplotlib==3.7.1
48
  mdit-py-plugins==0.3.3
49
  mdurl==0.1.2
50
  mpmath==1.3.0
51
+ msgpack==1.0.5
52
  multidict==6.0.4
53
  multiprocess==0.70.14
54
  networkx==3.1
55
+ numba==0.57.0
56
  numpy==1.24.3
57
  nvidia-cublas-cu11==11.10.3.66
58
  nvidia-cuda-cupti-cu11==11.7.101
 
70
  pandas==2.0.2
71
  Pillow==9.5.0
72
  pkgutil_resolve_name==1.3.10
73
+ pooch==1.6.0
74
  pyarrow==12.0.1
75
+ pycparser==2.21
76
  pydantic==1.10.9
77
  pydub==0.25.1
78
  Pygments==2.15.1
 
85
  regex==2023.6.3
86
  requests==2.31.0
87
  safetensors==0.3.1
88
+ scikit-learn==1.2.2
89
+ scipy==1.10.1
90
  semantic-version==2.10.0
91
  six==1.16.0
92
  sniffio==1.3.0
93
+ soundfile==0.12.1
94
+ soxr==0.3.5
95
  starlette==0.27.0
96
  sympy==1.12
97
+ threadpoolctl==3.1.0
98
  tokenizers==0.13.3
99
  toolz==0.12.0
100
  torch==2.0.1
text_label.json ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bark": [
3
+ [
4
+ "너무 신나서 어쩌지?",
5
+ "긍정"
6
+ ],
7
+ [
8
+ "집사, 놀아줘!",
9
+ "긍정"
10
+ ],
11
+ [
12
+ "지금 너무 신나!",
13
+ "긍정"
14
+ ],
15
+ [
16
+ "누가 왔나 봐!",
17
+ "긍정"
18
+ ],
19
+ [
20
+ "놀아줘!! 놀아달란말이야!!",
21
+ "긍정"
22
+ ],
23
+ [
24
+ "안녕 🐶",
25
+ "긍정"
26
+ ],
27
+ [
28
+ "난 너를 좋아하는 걸, 그런데 너는 나를 좋아해?",
29
+ "긍정"
30
+ ],
31
+ [
32
+ "주목해줘! 놀자!",
33
+ "긍정"
34
+ ],
35
+ [
36
+ "놀이 시간이야, 같이 놀자!",
37
+ "긍정"
38
+ ],
39
+ [
40
+ "다가오지마!",
41
+ "부정"
42
+ ],
43
+ [
44
+ "뭔가 이상한 소리 들려!",
45
+ "부정"
46
+ ],
47
+ [
48
+ "경계해, 경계해!",
49
+ "부정"
50
+ ],
51
+ [
52
+ "아니야, 아니야!",
53
+ "부정"
54
+ ],
55
+ [
56
+ "건들지마!!!!",
57
+ "부정"
58
+ ],
59
+ [
60
+ "뭔가 불안해, 도와줘!",
61
+ "부정"
62
+ ],
63
+ [
64
+ "주인~ 뭐해~?",
65
+ "중립"
66
+ ],
67
+ [
68
+ "밖에 뭐가 있는 거 같아!",
69
+ "중립"
70
+ ],
71
+ [
72
+ "이리 와봐!",
73
+ "중립"
74
+ ],
75
+ [
76
+ "날 보고있어?",
77
+ "중립"
78
+ ],
79
+ [
80
+ "밖에 뭐 있어?",
81
+ "중립"
82
+ ],
83
+ [
84
+ "이거 내꺼야!",
85
+ "중립"
86
+ ],
87
+ [
88
+ "물 마실래, 마실 것 좀 줘.",
89
+ "중립"
90
+ ],
91
+ [
92
+ "목이 말라, 물 좀 줄래?",
93
+ "중립"
94
+ ]
95
+ ],
96
+ "growling": [
97
+ [
98
+ "나 좀 내버려 둬!",
99
+ "부정"
100
+ ],
101
+ [
102
+ "더 이상 다가오지마!",
103
+ "부정"
104
+ ],
105
+ [
106
+ "너무 까다로워!",
107
+ "부정"
108
+ ],
109
+ [
110
+ "내가 경계하고 있어!",
111
+ "부정"
112
+ ],
113
+ [
114
+ "빨리 이리 와!",
115
+ "부정"
116
+ ],
117
+ [
118
+ "나 너무 화나!",
119
+ "부정"
120
+ ],
121
+ [
122
+ "나 싸울 준비됐어!",
123
+ "부정"
124
+ ],
125
+ [
126
+ "그만 좀 해!",
127
+ "부정"
128
+ ],
129
+ [
130
+ "내게 장난치지마!",
131
+ "부정"
132
+ ],
133
+ [
134
+ "나 지금 너무 짜증나!",
135
+ "부정"
136
+ ],
137
+ [
138
+ "나 지금 안 좋아!",
139
+ "부정"
140
+ ],
141
+ [
142
+ "다가오지마!",
143
+ "부정"
144
+ ],
145
+ [
146
+ "너에게 화난 거야!",
147
+ "부정"
148
+ ],
149
+ [
150
+ "좀 멀리 가!",
151
+ "부정"
152
+ ],
153
+ [
154
+ "나 싸우려고 준비됐어!",
155
+ "부정"
156
+ ],
157
+ [
158
+ "한번 더 건드리면 물어버릴거야!!!",
159
+ "부정"
160
+ ],
161
+ [
162
+ "나한테 이렇게 위협적으로 다가오지마!",
163
+ "부정"
164
+ ],
165
+ [
166
+ "나의 영역을 침범하면 안돼! 이해해줘!",
167
+ "부정"
168
+ ],
169
+ [
170
+ "그만 좀 귀찮게 해! 내가 분명히 경고했잖아!",
171
+ "부정"
172
+ ],
173
+ [
174
+ "불편해, 물러서줘.",
175
+ "부정"
176
+ ],
177
+ [
178
+ "경고하는 거야, 가까이 오지 마.",
179
+ "부정"
180
+ ],
181
+ [
182
+ "좀 너무 가까워, 거리 좀 둬.",
183
+ "부정"
184
+ ],
185
+ [
186
+ "나를 방해하지 마, 신경 써줘.",
187
+ "부정"
188
+ ],
189
+ [
190
+ "내가 불편해, 거리 좀 두고 있어.",
191
+ "부정"
192
+ ],
193
+ [
194
+ "가까이 오지 마.",
195
+ "부정"
196
+ ],
197
+ [
198
+ "나를 방해하지 마, 존중해줘. Respect Me!!",
199
+ "부정"
200
+ ]
201
+ ],
202
+ "howl": [
203
+ [
204
+ "나 여기있어, 봐줘!",
205
+ "중립"
206
+ ],
207
+ [
208
+ "너 어디 갔어?!",
209
+ "중립"
210
+ ],
211
+ [
212
+ "나 너무 외로워!",
213
+ "중립"
214
+ ],
215
+ [
216
+ "이리 와봐, 나 있는 곳으로!",
217
+ "중립"
218
+ ],
219
+ [
220
+ "너 없으면 너무 심심해!",
221
+ "중립"
222
+ ],
223
+ [
224
+ "나도 같이 가고 싶어!",
225
+ "중립"
226
+ ],
227
+ [
228
+ "나 심심해",
229
+ "중립"
230
+ ],
231
+ [
232
+ "어디야? 나 찾아봐!",
233
+ "중립"
234
+ ],
235
+ [
236
+ "언제 오려고 그래?",
237
+ "중립"
238
+ ],
239
+ [
240
+ "나는 여기 있는데!",
241
+ "중립"
242
+ ],
243
+ [
244
+ "빨리 돌아와줘!",
245
+ "중립"
246
+ ],
247
+ [
248
+ "나 혼자 남겨두지 마!",
249
+ "중립"
250
+ ],
251
+ [
252
+ "나 여기있어!! 나좀 봐줘!!!",
253
+ "중립"
254
+ ],
255
+ [
256
+ "나 잘 보고 있어? 나 괜찮아?",
257
+ "중립"
258
+ ],
259
+ [
260
+ "주인, 나 좀 안아줄 수 있을까?",
261
+ "중립"
262
+ ],
263
+ [
264
+ "외로워, 보고 싶어.",
265
+ "중립"
266
+ ],
267
+ [
268
+ "다른 강아지와 '합창'하고 싶어.",
269
+ "중립"
270
+ ],
271
+ [
272
+ "너를 보고싶어, 언제 와?",
273
+ "중립"
274
+ ],
275
+ [
276
+ "무언가 알려고 하는 중이야.",
277
+ "중립"
278
+ ],
279
+ [
280
+ "다른 강아지들이랑 노래하고 싶어.",
281
+ "긍정"
282
+ ]
283
+ ],
284
+ "panting": [
285
+ [
286
+ "더워~ 에어컨 켜줘.",
287
+ "부정"
288
+ ],
289
+ [
290
+ "운동 후 휴식 중이야.",
291
+ "중립"
292
+ ],
293
+ [
294
+ "숨이 차, 좀 도와줘.",
295
+ "부정"
296
+ ],
297
+ [
298
+ "휴식이 필요해, 좀 쉬자.",
299
+ "부정"
300
+ ],
301
+ [
302
+ "너무 더워, 물 좀 줄래?",
303
+ "부정"
304
+ ],
305
+ [
306
+ "너무 더워, 바람 좀 쐬자.",
307
+ "부정"
308
+ ],
309
+ [
310
+ "힘들게 운동했어, 휴식 좀!",
311
+ "부정"
312
+ ],
313
+ [
314
+ "숨이 차, 쉬는 시간이 필요해.",
315
+ "부정"
316
+ ],
317
+ [
318
+ "휴식이 필요해, 조용히 좀...",
319
+ "부정"
320
+ ],
321
+ [
322
+ "물 좀 마시고 싶어, 줄래?",
323
+ "중립"
324
+ ],
325
+ [
326
+ "많이 뛰어서 힘들어, 휴식이 필요해.",
327
+ "부정"
328
+ ],
329
+ [
330
+ "휴식이 필요해, 좀 더 쉬자.",
331
+ "중립"
332
+ ],
333
+ [
334
+ "너무 더워서 물 좀 마시고 싶어.",
335
+ "중립"
336
+ ],
337
+ [
338
+ "좀 더운 �� 같아, 바람 좀 쐬고 싶어.",
339
+ "중립"
340
+ ],
341
+ [
342
+ "지금 좀 쉴 시간이 필요해, 잠시만 기다려.",
343
+ "중립"
344
+ ],
345
+ [
346
+ "지금 진정할 시간이 필요해!!!",
347
+ "중립"
348
+ ],
349
+ [
350
+ "나 지금 너무 신나",
351
+ "긍정"
352
+ ],
353
+ [
354
+ "너랑 놀면 더 재밌을 것 같아",
355
+ "긍정"
356
+ ],
357
+ [
358
+ "나랑 놀지 않을래?",
359
+ "긍정"
360
+ ],
361
+ [
362
+ "밖에 나가면 재미난 일이 있을 것 같아!",
363
+ "긍정"
364
+ ],
365
+ [
366
+ "오늘은 무슨 일이 있을까? 좋은 일이 생길 것 같아!",
367
+ "긍정"
368
+ ],
369
+ [
370
+ "세상 모든 것들이 반가워~",
371
+ "긍정"
372
+ ],
373
+ [
374
+ "너랑 친해지고 싶어~",
375
+ "긍정"
376
+ ],
377
+ [
378
+ "오늘 기분 아주 나이스~",
379
+ "긍정"
380
+ ],
381
+ [
382
+ "세상에서 제일 좋아!!",
383
+ "긍정"
384
+ ],
385
+ [
386
+ "나 지금 기분이가 좋아~",
387
+ "긍정"
388
+ ],
389
+ [
390
+ "너랑 놀고싶어~",
391
+ "긍정"
392
+ ],
393
+ [
394
+ "오늘 되게 행복한 하루다~",
395
+ "긍정"
396
+ ],
397
+ [
398
+ "오늘 내 생일인가? 너무 행복해><",
399
+ "긍정"
400
+ ],
401
+ [
402
+ "만나서 반가워",
403
+ "긍정"
404
+ ],
405
+ [
406
+ "너는 이름이 뭐니?",
407
+ "긍정"
408
+ ],
409
+ [
410
+ "난 너가 좋아!!",
411
+ "긍정"
412
+ ],
413
+ [
414
+ "나 매우 재밌어",
415
+ "긍정"
416
+ ],
417
+ [
418
+ "나랑 같이 놀러 나가자",
419
+ "긍정"
420
+ ]
421
+ ],
422
+ "whimper": [
423
+ [
424
+ "나 너무 두려워",
425
+ "부정"
426
+ ],
427
+ [
428
+ "나 지금 너무 외로워",
429
+ "부정"
430
+ ],
431
+ [
432
+ "나 너무 슬퍼",
433
+ "부정"
434
+ ],
435
+ [
436
+ "나 좀 안아줘",
437
+ "부정"
438
+ ],
439
+ [
440
+ "나 지금 너무 불편해",
441
+ "부정"
442
+ ],
443
+ [
444
+ "나 너무 피곤해",
445
+ "부정"
446
+ ],
447
+ [
448
+ "조금만 더 안아줘",
449
+ "부정"
450
+ ],
451
+ [
452
+ "나 좀 위로해줘",
453
+ "부정"
454
+ ],
455
+ [
456
+ "나 기다리는 중",
457
+ "부정"
458
+ ],
459
+ [
460
+ "외로워서 눈물이 나",
461
+ "부정"
462
+ ],
463
+ [
464
+ "나 상처받았어, 너무 두려워...ㅠㅡㅠ",
465
+ "부정"
466
+ ],
467
+ [
468
+ "나 놀래쪄ㅠㅡㅠ 힝구힝구..",
469
+ "부정"
470
+ ],
471
+ [
472
+ "무셔워... 안아죠~~~",
473
+ "부정"
474
+ ],
475
+ [
476
+ "너무 슬퍼서 맘이 아파... 안아줘...",
477
+ "부정"
478
+ ],
479
+ [
480
+ "나 기분이 너무 안 좋아... 어떻게 해줄래?",
481
+ "부정"
482
+ ],
483
+ [
484
+ "힝...미안해...",
485
+ "부정"
486
+ ],
487
+ [
488
+ "불안해, 곁에 있어줘.",
489
+ "부정"
490
+ ],
491
+ [
492
+ "밖으로 나가고 싶어.",
493
+ "중립"
494
+ ],
495
+ [
496
+ "미안해, 실수했어.",
497
+ "부정"
498
+ ],
499
+ [
500
+ "너무 슬퍼, 위로 좀 해줘.",
501
+ "부정"
502
+ ],
503
+ [
504
+ "스트레스 받았어, 도와줘.",
505
+ "부정"
506
+ ],
507
+ [
508
+ "내가 불안해, 붙어있어줘.",
509
+ "부정"
510
+ ],
511
+ [
512
+ "너무 외로워, 애정을 보여줘.",
513
+ "부정"
514
+ ],
515
+ [
516
+ "산책 좀 가고 싶어.",
517
+ "중립"
518
+ ],
519
+ [
520
+ "정말 슬퍼, 안아줘.",
521
+ "부정"
522
+ ],
523
+ [
524
+ "스트레스가 너무 많아, 안아줘.",
525
+ "부정"
526
+ ]
527
+ ]
528
+ }
text_mapping_example.ipynb ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 13,
6
+ "id": "8f925fb7-86ba-487f-ab85-88754d777860",
7
+ "metadata": {
8
+ "tags": []
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "import json\n",
13
+ "with open(\"text/text_label.json\",\"r\",encoding='utf-8') as f:\n",
14
+ " text_label = json.load(f)"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 14,
20
+ "id": "d2c0a048-1db7-4236-9f26-539ed31d3d27",
21
+ "metadata": {
22
+ "tags": []
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "import random\n",
27
+ "random.seed(0)\n",
28
+ "def post_process(model_output,text_label):\n",
29
+ " text_list = text_label[model_output]\n",
30
+ " text,sent = random.sample(text_list,1)[0]\n",
31
+ " return {'label' : model_output,\n",
32
+ " 'text' : text,\n",
33
+ " 'sentiment' : sent}"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 15,
39
+ "id": "f8ca0ad8-bc0c-4766-8e13-fe093c5290df",
40
+ "metadata": {
41
+ "tags": []
42
+ },
43
+ "outputs": [
44
+ {
45
+ "data": {
46
+ "text/plain": [
47
+ "{'label': 'bark', 'text': '아니야, 아니야!', 'sentiment': '부정'}"
48
+ ]
49
+ },
50
+ "execution_count": 15,
51
+ "metadata": {},
52
+ "output_type": "execute_result"
53
+ }
54
+ ],
55
+ "source": [
56
+ "model_output = 'bark'\n",
57
+ "post_process(model_output,text_label)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "id": "da690a64-4dea-4b2a-89c1-23ea8bad955c",
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": []
67
+ }
68
+ ],
69
+ "metadata": {
70
+ "kernelspec": {
71
+ "display_name": "Python 3 (ipykernel)",
72
+ "language": "python",
73
+ "name": "python3"
74
+ },
75
+ "language_info": {
76
+ "codemirror_mode": {
77
+ "name": "ipython",
78
+ "version": 3
79
+ },
80
+ "file_extension": ".py",
81
+ "mimetype": "text/x-python",
82
+ "name": "python",
83
+ "nbconvert_exporter": "python",
84
+ "pygments_lexer": "ipython3",
85
+ "version": "3.8.16"
86
+ }
87
+ },
88
+ "nbformat": 4,
89
+ "nbformat_minor": 5
90
+ }