DuyTa commited on
Commit
c6b1960
·
1 Parent(s): 82e8e84
src/EDA.ipynb ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 6,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import glob\n",
10
+ "\n",
11
+ "def count_files_by_extension(path, extension):\n",
12
+ " \"\"\"\n",
13
+ " path : root path to check ,\n",
14
+ " extension : .wav , ...\n",
15
+ " \"\"\"\n",
16
+ "\n",
17
+ " files = glob.glob(f\"{path}/*.{extension}\")\n",
18
+ " return len(files)\n",
19
+ "\n",
20
+ "\n",
21
+ "root_path = \"./vin_data/vlsp2020_train_set_02/\"\n"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 7,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "num_wav_files = count_files_by_extension(root_path, \"wav\")"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 8,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "num_txt_files = count_files_by_extension(root_path, \"txt\")"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 9,
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "name": "stdout",
49
+ "output_type": "stream",
50
+ "text": [
51
+ "Số lượng file WAV: 56427\n",
52
+ "Số lượng file text: 56427\n"
53
+ ]
54
+ }
55
+ ],
56
+ "source": [
57
+ "print(f\"Số lượng file WAV: {num_wav_files}\")\n",
58
+ "print(f\"Số lượng file text: {num_txt_files}\")"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 10,
64
+ "metadata": {},
65
+ "outputs": [
66
+ {
67
+ "name": "stdout",
68
+ "output_type": "stream",
69
+ "text": [
70
+ "Tần số mẫu (sample rate): 16000 Hz\n",
71
+ "Số kênh (channels): 1\n"
72
+ ]
73
+ }
74
+ ],
75
+ "source": [
76
+ "import os\n",
77
+ "import random\n",
78
+ "import wave\n",
79
+ "\n",
80
+ "\n",
81
+ "def get_random_wav_file_info(folder_path):\n",
82
+ " wav_files = glob.glob(f\"{folder_path}/*.wav\")\n",
83
+ " \n",
84
+ " if not wav_files:\n",
85
+ " return None, None\n",
86
+ " \n",
87
+ " random_wav_file = random.choice(wav_files)\n",
88
+ " \n",
89
+ " with wave.open(random_wav_file, 'rb') as wav_file:\n",
90
+ " sample_rate = wav_file.getframerate()\n",
91
+ " channels = wav_file.getnchannels()\n",
92
+ " \n",
93
+ " return sample_rate, channels\n",
94
+ "\n",
95
+ "path_to_wav_folder = \"./vin_data/vlsp2020_train_set_02/\"\n",
96
+ "\n",
97
+ "sample_rate, channels = get_random_wav_file_info(path_to_wav_folder)\n",
98
+ "\n",
99
+ "if sample_rate is not None and channels is not None:\n",
100
+ " print(f\"Tần số mẫu (sample rate): {sample_rate} Hz\")\n",
101
+ " print(f\"Số kênh (channels): {channels}\")\n",
102
+ "else:\n",
103
+ " print(\"Nothing.\")\n"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": 13,
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "import os\n",
113
+ "import csv\n",
114
+ "from tqdm import tqdm\n",
115
+ "\n",
116
+ "def create_csv_from_wav_folder(folder_path, output_csv_file):\n",
117
+ " wav_files = glob.glob(f\"{folder_path}/*.wav\")\n",
118
+ "\n",
119
+ " if not wav_files:\n",
120
+ " print(\"Không có file WAV nào trong thư mục.\")\n",
121
+ " return\n",
122
+ "\n",
123
+ " # Mở tệp CSV đầu ra và tạo bộ đếm số lượng file WAV\n",
124
+ " with open(output_csv_file, mode='w', newline='') as csv_file:\n",
125
+ " csv_writer = csv.writer(csv_file)\n",
126
+ " csv_writer.writerow(['path', 'name','sentence'])\n",
127
+ "\n",
128
+ " for wav_file_path in tqdm(wav_files):\n",
129
+ "\n",
130
+ " text_file_path = os.path.splitext(wav_file_path)[0] + \".txt\"\n",
131
+ " if os.path.exists(text_file_path):\n",
132
+ " with open(text_file_path, 'r') as txt_file:\n",
133
+ " text_content = txt_file.read()\n",
134
+ " else:\n",
135
+ " text_content = \"Not found.\"\n",
136
+ "\n",
137
+ " csv_writer.writerow([wav_file_path, os.path.basename(wav_file_path), sample_rate, channels, text_content])\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 14,
143
+ "metadata": {},
144
+ "outputs": [
145
+ {
146
+ "name": "stderr",
147
+ "output_type": "stream",
148
+ "text": [
149
+ "100%|██████████| 56427/56427 [00:37<00:00, 1492.44it/s]\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "output_csv_file = \"vin.csv\"\n",
155
+ "path_to_wav_folder = \"./vin_data/vlsp2020_train_set_02/\"\n",
156
+ "create_csv_from_wav_folder(path_to_wav_folder, output_csv_file)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "code",
161
+ "execution_count": 34,
162
+ "metadata": {},
163
+ "outputs": [
164
+ {
165
+ "data": {
166
+ "text/html": [
167
+ "<div>\n",
168
+ "<style scoped>\n",
169
+ " .dataframe tbody tr th:only-of-type {\n",
170
+ " vertical-align: middle;\n",
171
+ " }\n",
172
+ "\n",
173
+ " .dataframe tbody tr th {\n",
174
+ " vertical-align: top;\n",
175
+ " }\n",
176
+ "\n",
177
+ " .dataframe thead th {\n",
178
+ " text-align: right;\n",
179
+ " }\n",
180
+ "</style>\n",
181
+ "<table border=\"1\" class=\"dataframe\">\n",
182
+ " <thead>\n",
183
+ " <tr style=\"text-align: right;\">\n",
184
+ " <th></th>\n",
185
+ " <th>path</th>\n",
186
+ " <th>name</th>\n",
187
+ " <th>sentence</th>\n",
188
+ " </tr>\n",
189
+ " </thead>\n",
190
+ " <tbody>\n",
191
+ " <tr>\n",
192
+ " <th>0</th>\n",
193
+ " <td>./vin_data/vlsp2020_train_set_02/spkyut-201907...</td>\n",
194
+ " <td>spkyut-20190730-utt000000716.wav</td>\n",
195
+ " <td>cây cam canh là loại cây ăn quả dễ trồng dễ ch...</td>\n",
196
+ " </tr>\n",
197
+ " <tr>\n",
198
+ " <th>1</th>\n",
199
+ " <td>./vin_data/vlsp2020_train_set_02/database_sa3_...</td>\n",
200
+ " <td>database_sa3_1_150h_15Jan2020_cleaned_utt_0000...</td>\n",
201
+ " <td>những đặc sản vùng miền nổi tiếng như miến don...</td>\n",
202
+ " </tr>\n",
203
+ " <tr>\n",
204
+ " <th>2</th>\n",
205
+ " <td>./vin_data/vlsp2020_train_set_02/speaker_544-0...</td>\n",
206
+ " <td>speaker_544-069450-1.wav</td>\n",
207
+ " <td>trước thông tin này trương nam thành chia sẻ c...</td>\n",
208
+ " </tr>\n",
209
+ " <tr>\n",
210
+ " <th>3</th>\n",
211
+ " <td>./vin_data/vlsp2020_train_set_02/database_sa1_...</td>\n",
212
+ " <td>database_sa1_Jan08_Mar19_cleaned_utt_000005361...</td>\n",
213
+ " <td>giống như những nữ hoàng á</td>\n",
214
+ " </tr>\n",
215
+ " <tr>\n",
216
+ " <th>4</th>\n",
217
+ " <td>./vin_data/vlsp2020_train_set_02/database_sa2_...</td>\n",
218
+ " <td>database_sa2_Jan4_Feb29_cleaned_utt_0000154206...</td>\n",
219
+ " <td>thay vì phun toàn bộ cánh đồng bằng hóa chất c...</td>\n",
220
+ " </tr>\n",
221
+ " </tbody>\n",
222
+ "</table>\n",
223
+ "</div>"
224
+ ],
225
+ "text/plain": [
226
+ " path \\\n",
227
+ "0 ./vin_data/vlsp2020_train_set_02/spkyut-201907... \n",
228
+ "1 ./vin_data/vlsp2020_train_set_02/database_sa3_... \n",
229
+ "2 ./vin_data/vlsp2020_train_set_02/speaker_544-0... \n",
230
+ "3 ./vin_data/vlsp2020_train_set_02/database_sa1_... \n",
231
+ "4 ./vin_data/vlsp2020_train_set_02/database_sa2_... \n",
232
+ "\n",
233
+ " name \\\n",
234
+ "0 spkyut-20190730-utt000000716.wav \n",
235
+ "1 database_sa3_1_150h_15Jan2020_cleaned_utt_0000... \n",
236
+ "2 speaker_544-069450-1.wav \n",
237
+ "3 database_sa1_Jan08_Mar19_cleaned_utt_000005361... \n",
238
+ "4 database_sa2_Jan4_Feb29_cleaned_utt_0000154206... \n",
239
+ "\n",
240
+ " sentence \n",
241
+ "0 cây cam canh là loại cây ăn quả dễ trồng dễ ch... \n",
242
+ "1 những đặc sản vùng miền nổi tiếng như miến don... \n",
243
+ "2 trước thông tin này trương nam thành chia sẻ c... \n",
244
+ "3 giống như những nữ hoàng á \n",
245
+ "4 thay vì phun toàn bộ cánh đồng bằng hóa chất c... "
246
+ ]
247
+ },
248
+ "execution_count": 34,
249
+ "metadata": {},
250
+ "output_type": "execute_result"
251
+ }
252
+ ],
253
+ "source": [
254
+ "import pandas as pd \n",
255
+ "data = pd.read_csv('vin_test.csv')\n",
256
+ "data.head(5)"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 30,
262
+ "metadata": {},
263
+ "outputs": [],
264
+ "source": [
265
+ "import csv\n",
266
+ "import random\n",
267
+ "\n",
268
+ "def split_csv_file(input_file, output_file1, output_file2, ratio):\n",
269
+ " with open(input_file, 'r', newline='', encoding='utf-8') as csvfile:\n",
270
+ " csvreader = csv.reader(csvfile)\n",
271
+ " header = next(csvreader) \n",
272
+ " \n",
273
+ " data = list(csvreader)\n",
274
+ " random.shuffle(data)\n",
275
+ "\n",
276
+ " total_rows = len(data)\n",
277
+ " rows_output_file1 = int(total_rows * ratio)\n",
278
+ " rows_output_file2 = total_rows - rows_output_file1\n",
279
+ " \n",
280
+ " # Split the data into two parts\n",
281
+ " data1 = data[:rows_output_file1]\n",
282
+ " data2 = data[rows_output_file1:]\n",
283
+ "\n",
284
+ " with open(output_file1, 'w', newline='', encoding='utf-8') as csvfile1:\n",
285
+ " csvwriter1 = csv.writer(csvfile1, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
286
+ " csvwriter1.writerow(header)\n",
287
+ " csvwriter1.writerows(data1)\n",
288
+ "\n",
289
+ " with open(output_file2, 'w', newline='', encoding='utf-8') as csvfile2:\n",
290
+ " csvwriter2 = csv.writer(csvfile2, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
291
+ " csvwriter2.writerow(header)\n",
292
+ " csvwriter2.writerows(data2)\n",
293
+ "\n",
294
+ "input_file = 'vin.csv'\n",
295
+ "output_file1 = 'vin_train.csv'\n",
296
+ "output_file2 = 'vin_test.csv'\n",
297
+ "ratio = 0.8 \n",
298
+ "\n",
299
+ "split_csv_file(input_file, output_file1, output_file2, ratio)\n"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": null,
305
+ "metadata": {},
306
+ "outputs": [],
307
+ "source": [
308
+ "from datasets import load_dataset, DatasetDict\n",
309
+ "\n",
310
+ "vivos = DatasetDict()"
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": 46,
316
+ "metadata": {},
317
+ "outputs": [],
318
+ "source": [
319
+ "import os\n",
320
+ "import numpy as np\n",
321
+ "\n",
322
+ "import torch\n",
323
+ "import torchaudio\n",
324
+ "\n",
325
+ "import pandas as pd\n",
326
+ "import whisper\n",
327
+ "import torchaudio.transforms as at\n",
328
+ "from pathlib import Path\n",
329
+ "\n",
330
+ "def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:\n",
331
+ " waveform, sr = torchaudio.load(wave_path, normalize=True)\n",
332
+ " if sample_rate != sr:\n",
333
+ " waveform = at.Resample(sr, sample_rate)(waveform)\n",
334
+ " return waveform\n",
335
+ "\n",
336
+ "\n",
337
+ "\n",
338
+ "def get_list_files_vin100h(phase, dataset_path='./vin_data/vlsp2020_train_set_02/', text_max_length=10000, audio_max_sample_length=1000000, sample_rate=16000):\n",
339
+ " audio_transcript_pair_list = []\n",
340
+ " if phase == 'train':\n",
341
+ " csv_file = 'vin_train.csv'\n",
342
+ " else:\n",
343
+ " csv_file = 'vin_test.csv'\n",
344
+ " df = pd.read_csv(csv_file)\n",
345
+ " for index, row in df.iterrows():\n",
346
+ " new_path = Path(row['path'])\n",
347
+ " audio_id = index\n",
348
+ " text = row['sentence']\n",
349
+ " if new_path.exists():\n",
350
+ " audio = load_wave(new_path, sample_rate=sample_rate)[0]\n",
351
+ " # if len(text) > text_max_length or len(audio) > audio_max_sample_length:\n",
352
+ " # print('skip file:', new_path, 'with len text:', len(text), 'and len audio', len(audio))\n",
353
+ " # continue\n",
354
+ " audio_transcript_pair_list.append((audio_id, str(new_path), text))\n",
355
+ " print(audio_transcript_pair_list)\n",
356
+ " return audio, audio_transcript_pair_list\n"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": null,
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "get_list_files_vin100h(phase='train')"
366
+ ]
367
+ }
368
+ ],
369
+ "metadata": {
370
+ "kernelspec": {
371
+ "display_name": "DUY",
372
+ "language": "python",
373
+ "name": "python3"
374
+ },
375
+ "language_info": {
376
+ "codemirror_mode": {
377
+ "name": "ipython",
378
+ "version": 3
379
+ },
380
+ "file_extension": ".py",
381
+ "mimetype": "text/x-python",
382
+ "name": "python",
383
+ "nbconvert_exporter": "python",
384
+ "pygments_lexer": "ipython3",
385
+ "version": "3.9.17"
386
+ },
387
+ "orig_nbformat": 4
388
+ },
389
+ "nbformat": 4,
390
+ "nbformat_minor": 2
391
+ }
src/MITI.ipynb ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 64,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import glob\n",
11
+ "\n",
12
+ "def count_files_by_extension(path, extension):\n",
13
+ " \"\"\"\n",
14
+ " path : root path to check,\n",
15
+ " extension : .wav, ...\n",
16
+ " \"\"\"\n",
17
+ " total_count = 0\n",
18
+ " \n",
19
+ " for foldername, subfolders, filenames in os.walk(path):\n",
20
+ " files = glob.glob(os.path.join(foldername, f\"*.{extension}\"))\n",
21
+ " total_count += len(files)\n",
22
+ " \n",
23
+ " return total_count\n",
24
+ "\n",
25
+ "\n",
26
+ "root_path = \"./Cleaned_MITI/dataset_2\""
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 65,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "num_wav_files = count_files_by_extension(root_path, \"wav\")\n",
36
+ "num_txt_files = count_files_by_extension(root_path, \"txt\")"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 66,
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "name": "stdout",
46
+ "output_type": "stream",
47
+ "text": [
48
+ "Số lượng file WAV: 2099\n",
49
+ "Số lượng file text: 2099\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "print(f\"Số lượng file WAV: {num_wav_files}\")\n",
55
+ "print(f\"Số lượng file text: {num_txt_files}\")"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 70,
61
+ "metadata": {},
62
+ "outputs": [
63
+ {
64
+ "name": "stdout",
65
+ "output_type": "stream",
66
+ "text": [
67
+ "Tần số mẫu (sample rate): 44100 Hz\n",
68
+ "Số kênh (channels): 1\n"
69
+ ]
70
+ }
71
+ ],
72
+ "source": [
73
+ "import os\n",
74
+ "import random\n",
75
+ "import wave\n",
76
+ "\n",
77
+ "\n",
78
+ "def get_random_wav_file_info(folder_path):\n",
79
+ " for foldername, subfolders, filenames in os.walk(folder_path): \n",
80
+ " wav_files = glob.glob(f\"{foldername}/*.wav\")\n",
81
+ " \n",
82
+ " if not wav_files:\n",
83
+ " return None, None\n",
84
+ " \n",
85
+ " random_wav_file = random.choice(wav_files)\n",
86
+ " \n",
87
+ " with wave.open(random_wav_file, 'rb') as wav_file:\n",
88
+ " sample_rate = wav_file.getframerate()\n",
89
+ " channels = wav_file.getnchannels()\n",
90
+ " \n",
91
+ " return sample_rate, channels\n",
92
+ "\n",
93
+ "path_to_wav_folder = \"./Cleaned_MITI/dataset_2/\"\n",
94
+ "\n",
95
+ "sample_rate, channels = get_random_wav_file_info(path_to_wav_folder)\n",
96
+ "\n",
97
+ "if sample_rate is not None and channels is not None:\n",
98
+ " print(f\"Tần số mẫu (sample rate): {sample_rate} Hz\")\n",
99
+ " print(f\"Số kênh (channels): {channels}\")\n",
100
+ "else:\n",
101
+ " print(\"Nothing.\")\n"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": null,
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "def remove_special_characters(input_string):\n",
111
+ " special_characters = ['.', ',', '-', '_', \" \"]\n",
112
+ " \n",
113
+ " # Duyệt qua từng ký tự trong chuỗi\n",
114
+ " filtered_string = ''.join([char for char in input_string if char not in special_characters])\n",
115
+ " \n",
116
+ " return filtered_string\n",
117
+ "\n",
118
+ "# Sử dụng hàm\n",
119
+ "input_string = \"Hello, this_is_a-test.string!\"\n",
120
+ "output_string = remove_special_characters(input_string)\n",
121
+ "print(output_string) # Kết quả: \"Hello thisisa teststring\"\n"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": 86,
127
+ "metadata": {},
128
+ "outputs": [
129
+ {
130
+ "name": "stderr",
131
+ "output_type": "stream",
132
+ "text": [
133
+ " 84%|████████▎ | 164/196 [00:00<00:00, 1629.92it/s]"
134
+ ]
135
+ },
136
+ {
137
+ "name": "stderr",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "100%|██████████| 196/196 [00:00<00:00, 1580.86it/s]\n",
141
+ "100%|██████████| 218/218 [00:00<00:00, 1440.12it/s]\n",
142
+ "100%|██████████| 216/216 [00:00<00:00, 1364.20it/s]\n",
143
+ "100%|██████████| 205/205 [00:00<00:00, 1412.14it/s]\n",
144
+ "100%|██████████| 204/204 [00:00<00:00, 1426.29it/s]\n",
145
+ "100%|██████████| 220/220 [00:00<00:00, 1511.87it/s]\n",
146
+ "100%|██████████| 225/225 [00:00<00:00, 1499.30it/s]\n",
147
+ "100%|██████████| 175/175 [00:00<00:00, 1492.85it/s]\n",
148
+ "100%|██████████| 220/220 [00:00<00:00, 1496.34it/s]\n",
149
+ "100%|██████████| 220/220 [00:00<00:00, 1480.81it/s]\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "import os\n",
155
+ "import csv\n",
156
+ "from tqdm import tqdm\n",
157
+ "import glob\n",
158
+ "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n",
159
+ "normalizer = BasicTextNormalizer()\n",
160
+ "def create_csv_from_wav_folder(folder_path, output_csv_file):\n",
161
+ " with open(output_csv_file, mode='w', newline='') as csv_file:\n",
162
+ " csv_writer = csv.writer(csv_file)\n",
163
+ " csv_writer.writerow(['path', 'name', 'sentence'])\n",
164
+ "\n",
165
+ " for person_foldername, _, _ in os.walk(folder_path):\n",
166
+ " if \"person_\" in person_foldername:\n",
167
+ " wav_files = glob.glob(os.path.join(person_foldername, \"*.wav\"))\n",
168
+ "\n",
169
+ " for wav_file_path in tqdm(wav_files):\n",
170
+ " wav_filename = os.path.basename(wav_file_path)\n",
171
+ " text_filename = os.path.splitext(wav_filename)[0] + \".txt\"\n",
172
+ " text_file_path = os.path.join(person_foldername, text_filename)\n",
173
+ "\n",
174
+ " if os.path.exists(text_file_path):\n",
175
+ " with open(text_file_path, 'r') as txt_file:\n",
176
+ " text_content = normalizer(txt_file.read())\n",
177
+ " else:\n",
178
+ " text_content = \"Not found.\"\n",
179
+ "\n",
180
+ " csv_writer.writerow([wav_file_path, wav_filename, text_content])\n",
181
+ "\n",
182
+ "root_path = \"./Cleaned_MITI/dataset_2\" \n",
183
+ "output_csv_file = \"MITI.csv\" \n",
184
+ "\n",
185
+ "create_csv_from_wav_folder(root_path, output_csv_file)\n"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": 89,
191
+ "metadata": {},
192
+ "outputs": [
193
+ {
194
+ "data": {
195
+ "text/plain": [
196
+ "2099"
197
+ ]
198
+ },
199
+ "execution_count": 89,
200
+ "metadata": {},
201
+ "output_type": "execute_result"
202
+ }
203
+ ],
204
+ "source": [
205
+ "import pandas as pd \n",
206
+ "data = pd.read_csv('MITI.csv')\n",
207
+ "len(data)"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 90,
213
+ "metadata": {},
214
+ "outputs": [],
215
+ "source": [
216
+ "import csv\n",
217
+ "import random\n",
218
+ "\n",
219
+ "def split_csv_file(input_file, output_file1, output_file2, ratio):\n",
220
+ " with open(input_file, 'r', newline='', encoding='utf-8') as csvfile:\n",
221
+ " csvreader = csv.reader(csvfile)\n",
222
+ " header = next(csvreader) \n",
223
+ " \n",
224
+ " data = list(csvreader)\n",
225
+ " random.shuffle(data)\n",
226
+ "\n",
227
+ " total_rows = len(data)\n",
228
+ " rows_output_file1 = int(total_rows * ratio)\n",
229
+ " rows_output_file2 = total_rows - rows_output_file1\n",
230
+ " \n",
231
+ " # Split the data into two parts\n",
232
+ " data1 = data[:rows_output_file1]\n",
233
+ " data2 = data[rows_output_file1:]\n",
234
+ "\n",
235
+ " with open(output_file1, 'w', newline='', encoding='utf-8') as csvfile1:\n",
236
+ " csvwriter1 = csv.writer(csvfile1, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
237
+ " csvwriter1.writerow(header)\n",
238
+ " csvwriter1.writerows(data1)\n",
239
+ "\n",
240
+ " with open(output_file2, 'w', newline='', encoding='utf-8') as csvfile2:\n",
241
+ " csvwriter2 = csv.writer(csvfile2, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
242
+ " csvwriter2.writerow(header)\n",
243
+ " csvwriter2.writerows(data2)\n",
244
+ "\n",
245
+ "input_file = 'MITI.csv'\n",
246
+ "output_file1 = 'MITI_train.csv'\n",
247
+ "output_file2 = 'MITI_test.csv'\n",
248
+ "ratio = 0.8 \n",
249
+ "\n",
250
+ "split_csv_file(input_file, output_file1, output_file2, ratio)\n"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {},
257
+ "outputs": [],
258
+ "source": [
259
+ "from datasets import load_dataset, DatasetDict\n",
260
+ "\n",
261
+ "vivos = DatasetDict()"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 46,
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "import os\n",
271
+ "import numpy as np\n",
272
+ "\n",
273
+ "import torch\n",
274
+ "import torchaudio\n",
275
+ "\n",
276
+ "import pandas as pd\n",
277
+ "import whisper\n",
278
+ "import torchaudio.transforms as at\n",
279
+ "from pathlib import Path\n",
280
+ "\n",
281
+ "def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:\n",
282
+ " waveform, sr = torchaudio.load(wave_path, normalize=True)\n",
283
+ " if sample_rate != sr:\n",
284
+ " waveform = at.Resample(sr, sample_rate)(waveform)\n",
285
+ " return waveform\n",
286
+ "\n",
287
+ "\n",
288
+ "\n",
289
+ "def get_list_files_vin100h(phase, dataset_path='./vin_data/vlsp2020_train_set_02/', text_max_length=10000, audio_max_sample_length=1000000, sample_rate=16000):\n",
290
+ " audio_transcript_pair_list = []\n",
291
+ " if phase == 'train':\n",
292
+ " csv_file = 'vin_train.csv'\n",
293
+ " else:\n",
294
+ " csv_file = 'vin_test.csv'\n",
295
+ " df = pd.read_csv(csv_file)\n",
296
+ " for index, row in df.iterrows():\n",
297
+ " new_path = Path(row['path'])\n",
298
+ " audio_id = index\n",
299
+ " text = row['sentence']\n",
300
+ " if new_path.exists():\n",
301
+ " audio = load_wave(new_path, sample_rate=sample_rate)[0]\n",
302
+ " # if len(text) > text_max_length or len(audio) > audio_max_sample_length:\n",
303
+ " # print('skip file:', new_path, 'with len text:', len(text), 'and len audio', len(audio))\n",
304
+ " # continue\n",
305
+ " audio_transcript_pair_list.append((audio_id, str(new_path), text))\n",
306
+ " print(audio_transcript_pair_list)\n",
307
+ " return audio, audio_transcript_pair_list\n"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "get_list_files_vin100h(phase='train')"
317
+ ]
318
+ }
319
+ ],
320
+ "metadata": {
321
+ "kernelspec": {
322
+ "display_name": "DUY",
323
+ "language": "python",
324
+ "name": "python3"
325
+ },
326
+ "language_info": {
327
+ "codemirror_mode": {
328
+ "name": "ipython",
329
+ "version": 3
330
+ },
331
+ "file_extension": ".py",
332
+ "mimetype": "text/x-python",
333
+ "name": "python",
334
+ "nbconvert_exporter": "python",
335
+ "pygments_lexer": "ipython3",
336
+ "version": "3.9.17"
337
+ },
338
+ "orig_nbformat": 4
339
+ },
340
+ "nbformat": 4,
341
+ "nbformat_minor": 2
342
+ }
src/download_quantized.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import re
4
+
5
+ from typing import Optional
6
+
7
+ import huggingface_hub
8
+ import requests
9
+
10
+ from tqdm.auto import tqdm
11
+
12
+ _MODELS = (
13
+ "medium"
14
+
15
+ )
16
+
17
+
18
+ def get_assets_path():
19
+ """Returns the path to the assets directory."""
20
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
21
+
22
+
23
+ def get_logger():
24
+ """Returns the module logger."""
25
+ return logging.getLogger("faster_whisper")
26
+
27
+
28
+ def download_model(
29
+ size_or_id: str,
30
+ output_dir: Optional[str] = None,
31
+ local_files_only: bool = False,
32
+ cache_dir: Optional[str] = None,
33
+ ):
34
+ """Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
35
+
36
+ The model is downloaded from https://huggingface.co/DuyTa.
37
+
38
+ Args:
39
+ size_or_id: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
40
+ medium, medium.en, large-v1, or large-v2), or a CTranslate2-converted model ID
41
+ from the Hugging Face Hub (e.g. guillaumekln/faster-whisper-large-v2).
42
+ output_dir: Directory where the model should be saved. If not set, the model is saved in
43
+ the cache directory.
44
+ local_files_only: If True, avoid downloading the file and return the path to the local
45
+ cached file if it exists.
46
+ cache_dir: Path to the folder where cached files are stored.
47
+
48
+ Returns:
49
+ The path to the downloaded model.
50
+
51
+ Raises:
52
+ ValueError: if the model size is invalid.
53
+ """
54
+ if re.match(r".*/.*", size_or_id):
55
+ repo_id = size_or_id
56
+ else:
57
+ if size_or_id not in _MODELS:
58
+ raise ValueError(
59
+ "Invalid model size '%s', expected one of: %s"
60
+ % (size_or_id, ", ".join(_MODELS))
61
+ )
62
+
63
+ #repo_id = "DuyTa/vi-whisper-%s-Lora" % size_or_id
64
+ repo_id = "DuyTa/Vietnamese_ASR"
65
+
66
+ allow_patterns = [
67
+ "config.json",
68
+ "model.bin",
69
+ "tokenizer.json",
70
+ "vocabulary.*",
71
+ ]
72
+
73
+ kwargs = {
74
+ "local_files_only": local_files_only,
75
+ "allow_patterns": allow_patterns,
76
+ "tqdm_class": disabled_tqdm,
77
+ }
78
+
79
+ if output_dir is not None:
80
+ kwargs["local_dir"] = output_dir
81
+ kwargs["local_dir_use_symlinks"] = False
82
+
83
+ if cache_dir is not None:
84
+ kwargs["cache_dir"] = cache_dir
85
+
86
+ try:
87
+ return huggingface_hub.snapshot_download(repo_id, **kwargs)
88
+ except (
89
+ huggingface_hub.utils.HfHubHTTPError,
90
+ requests.exceptions.ConnectionError,
91
+ ) as exception:
92
+ logger = get_logger()
93
+ logger.warning(
94
+ "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
95
+ repo_id,
96
+ exception,
97
+ )
98
+ logger.warning(
99
+ "Trying to load the model directly from the local cache, if it exists."
100
+ )
101
+
102
+ kwargs["local_files_only"] = True
103
+ return huggingface_hub.snapshot_download(repo_id, **kwargs)
104
+
105
+
106
+ def format_timestamp(
107
+ seconds: float,
108
+ always_include_hours: bool = False,
109
+ decimal_marker: str = ".",
110
+ ) -> str:
111
+ assert seconds >= 0, "non-negative timestamp expected"
112
+ milliseconds = round(seconds * 1000.0)
113
+
114
+ hours = milliseconds // 3_600_000
115
+ milliseconds -= hours * 3_600_000
116
+
117
+ minutes = milliseconds // 60_000
118
+ milliseconds -= minutes * 60_000
119
+
120
+ seconds = milliseconds // 1_000
121
+ milliseconds -= seconds * 1_000
122
+
123
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
124
+ return (
125
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
126
+ )
127
+
128
+
129
+ class disabled_tqdm(tqdm):
130
+ def __init__(self, *args, **kwargs):
131
+ kwargs["disable"] = True
132
+ super().__init__(*args, **kwargs)
src/laboratory.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
src/lora_tuning.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ from dataclasses import dataclass
8
+ from datetime import datetime
9
+ from pathlib import Path
10
+ from random import randint
11
+ from typing import Any, Dict, List, Union
12
+
13
+ # datasets imports
14
+ import datasets
15
+
16
+ # metric imports
17
+ import evaluate
18
+ import numpy as np
19
+ import torch
20
+ import transformers
21
+ import wandb
22
+
23
+ # accelerate imports
24
+ from accelerate import Accelerator, dispatch_model
25
+ from accelerate.logging import get_logger
26
+ from datasets import Audio, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
27
+
28
+ # hf imports
29
+ from huggingface_hub import Repository
30
+ from torch.utils.data import DataLoader
31
+ from tqdm import tqdm
32
+ from transformers import (
33
+ SchedulerType,
34
+ WhisperForConditionalGeneration,
35
+ WhisperProcessor,
36
+ get_scheduler,
37
+ set_seed,
38
+ )
39
+ from transformers.models.whisper.english_normalizer import BasicTextNormalizer
40
+ from transformers.utils import get_full_repo_name
41
+
42
+ # peft imports
43
+ from peft import AdaLoraConfig, LoraConfig, PeftModel, get_peft_model
44
+
45
+
46
+ logger = get_logger(__name__, log_level="INFO")
47
+
48
+
49
+ def parse_args():
50
+ parser = argparse.ArgumentParser(description="Whisper Fine-Tuning with AdaLora")
51
+ parser.add_argument(
52
+ "--model_name_or_path",
53
+ type=str,
54
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
55
+ required=True,
56
+ )
57
+ parser.add_argument("--language", type=str, help="Language to use for training; e.g., 'Hindi' ", required=True)
58
+ parser.add_argument("--language_abbr", type=str, help="Language to use for training; e.g., 'hi' ", required=True)
59
+ parser.add_argument(
60
+ "--task", type=str, default="transcribe", help="Task to use for training; e.g., 'transcribe' ", required=False
61
+ )
62
+ parser.add_argument(
63
+ "--dataset_name",
64
+ type=str,
65
+ default="mozilla-foundation/common_voice_11_0",
66
+ help="Dataset to use for training; e.g., 'whisper' ",
67
+ required=False,
68
+ )
69
+ parser.add_argument(
70
+ "--dataset_in_streaming_mode",
71
+ action="store_true",
72
+ help="Whether to use streaming mode for the dataset.",
73
+ )
74
+ parser.add_argument(
75
+ "--do_lower_case", action="store_true", help="lowercase the transcribed text before tokenizing"
76
+ )
77
+ parser.add_argument(
78
+ "--do_remove_punctuation", action="store_true", help="remove punctuation from the transcribed text"
79
+ )
80
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
81
+ parser.add_argument(
82
+ "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
83
+ )
84
+ parser.add_argument("--max_audio_input_length", type=float, default=30.0, help="Maximum audio length in seconds.")
85
+ parser.add_argument(
86
+ "--preprocessing_num_workers",
87
+ type=int,
88
+ default=None,
89
+ help="The number of processes to use for the preprocessing.",
90
+ )
91
+ parser.add_argument(
92
+ "--per_device_train_batch_size",
93
+ type=int,
94
+ default=8,
95
+ help="Batch size (per device) for the training dataloader.",
96
+ )
97
+ parser.add_argument(
98
+ "--per_device_eval_batch_size",
99
+ type=int,
100
+ default=8,
101
+ help="Batch size (per device) for the evaluation dataloader.",
102
+ )
103
+ parser.add_argument(
104
+ "--buffer_size",
105
+ type=int,
106
+ default=5000,
107
+ help="Number of samples to prefetch in the streaming mode.",
108
+ )
109
+ parser.add_argument(
110
+ "--dataloader_pin_memory",
111
+ action="store_true",
112
+ help="Whether or not to pin memory for the DataLoader.",
113
+ )
114
+ parser.add_argument(
115
+ "--dataloader_num_workers",
116
+ type=int,
117
+ default=0,
118
+ help="Number of subprocesses to use for data loading.",
119
+ )
120
+ parser.add_argument(
121
+ "--learning_rate",
122
+ type=float,
123
+ default=5e-5,
124
+ help="Initial learning rate (after the potential warmup period) to use.",
125
+ )
126
+ parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
127
+ parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
128
+ parser.add_argument(
129
+ "--max_train_steps",
130
+ type=int,
131
+ default=None,
132
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
133
+ )
134
+ parser.add_argument(
135
+ "--gradient_accumulation_steps",
136
+ type=int,
137
+ default=1,
138
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
139
+ )
140
+ parser.add_argument(
141
+ "--lr_scheduler_type",
142
+ type=SchedulerType,
143
+ default="linear",
144
+ help="The scheduler type to use.",
145
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
146
+ )
147
+ parser.add_argument(
148
+ "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
149
+ )
150
+ parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
151
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
152
+ parser.add_argument(
153
+ "--load_best_model",
154
+ action="store_true",
155
+ help="Whether to load the best model at the end of training",
156
+ )
157
+ parser.add_argument(
158
+ "--with_tracking",
159
+ action="store_true",
160
+ help="Whether to enable experiment trackers for logging.",
161
+ )
162
+ parser.add_argument(
163
+ "--report_to",
164
+ type=str,
165
+ default="all",
166
+ help=(
167
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
168
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
169
+ "Only applicable when `--with_tracking` is passed."
170
+ ),
171
+ )
172
+ parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
173
+ parser.add_argument(
174
+ "--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
175
+ )
176
+ parser.add_argument(
177
+ "--checkpointing_steps",
178
+ type=int,
179
+ default=500,
180
+ help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
181
+ )
182
+ parser.add_argument(
183
+ "--logging_steps",
184
+ type=int,
185
+ default=100,
186
+ help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
187
+ )
188
+ parser.add_argument(
189
+ "--evaluation_steps",
190
+ type=int,
191
+ default=500,
192
+ help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
193
+ )
194
+ parser.add_argument(
195
+ "--resume_from_checkpoint",
196
+ type=str,
197
+ default=None,
198
+ help="If the training should continue from a checkpoint folder.",
199
+ )
200
+
201
+ # lora/adalora specific args
202
+ parser.add_argument(
203
+ "--use_peft",
204
+ action="store_true",
205
+ help="Whether to use PEFT",
206
+ )
207
+ parser.add_argument(
208
+ "--use_adalora",
209
+ action="store_true",
210
+ help="Whether to use AdaLoRA or LoRA. If set, uses AdaLoRA instead of the default LoRA.",
211
+ )
212
+ parser.add_argument(
213
+ "--init_r",
214
+ type=int,
215
+ default=12,
216
+ help="Initial AdaLoRA rank",
217
+ )
218
+ parser.add_argument(
219
+ "--target_r",
220
+ type=int,
221
+ default=4,
222
+ help="Target AdaLoRA rank",
223
+ )
224
+ parser.add_argument(
225
+ "--tinit",
226
+ type=int,
227
+ default=200,
228
+ help="number of warmup steps for AdaLoRA wherein no pruning is performed",
229
+ )
230
+ parser.add_argument(
231
+ "--tfinal",
232
+ type=int,
233
+ default=1000,
234
+ help=" fix the resulting budget distribution and fine-tune the model for tfinal steps when using AdaLoRA ",
235
+ )
236
+ parser.add_argument(
237
+ "--delta_t",
238
+ type=int,
239
+ default=10,
240
+ help="interval of steps for AdaLoRA to update rank",
241
+ )
242
+ parser.add_argument(
243
+ "--lora_alpha",
244
+ type=int,
245
+ default=32,
246
+ help="LORA alpha",
247
+ )
248
+ parser.add_argument(
249
+ "--r",
250
+ type=int,
251
+ default=8,
252
+ help="LORA rank",
253
+ )
254
+ parser.add_argument(
255
+ "--lora_dropout",
256
+ type=float,
257
+ default=0.1,
258
+ help="LORA dropout",
259
+ )
260
+ parser.add_argument(
261
+ "--orth_reg_weight",
262
+ type=float,
263
+ default=0.5,
264
+ help="Orthogonal regularization weight",
265
+ )
266
+ parser.add_argument(
267
+ "--debug_mode",
268
+ action="store_true",
269
+ help="Whether to use debug mode",
270
+ )
271
+
272
+ args = parser.parse_args()
273
+
274
+ if args.push_to_hub:
275
+ assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
276
+
277
+ return args
278
+
279
+
280
+ def load_streaming_dataset(dataset_name, dataset_config_name, split, **kwargs):
281
+ if "+" in split:
282
+ # load multiple splits separated by the `+` symbol *with* streaming mode
283
+ dataset_splits = [
284
+ load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs)
285
+ for split_name in split.split("+")
286
+ ]
287
+ # interleave multiple splits to form one dataset
288
+ interleaved_dataset = interleave_datasets(dataset_splits)
289
+ return interleaved_dataset
290
+ else:
291
+ # load a single split *with* streaming mode
292
+ dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
293
+ return dataset
294
+
295
+
296
+ def prepare_dataset_wrapper(do_lower_case, do_remove_punctuation, processor, normalizer):
297
+ def prepare_dataset(batch):
298
+ # load and (possibly) resample audio data to 16kHz
299
+ audio = batch["audio"]
300
+
301
+ # compute log-Mel input features from input audio array
302
+ batch["input_features"] = processor.feature_extractor(
303
+ audio["array"], sampling_rate=audio["sampling_rate"]
304
+ ).input_features[0]
305
+ # compute input length of audio sample in seconds
306
+ batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
307
+
308
+ # optional pre-processing steps
309
+ transcription = batch["sentence"]
310
+ if do_lower_case:
311
+ transcription = transcription.lower()
312
+ if do_remove_punctuation:
313
+ transcription = normalizer(transcription).strip()
314
+
315
+ # encode target text to label ids
316
+ batch["labels"] = processor.tokenizer(transcription).input_ids
317
+ return batch
318
+
319
+ return prepare_dataset
320
+
321
+
322
+ def save_model_hook(models, weights, output_dir):
323
+ for model in models:
324
+ model.save_pretrained(output_dir)
325
+ # make sure to pop weight so that corresponding model is not saved again
326
+ weights.pop()
327
+
328
+
329
+ def load_model_hook(models, input_dir):
330
+ while len(models) > 0:
331
+ model = models.pop()
332
+ # pop models so that they are not loaded again
333
+ PeftModel.from_pretrained(model.base_model.model, input_dir)
334
+
335
+
336
+ @dataclass
337
+ class DataCollatorSpeechSeq2SeqWithPadding:
338
+ processor: Any
339
+
340
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
341
+ # split inputs and labels since they have to be of different lengths and need different padding methods
342
+ # first treat the audio inputs by simply returning torch tensors
343
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
344
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
345
+
346
+ # get the tokenized label sequences
347
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
348
+ # pad the labels to max length
349
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
350
+
351
+ # replace padding with -100 to ignore loss correctly
352
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
353
+
354
+ # if bos token is appended in previous tokenization step,
355
+ # cut bos token here as it's append later anyways
356
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
357
+ labels = labels[:, 1:]
358
+
359
+ batch["labels"] = labels
360
+
361
+ return batch
362
+
363
+
364
+ def get_audio_length_processor(max_input_length):
365
+ def is_audio_in_length_range(length):
366
+ return length < max_input_length
367
+
368
+ return is_audio_in_length_range
369
+
370
+
371
+ def evaluation_loop(model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator):
372
+ model.eval()
373
+ predictions = []
374
+ references = []
375
+ normalized_predictions = []
376
+ normalized_references = []
377
+ for _, batch in enumerate(tqdm(eval_dataloader)):
378
+ with torch.cuda.amp.autocast():
379
+ with torch.no_grad():
380
+ generated_tokens = (
381
+ model.generate(
382
+ input_features=batch["input_features"],
383
+ forced_decoder_ids=forced_decoder_ids,
384
+ max_new_tokens=255,
385
+ )
386
+ .cpu()
387
+ .numpy()
388
+ )
389
+ labels = batch["labels"].cpu().numpy()
390
+ labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
391
+ decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
392
+ decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
393
+ predictions.extend(decoded_preds)
394
+ references.extend(decoded_labels)
395
+ normalized_predictions.extend([normalizer(pred).strip() for pred in decoded_preds])
396
+ normalized_references.extend([normalizer(label).strip() for label in decoded_labels])
397
+ del generated_tokens, labels, batch
398
+ gc.collect()
399
+ wer = 100 * metric.compute(predictions=predictions, references=references)
400
+ normalized_wer = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
401
+ eval_metrics = {"eval/wer": wer, "eval/normalized_wer": normalized_wer}
402
+ if accelerator.get_tracker("wandb"):
403
+ sample_size = min(len(predictions), 256)
404
+ ids = [randint(0, len(predictions) - 1) for p in range(0, sample_size)]
405
+ sample_predictions = [predictions[i] for i in ids]
406
+ sample_references = [references[i] for i in ids]
407
+ sample_normalized_predictions = [normalized_predictions[i] for i in ids]
408
+ sample_normalized_references = [normalized_references[i] for i in ids]
409
+ table_rows = [
410
+ list(r)
411
+ for r in zip(
412
+ sample_predictions, sample_references, sample_normalized_predictions, sample_normalized_references
413
+ )
414
+ ]
415
+ eval_metrics["eval_samples"] = wandb.Table(
416
+ columns=["predictions", "references", "normalized_predictions", "normalized_references"],
417
+ rows=table_rows,
418
+ )
419
+ return eval_metrics
420
+
421
+
422
+ def main():
423
+ args = parse_args()
424
+
425
+ # initialize accelerator
426
+ accelerator = (
427
+ Accelerator(
428
+ log_with=args.report_to,
429
+ project_dir=args.output_dir,
430
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
431
+ )
432
+ if args.with_tracking
433
+ else Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
434
+ )
435
+
436
+ # Make one log on every process with the configuration for debugging.
437
+ logging.basicConfig(
438
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
439
+ datefmt="%m/%d/%Y %H:%M:%S",
440
+ level=logging.INFO,
441
+ )
442
+ logger.info(accelerator.state, main_process_only=False)
443
+ if accelerator.is_local_main_process:
444
+ datasets.utils.logging.set_verbosity_warning()
445
+ transformers.utils.logging.set_verbosity_info()
446
+ else:
447
+ datasets.utils.logging.set_verbosity_error()
448
+ transformers.utils.logging.set_verbosity_error()
449
+
450
+ # If passed along, set the training seed now.
451
+ if args.seed is not None:
452
+ set_seed(args.seed)
453
+
454
+ # Handle the repository creation
455
+ if accelerator.is_main_process:
456
+ if args.push_to_hub:
457
+ if args.hub_model_id is None:
458
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
459
+ else:
460
+ repo_name = args.hub_model_id
461
+ repo = Repository(args.output_dir, clone_from=repo_name)
462
+
463
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
464
+ if "step_*" not in gitignore:
465
+ gitignore.write("step_*\n")
466
+ if "epoch_*" not in gitignore:
467
+ gitignore.write("epoch_*\n")
468
+ elif args.output_dir is not None:
469
+ os.makedirs(args.output_dir, exist_ok=True)
470
+ accelerator.wait_for_everyone()
471
+
472
+ # load dataset either in streaming mode or not
473
+ processor = WhisperProcessor.from_pretrained(args.model_name_or_path, language=args.language, task=args.task)
474
+ normalizer = BasicTextNormalizer()
475
+ prepare_dataset = prepare_dataset_wrapper(args.do_lower_case, args.do_remove_punctuation, processor, normalizer)
476
+ is_audio_in_length_range = get_audio_length_processor(args.max_audio_input_length)
477
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
478
+
479
+ if args.dataset_in_streaming_mode:
480
+ raw_datasets = IterableDatasetDict()
481
+ loading_method = load_streaming_dataset
482
+ else:
483
+ raw_datasets = DatasetDict()
484
+ loading_method = load_dataset
485
+
486
+ if args.debug_mode:
487
+ train_split = "train[:100]"
488
+ test_split = "test[:10]"
489
+ else:
490
+ train_split = "train+validation"
491
+ test_split = "test"
492
+
493
+ raw_datasets["train"] = loading_method(
494
+ args.dataset_name, args.language_abbr, split=train_split, use_auth_token=True
495
+ )
496
+ raw_datasets["test"] = loading_method(args.dataset_name, args.language_abbr, split=test_split, use_auth_token=True)
497
+ raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000))
498
+
499
+ logger.info("Dataset loaded: %s", raw_datasets)
500
+ logger.info(f'{raw_datasets["train"][0]}')
501
+
502
+ vectorized_datasets = raw_datasets.map(
503
+ prepare_dataset,
504
+ remove_columns=list(next(iter(raw_datasets.values())).features),
505
+ num_proc=args.preprocessing_num_workers,
506
+ ).with_format("torch")
507
+
508
+ if args.dataset_in_streaming_mode:
509
+ vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
510
+ buffer_size=args.buffer_size,
511
+ seed=args.seed,
512
+ )
513
+
514
+ # filter out audio files that are too long from the training set
515
+ is_audio_in_length_range = get_audio_length_processor(args.max_audio_input_length)
516
+ vectorized_datasets["train"] = vectorized_datasets["train"].filter(
517
+ is_audio_in_length_range, input_columns=["input_length"]
518
+ )
519
+
520
+ # get dataloaders
521
+ train_dataloader = DataLoader(
522
+ vectorized_datasets["train"],
523
+ batch_size=args.per_device_train_batch_size,
524
+ shuffle=True,
525
+ collate_fn=data_collator,
526
+ num_workers=args.dataloader_num_workers,
527
+ pin_memory=args.dataloader_pin_memory,
528
+ )
529
+ eval_dataloader = DataLoader(
530
+ vectorized_datasets["test"],
531
+ batch_size=args.per_device_eval_batch_size,
532
+ collate_fn=data_collator,
533
+ num_workers=args.dataloader_num_workers,
534
+ pin_memory=args.dataloader_pin_memory,
535
+ )
536
+
537
+ # metric
538
+ metric = evaluate.load("wer")
539
+
540
+ # model
541
+ model = WhisperForConditionalGeneration.from_pretrained(args.model_name_or_path, load_in_8bit=True)
542
+ model.config.forced_decoder_ids = None
543
+ model.config.suppress_tokens = []
544
+ if len(set(model.hf_device_map.values()).intersection({"cpu", "disk"})) > 0:
545
+ raise ValueError("Training on CPU or disk is not supported.")
546
+ if len(set(model.hf_device_map.values())) > 1:
547
+ device_map = model.hf_device_map.copy()
548
+ # required because `labels` are on main execution device (0) while the output of `proj_out` is on other device.
549
+ # So, this leads to device mismatch error when calculation cross-entropy between logits and labels.
550
+ # Won't arise during inference as `labels` aren't supplied during that time
551
+ # instead of changing device of one of the tied modules, I have to do this for all tied modules
552
+ # else the execution device of remaining tied modules isn't changed
553
+ device_map["model.decoder.embed_tokens"] = model._hf_hook.execution_device
554
+ device_map["model.decoder.embed_positions"] = model._hf_hook.execution_device
555
+ device_map["proj_out"] = model._hf_hook.execution_device
556
+ dispatch_model(model, device_map=device_map)
557
+
558
+ # preparing peft model
559
+ if args.use_peft:
560
+ from peft import prepare_model_for_int8_training
561
+
562
+ model = prepare_model_for_int8_training(model)
563
+
564
+ # as Whisper model uses Conv layer in encoder, checkpointing disables grad computation
565
+ # to avoid this, make the inputs trainable
566
+ def make_inputs_require_grad(module, input, output):
567
+ output.requires_grad_(True)
568
+
569
+ model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
570
+
571
+ # wrapping model with adalora tuner
572
+ if args.use_adalora:
573
+ config = AdaLoraConfig(
574
+ init_r=args.init_r,
575
+ target_r=args.target_r,
576
+ beta1=0.85,
577
+ beta2=0.85,
578
+ tinit=args.tinit,
579
+ tfinal=args.tfinal,
580
+ deltaT=args.delta_t,
581
+ lora_alpha=args.lora_alpha,
582
+ lora_dropout=args.lora_dropout,
583
+ target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
584
+ orth_reg_weight=args.orth_reg_weight,
585
+ )
586
+ else:
587
+ config = LoraConfig(
588
+ r=args.r,
589
+ lora_alpha=args.lora_alpha,
590
+ target_modules=["q_proj", "v_proj"],
591
+ lora_dropout=args.lora_dropout,
592
+ )
593
+
594
+ model = get_peft_model(model, config)
595
+ model.print_trainable_parameters()
596
+
597
+ # optimizer
598
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
599
+
600
+ if args.max_train_steps is None:
601
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
602
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
603
+ else:
604
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
605
+
606
+ # scheduler
607
+ lr_scheduler = get_scheduler(
608
+ name=args.lr_scheduler_type,
609
+ optimizer=optimizer,
610
+ num_warmup_steps=args.num_warmup_steps,
611
+ num_training_steps=args.max_train_steps,
612
+ )
613
+
614
+ # Prepare everything with our `accelerator`.
615
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
616
+ model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
617
+ )
618
+
619
+ accelerator.print(model)
620
+
621
+ # Note here that the max steps is adjusted by the accelerator's num_processes
622
+ args.max_train_steps = math.ceil(args.max_train_steps / accelerator.num_processes)
623
+ if args.use_peft and args.use_adalora:
624
+ model.base_model.peft_config["default"].total_step = args.max_train_steps
625
+ # model.base_model.peft_config.total_step = args.max_train_steps
626
+
627
+ # We need to initialize the trackers we use, and also store our configuration.
628
+ # The trackers initializes automatically on the main process.
629
+ if args.with_tracking:
630
+ run_name = f"run-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
631
+ experiment_config = vars(args)
632
+ # TensorBoard cannot log Enums, need the raw value
633
+ experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
634
+ accelerator.init_trackers(
635
+ "Whisper PEFT Fine-Tuning", config=experiment_config, init_kwargs={"wandb": {"name": run_name}}
636
+ )
637
+
638
+ # saving and loading checkpoints for resuming training
639
+ accelerator.register_save_state_pre_hook(save_model_hook)
640
+ accelerator.register_load_state_pre_hook(load_model_hook)
641
+
642
+ total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
643
+ logger.info("***** Running training *****")
644
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
645
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
646
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
647
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
648
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
649
+ # Only show the progress bar once on each machine.
650
+ progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
651
+ global_step = 0
652
+ starting_epoch = 0
653
+ best_metric = None
654
+ resume_step = 0
655
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
656
+
657
+ # Potentially load in the weights and states from a previous save
658
+ if args.resume_from_checkpoint:
659
+ accelerator.load_state(args.resume_from_checkpoint)
660
+ path = os.path.basename(args.resume_from_checkpoint)
661
+ training_difference = os.path.splitext(path)[0]
662
+ global_step = resume_step = int(training_difference.replace("step_", ""))
663
+ starting_epoch = resume_step // len(train_dataloader)
664
+ resume_step -= starting_epoch * len(train_dataloader)
665
+
666
+ # We need to adjust the progress bar to the current step
667
+ progress_bar.update(resume_step)
668
+ for epoch in range(starting_epoch, args.num_train_epochs):
669
+ model.train()
670
+ if args.with_tracking:
671
+ total_loss = 0
672
+ running_loss = 0
673
+ for step, batch in enumerate(accelerator.skip_first_batches(train_dataloader, num_batches=resume_step)):
674
+ with accelerator.accumulate(model):
675
+ outputs = model(**batch)
676
+ loss = outputs.loss
677
+ accelerator.backward(loss)
678
+ optimizer.step()
679
+ lr_scheduler.step()
680
+
681
+ # Update the importance of low-rank matrices
682
+ # and allocate the budget accordingly.
683
+ # This is only needed for AdaLora.
684
+ # Note that this requires parameter gradients.
685
+ # Hence being called before optimizer.zero_grad().
686
+ if args.use_peft and args.use_adalora:
687
+ model.update_and_allocate(global_step)
688
+
689
+ optimizer.zero_grad()
690
+ global_step += 1
691
+ progress_bar.update(1)
692
+
693
+ if args.with_tracking:
694
+ step_loss = accelerator.reduce(loss.detach().clone()).item()
695
+ total_loss += step_loss
696
+ running_loss += step_loss
697
+
698
+ if global_step % args.checkpointing_steps == 0:
699
+ output_dir = os.path.join(args.output_dir, f"step_{global_step}")
700
+ accelerator.save_state(output_dir)
701
+
702
+ if global_step % args.logging_steps == 0:
703
+ if args.with_tracking:
704
+ accelerator.log({"train/running_loss": running_loss / args.logging_steps}, step=global_step)
705
+ running_loss = 0
706
+
707
+ if global_step % args.evaluation_steps == 0:
708
+ eval_metrics = evaluation_loop(
709
+ model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator
710
+ )
711
+ if args.with_tracking:
712
+ logger.info(f"Step {global_step} eval metrics: {eval_metrics}")
713
+ accelerator.log(eval_metrics, step=global_step)
714
+ if best_metric is None or eval_metrics["eval/wer"] < best_metric:
715
+ best_metric = eval_metrics["eval/wer"]
716
+ accelerator.save_state(os.path.join(args.output_dir, "best_checkpoint"))
717
+ model.train()
718
+
719
+ if global_step >= args.max_train_steps:
720
+ break
721
+
722
+ if args.with_tracking:
723
+ train_epoch_loss = total_loss / (step + 1)
724
+ logger.info(f"Epoch {epoch} train loss: {train_epoch_loss}")
725
+ accelerator.log({"epoch/train_loss": train_epoch_loss}, step=epoch)
726
+
727
+ if args.push_to_hub and epoch <= args.num_train_epochs - 1:
728
+ accelerator.wait_for_everyone()
729
+ unwrapped_model = accelerator.unwrap_model(model)
730
+ unwrapped_model.save_pretrained(args.output_dir, is_main_process=accelerator.is_main_process)
731
+ # evaluate the model at the end of training
732
+ eval_metrics = evaluation_loop(
733
+ model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator
734
+ )
735
+ if args.with_tracking:
736
+ logger.info(f"Step {global_step} eval metrics: {eval_metrics}")
737
+ accelerator.log(eval_metrics, step=global_step)
738
+ if best_metric is None or eval_metrics["eval/wer"] < best_metric:
739
+ best_metric = eval_metrics["eval/wer"]
740
+ accelerator.save_state(os.path.join(args.output_dir, "best_checkpoint"))
741
+
742
+ if accelerator.is_main_process:
743
+ processor.tokenizer.save_pretrained(args.output_dir)
744
+ repo.push_to_hub(
745
+ commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
746
+ )
747
+
748
+ if args.load_best_model:
749
+ # load the best model
750
+ accelerator.load_state(os.path.join(args.output_dir, "best_checkpoint"))
751
+ model.resize_modules_by_rank_pattern(model.peft_config["default"].rank_pattern, "default")
752
+ eval_metrics = evaluation_loop(
753
+ model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator
754
+ )
755
+ if args.with_tracking:
756
+ best_metrics = {"best_" + k: v for k, v in eval_metrics.items()}
757
+ accelerator.log(best_metrics, step=global_step)
758
+
759
+ accelerator.wait_for_everyone()
760
+ unwrapped_model = accelerator.unwrap_model(model)
761
+ unwrapped_model.save_pretrained(args.output_dir, is_main_process=accelerator.is_main_process)
762
+ if accelerator.is_main_process:
763
+ processor.tokenizer.save_pretrained(args.output_dir)
764
+ if args.push_to_hub:
765
+ repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
766
+
767
+ with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
768
+ eval_metrics.pop("eval_samples")
769
+ json.dump(eval_metrics, f)
770
+
771
+
772
+ if __name__ == "__main__":
773
+ main()
src/merge_lora.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+
5
+ from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizerFast,\
6
+ WhisperProcessor
7
+ from peft import PeftModel, PeftConfig
8
+ from utils.utils import print_arguments, add_arguments
9
+
10
+ parser = argparse.ArgumentParser(description=__doc__)
11
+ add_arg = functools.partial(add_arguments, argparser=parser)
12
+ add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="微调保存的模型路径")
13
+ add_arg('output_dir', type=str, default='models/', help="合并模型的保存目录")
14
+ add_arg("local_files_only", type=bool, default=False, help="是否只在本地加载模型,不尝试下载")
15
+ args = parser.parse_args()
16
+ print_arguments(args)
17
+
18
+ assert os.path.exists(args.lora_model), f"模型文件{args.lora_model}不存在"
19
+
20
+ peft_config = PeftConfig.from_pretrained(args.lora_model)
21
+ #
22
+ base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cpu"},
23
+ local_files_only=args.local_files_only)
24
+
25
+ model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
26
+ feature_extractor = WhisperFeatureExtractor.from_pretrained(peft_config.base_model_name_or_path,
27
+ local_files_only=args.local_files_only)
28
+ tokenizer = WhisperTokenizerFast.from_pretrained(peft_config.base_model_name_or_path,
29
+ local_files_only=args.local_files_only)
30
+ processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path,
31
+ local_files_only=args.local_files_only)
32
+
33
+
34
+ model = model.merge_and_unload()
35
+ model.train(False)
36
+
37
+ save_directory = os.path.join(args.output_dir, f'{os.path.basename(peft_config.base_model_name_or_path)}-finetune')
38
+ os.makedirs(save_directory, exist_ok=True)
39
+
40
+ model.save_pretrained(save_directory)
41
+ feature_extractor.save_pretrained(save_directory)
42
+ tokenizer.save_pretrained(save_directory)
43
+ processor.save_pretrained(save_directory)
44
+ print(f'合并模型保持在:{save_directory}')
src/prepare_data.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import datasets
4
+ from datasets import DatasetDict, load_dataset, concatenate_datasets
5
+ from tqdm import tqdm
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoFeatureExtractor,
9
+ AutoModelForSpeechSeq2Seq,
10
+ AutoTokenizer,
11
+ set_seed,
12
+ )
13
+ from transformers.utils.versions import require_version
14
+ from transformers.utils import check_min_version
15
+ from tqdm import tqdm
16
+
17
+ from audiomentations import (
18
+ AddBackgroundNoise,
19
+ AddGaussianNoise,
20
+ Compose,
21
+ Gain,
22
+ OneOf,
23
+ PitchShift,
24
+ PolarityInversion,
25
+ TimeStretch,
26
+ )
27
+
28
+
29
+ check_min_version("4.27.0.dev0")
30
+
31
+ require_version(
32
+ "datasets>=1.18.0",
33
+ "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt",
34
+ )
35
+
36
+ logger = logging.getLogger(__name__)
37
+ from datasets import Dataset, DatasetDict
38
+ import torchaudio
39
+ from torchaudio import transforms as at
40
+ import pandas as pd
41
+ import torch
42
+ from pathlib import Path
43
+ import random
44
+ def main():
45
+ # Set seed before initializing model.
46
+ set_seed(42)
47
+
48
+ # 5. Load pretrained model, tokenizer, and feature extractor
49
+ #
50
+ # Distributed training:
51
+ # The .from_pretrained methods guarantee that only one local process can concurrently
52
+ config = AutoConfig.from_pretrained(
53
+ "openai/whisper-medium", revision="main", use_auth_token=True
54
+ )
55
+
56
+ config.update({"forced_decoder_ids": None, "suppress_tokens": None})
57
+
58
+ # *****************************SpecAugment for whisper models
59
+ # if getattr(config, "model_type", None) == "whisper":
60
+ config.update({"apply_spec_augment": True})
61
+
62
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
63
+ "openai/whisper-medium",
64
+ revision="main",
65
+ use_auth_token=True,
66
+ )
67
+ tokenizer = AutoTokenizer.from_pretrained(
68
+ "openai/whisper-medium",
69
+ use_fast=True,
70
+ revision="main",
71
+ use_auth_token=True,
72
+ )
73
+
74
+ tokenizer.set_prefix_tokens(language="vi", task="transcribe")
75
+
76
+ # 7. Preprocessing the datasets.
77
+ # We need to read the audio files as arrays and tokenize the targets.
78
+ max_input_length = 30.0 * 16000
79
+ min_input_length = 0.0 * 16000
80
+ audio_column_name = "audio"
81
+ num_workers = 16
82
+ text_column_name = "text"
83
+ model_input_name = feature_extractor.model_input_names[0]
84
+
85
+ # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
86
+ forward_attention_mask = True
87
+
88
+ # noise_dir = "../noise/ESC-50-master/audio/"
89
+ # define augmentation
90
+ augmentation = Compose(
91
+ [
92
+ TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=True),
93
+ Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1),
94
+ PitchShift(min_semitones=-4, max_semitones=4, p=0.2),
95
+ ]
96
+ )
97
+
98
+ def augment_dataset(batch):
99
+ # load and (possibly) resample audio data to 16kHz
100
+ sample = batch["audio"]
101
+
102
+ # apply augmentation
103
+ augmented_waveform = augmentation(
104
+ sample, sample_rate=16000
105
+ )
106
+ batch["audio"]["array"] = augmented_waveform
107
+ return batch
108
+
109
+ def prepare_dataset(batch):
110
+ # process audio
111
+ sample = batch[audio_column_name]
112
+ inputs = feature_extractor(
113
+ sample,
114
+ sampling_rate= 16000,
115
+ return_attention_mask=forward_attention_mask,
116
+ )
117
+ # process audio length
118
+ batch[model_input_name] = inputs.get(model_input_name)[0]
119
+ batch["input_length"] = len(sample)
120
+ if forward_attention_mask:
121
+ batch["attention_mask"] = inputs.get("attention_mask")[0]
122
+
123
+ # process targets
124
+ input_str = batch[text_column_name]
125
+ batch["labels"] = tokenizer(input_str).input_ids
126
+ return batch
127
+
128
+
129
+ def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
130
+ waveform, sr = torchaudio.load(wave_path, normalize=True)
131
+ if sample_rate != sr:
132
+ waveform = at.Resample(sr, sample_rate)(waveform)
133
+ return waveform
134
+
135
+
136
+ def get_list_files_MITI(phase, sample_rate=16000, audio_max_sample_length=480000, fraction=0.15):
137
+ audio_list = []
138
+ text_list = []
139
+ if phase == 'train':
140
+ csv_file = 'vin_train.csv'
141
+ else:
142
+ csv_file = 'vin_test.csv'
143
+ df = pd.read_csv(csv_file)
144
+
145
+ # Calculate the number of samples to select based on the fraction
146
+ num_samples = int(len(df) * fraction)
147
+
148
+ # Randomly select the indices of samples
149
+ selected_indices = random.sample(range(len(df)), num_samples)
150
+
151
+ for index, row in tqdm(df.iterrows()):
152
+ if index not in selected_indices:
153
+ continue
154
+
155
+ new_path = Path(row['path'])
156
+ audio_id = index
157
+ text = row['sentence']
158
+
159
+ if new_path.exists():
160
+ audio = load_wave(new_path, sample_rate=sample_rate)[0]
161
+ if len(audio) > audio_max_sample_length or len(audio) < 0:
162
+ print('skip file:', new_path, 'with len audio', len(audio))
163
+ continue
164
+ audio_list.append(audio)
165
+ text_list.append(text)
166
+
167
+ return audio_list, text_list
168
+
169
+ # Assuming you have two CSV files, 'vin_train.csv' and 'vin_test.csv', in the same directory
170
+
171
+ # Get the training dataset
172
+ train_audio, train_text = get_list_files_MITI(phase='train')
173
+
174
+ # Get the testing dataset
175
+ test_audio, test_text = get_list_files_MITI(phase='test')
176
+
177
+ # Create the Dataset objects
178
+ train_dataset = Dataset.from_dict({"audio": train_audio, "text": train_text})
179
+ test_dataset = Dataset.from_dict({"audio": test_audio, "text": test_text})
180
+
181
+ # Create the DatasetDict
182
+ vin_100h = DatasetDict({"train": train_dataset, "test": test_dataset})
183
+
184
+
185
+
186
+
187
+
188
+ print(vin_100h)
189
+
190
+
191
+
192
+
193
+
194
+ vectorized_datasets = vin_100h.map(
195
+ prepare_dataset,
196
+ remove_columns=["audio", "text"],
197
+ num_proc=1,
198
+ desc="preprocess train dataset",
199
+ )
200
+
201
+
202
+ print(vectorized_datasets)
203
+
204
+ vectorized_datasets.save_to_disk(
205
+ "./vin_10h", num_proc=1
206
+ )
207
+
208
+ return
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()
src/realtime.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3.7
2
+
3
+ import argparse
4
+ import io
5
+ import os
6
+ import speech_recognition as sr
7
+ import whisperx
8
+ import torch
9
+
10
+ from datetime import datetime, timedelta
11
+ from queue import Queue
12
+ from tempfile import NamedTemporaryFile
13
+ from time import sleep
14
+ from sys import platform
15
+
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--model", default="Vietnamese_ASR/ct2ranslate", help="Size of model or the local path for model ",
20
+ type=str)
21
+ parser.add_argument("--non_english", action='store_true',
22
+ help="Don't use the English model.")
23
+ parser.add_argument("--language", default="vi", help="The language to infer the model with whisper", type=str)
24
+ parser.add_argument("--device", default="cpu",
25
+ help="Choose device for inference "
26
+ , type=str)
27
+ parser.add_argument("--energy_threshold", default=900,
28
+ help="Energy level for mic to detect.", type=int)
29
+ parser.add_argument("--record_timeout", default=0.6,
30
+ help="How real-time the recording is in seconds.", type=float)
31
+ parser.add_argument("--phrase_timeout", default=3,
32
+ help="How much empty space between recordings before we "
33
+ "consider it a new line in the transcription.", type=float)
34
+ if 'linux' in platform:
35
+ parser.add_argument("--default_microphone", default='pulse',
36
+ help="Default microphone name for SpeechRecognition. "
37
+ "Run this with 'list' to view available Microphones.", type=str)
38
+ args = parser.parse_args()
39
+
40
+
41
+ # The last time a recording was retreived from the queue.
42
+ phrase_time = None
43
+ # Current raw audio bytes.
44
+ last_sample = bytes()
45
+ # Thread safe Queue for passing data from the threaded recording callback.
46
+ data_queue = Queue()
47
+ # We use SpeechRecognizer to record our audio because it has a nice feauture where it can detect when speech ends.
48
+ recorder = sr.Recognizer()
49
+ recorder.energy_threshold = args.energy_threshold
50
+ # Definitely do this, dynamic energy compensation lowers the energy threshold dramtically to a point where the SpeechRecognizer never stops recording.
51
+ recorder.dynamic_energy_threshold = False
52
+
53
+ # Important for linux users.
54
+ # Prevents permanent application hang and crash by using the wrong Microphone
55
+ if 'linux' in platform:
56
+ mic_name = args.default_microphone
57
+ if not mic_name or mic_name == 'list':
58
+ print("Available microphone devices are: ")
59
+ for index, name in enumerate(sr.Microphone.list_microphone_names()):
60
+ print(f"Microphone with name \"{name}\" found")
61
+ return
62
+ else:
63
+ for index, name in enumerate(sr.Microphone.list_microphone_names()):
64
+ if mic_name in name:
65
+ source = sr.Microphone(sample_rate=16000, device_index=index)
66
+ break
67
+ else:
68
+ source = sr.Microphone(sample_rate=16000)
69
+
70
+ # Load / Download model
71
+ model = args.model
72
+ # if args.model != "large" and not args.non_english:
73
+ # model = model + ".en"
74
+ audio_model = whisperx.load_model(model, device=args.device, compute_type="float16", language = args.language)
75
+
76
+ record_timeout = args.record_timeout
77
+ phrase_timeout = args.phrase_timeout
78
+
79
+ temp_file = NamedTemporaryFile().name
80
+ transcription = ['']
81
+
82
+ with source:
83
+ recorder.adjust_for_ambient_noise(source)
84
+
85
+ def record_callback(_, audio:sr.AudioData) -> None:
86
+ """
87
+ Threaded callback function to recieve audio data when recordings finish.
88
+ audio: An AudioData containing the recorded bytes.
89
+ """
90
+ # Grab the raw bytes and push it into the thread safe queue.
91
+ data = audio.get_raw_data()
92
+ data_queue.put(data)
93
+
94
+ # Create a background thread that will pass us raw audio bytes.
95
+ # We could do this manually but SpeechRecognizer provides a nice helper.
96
+ recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
97
+
98
+ # Cue the user that we're ready to go.
99
+ print("Model loaded.\n")
100
+
101
+ while True:
102
+ try:
103
+ now = datetime.utcnow()
104
+ # Pull raw recorded audio from the queue.
105
+ if not data_queue.empty():
106
+ phrase_complete = False
107
+ # If enough time has passed between recordings, consider the phrase complete.
108
+ # Clear the current working audio buffer to start over with the new data.
109
+ if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
110
+ last_sample = bytes()
111
+ phrase_complete = True
112
+ # This is the last time we received new audio data from the queue.
113
+ phrase_time = now
114
+
115
+ # Concatenate our current audio data with the latest audio data.
116
+ while not data_queue.empty():
117
+ data = data_queue.get()
118
+ last_sample += data
119
+
120
+ # Use AudioData to convert the raw data to wav data.
121
+ audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
122
+ wav_data = io.BytesIO(audio_data.get_wav_data())
123
+
124
+ # Write wav data to the temporary file as bytes.
125
+ with open(temp_file, 'w+b') as f:
126
+ f.write(wav_data.read())
127
+
128
+ # Read the transcription.
129
+ result = audio_model.transcribe(temp_file, language="en",batch_size = 8)
130
+ text = result['segments'][0]['text'].strip()
131
+
132
+ # If we detected a pause between recordings, add a new item to our transcripion.
133
+ # Otherwise edit the existing one.
134
+ if phrase_complete:
135
+ transcription.append(text)
136
+ else:
137
+ transcription[-1] = text
138
+
139
+ # Clear the console to reprint the updated transcription.
140
+ os.system('cls' if os.name=='nt' else 'clear')
141
+ for line in transcription:
142
+ print(line)
143
+ # Flush stdout.
144
+ print('', end='', flush=True)
145
+
146
+ # Infinite loops are bad for processors, must sleep.
147
+ sleep(0.25)
148
+ except KeyboardInterrupt:
149
+ break
150
+
151
+ print("\n\nTranscription:")
152
+ for line in transcription:
153
+ print(line)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
src/requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/peft.git@main
2
+ bitsandbytes
3
+ accelerate
4
+ loralib
5
+ librosa
6
+ datasets>=2.6.1
7
+ evaluate>=0.3.0
8
+ jiwer
9
+ tensorboard
10
+ soundfile==0.12.1
11
+ git+https://github.com/m-bain/whisperX.git
12
+ #nvidia-cudnn-cu11-8.7.0.84 need
13
+ lightning-fabric
14
+ pyaudio
15
+ SpeechRecognition
src/test_whisper.ipynb ADDED
@@ -0,0 +1,1546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "data": {
10
+ "application/vnd.jupyter.widget-view+json": {
11
+ "model_id": "9d7b03aae28b4282b143eb17c3d8d687",
12
+ "version_major": 2,
13
+ "version_minor": 0
14
+ },
15
+ "text/plain": [
16
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
17
+ ]
18
+ },
19
+ "metadata": {},
20
+ "output_type": "display_data"
21
+ }
22
+ ],
23
+ "source": [
24
+ "from huggingface_hub import notebook_login\n",
25
+ "\n",
26
+ "notebook_login()"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 2,
32
+ "metadata": {},
33
+ "outputs": [
34
+ {
35
+ "name": "stderr",
36
+ "output_type": "stream",
37
+ "text": [
38
+ "Found cached dataset vivos (/home/tesla/.cache/huggingface/datasets/vivos/default/1.1.0/ab59078eb266c1a0ea856786ba56b5b8d56f29b42dfb37d92115cf81a7b1a5e0)\n",
39
+ "Found cached dataset vivos (/home/tesla/.cache/huggingface/datasets/vivos/default/1.1.0/ab59078eb266c1a0ea856786ba56b5b8d56f29b42dfb37d92115cf81a7b1a5e0)\n"
40
+ ]
41
+ }
42
+ ],
43
+ "source": [
44
+ "from datasets import load_dataset, DatasetDict\n",
45
+ "\n",
46
+ "vivos = DatasetDict()\n",
47
+ "\n",
48
+ "vivos[\"train\"] = load_dataset(\"vivos\", split=\"train\", use_auth_token=True)\n",
49
+ "vivos[\"test\"] = load_dataset(\"vivos\", split=\"test\", use_auth_token=True)\n",
50
+ "\n",
51
+ "\n"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 38,
57
+ "metadata": {},
58
+ "outputs": [
59
+ {
60
+ "data": {
61
+ "text/plain": [
62
+ "DatasetDict({\n",
63
+ " train: Dataset({\n",
64
+ " features: ['speaker_id', 'path', 'audio', 'sentence'],\n",
65
+ " num_rows: 11660\n",
66
+ " })\n",
67
+ " test: Dataset({\n",
68
+ " features: ['speaker_id', 'path', 'audio', 'sentence'],\n",
69
+ " num_rows: 760\n",
70
+ " })\n",
71
+ "})"
72
+ ]
73
+ },
74
+ "execution_count": 38,
75
+ "metadata": {},
76
+ "output_type": "execute_result"
77
+ }
78
+ ],
79
+ "source": [
80
+ "vivos"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 3,
86
+ "metadata": {},
87
+ "outputs": [],
88
+ "source": [
89
+ "vivos_clean = vivos.remove_columns([\"speaker_id\", \"path\"])"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 40,
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "data": {
99
+ "text/plain": [
100
+ "DatasetDict({\n",
101
+ " train: Dataset({\n",
102
+ " features: ['audio', 'sentence'],\n",
103
+ " num_rows: 11660\n",
104
+ " })\n",
105
+ " test: Dataset({\n",
106
+ " features: ['audio', 'sentence'],\n",
107
+ " num_rows: 760\n",
108
+ " })\n",
109
+ "})"
110
+ ]
111
+ },
112
+ "execution_count": 40,
113
+ "metadata": {},
114
+ "output_type": "execute_result"
115
+ }
116
+ ],
117
+ "source": [
118
+ "vivos_clean"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": 79,
124
+ "metadata": {},
125
+ "outputs": [
126
+ {
127
+ "name": "stdout",
128
+ "output_type": "stream",
129
+ "text": [
130
+ "{'audio': {'path': 'vivos/train/waves/VIVOSSPK27/VIVOSSPK27_084.wav', 'array': array([ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
131
+ " 9.15527344e-05, -5.18798828e-04, -9.15527344e-04]), 'sampling_rate': 16000}, 'sentence': 'CHƯA HẾT ĐI KHIẾU NẠI THÌ NHÀ MẠNG BẢO VỀ ĐẠI LÝ CHỌN SỐ KHÁC ĐI'}\n"
132
+ ]
133
+ }
134
+ ],
135
+ "source": [
136
+ "print(vivos_clean['train'][12])"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 4,
142
+ "metadata": {},
143
+ "outputs": [
144
+ {
145
+ "name": "stderr",
146
+ "output_type": "stream",
147
+ "text": [
148
+ "Found cached dataset common_voice_13_0 (/home/tesla/.cache/huggingface/datasets/mozilla-foundation___common_voice_13_0/vi/13.0.0/2506e9a8950f5807ceae08c2920e814222909fd7f477b74f5d225802e9f04055)\n",
149
+ "Found cached dataset common_voice_13_0 (/home/tesla/.cache/huggingface/datasets/mozilla-foundation___common_voice_13_0/vi/13.0.0/2506e9a8950f5807ceae08c2920e814222909fd7f477b74f5d225802e9f04055)\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "\n",
155
+ "common_voice = DatasetDict()\n",
156
+ "\n",
157
+ "common_voice[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_13_0\", \"vi\", split=\"train+validation\", use_auth_token=True)\n",
158
+ "common_voice[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_13_0\", \"vi\", split=\"test\", use_auth_token=True)\n",
159
+ "\n",
160
+ "\n"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 42,
166
+ "metadata": {},
167
+ "outputs": [
168
+ {
169
+ "data": {
170
+ "text/plain": [
171
+ "DatasetDict({\n",
172
+ " train: Dataset({\n",
173
+ " features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],\n",
174
+ " num_rows: 2854\n",
175
+ " })\n",
176
+ " test: Dataset({\n",
177
+ " features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],\n",
178
+ " num_rows: 1225\n",
179
+ " })\n",
180
+ "})"
181
+ ]
182
+ },
183
+ "execution_count": 42,
184
+ "metadata": {},
185
+ "output_type": "execute_result"
186
+ }
187
+ ],
188
+ "source": [
189
+ "common_voice"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 67,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "common_voice_clean = common_voice.remove_columns([\"client_id\", \"path\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\", \"age\", \"accent\", \"variant\"])\n"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": 89,
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "data": {
208
+ "text/plain": [
209
+ "DatasetDict({\n",
210
+ " train: Dataset({\n",
211
+ " features: ['audio', 'sentence'],\n",
212
+ " num_rows: 2854\n",
213
+ " })\n",
214
+ " test: Dataset({\n",
215
+ " features: ['audio', 'sentence'],\n",
216
+ " num_rows: 1225\n",
217
+ " })\n",
218
+ "})"
219
+ ]
220
+ },
221
+ "execution_count": 89,
222
+ "metadata": {},
223
+ "output_type": "execute_result"
224
+ }
225
+ ],
226
+ "source": [
227
+ "common_voice_clean"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 1,
233
+ "metadata": {},
234
+ "outputs": [
235
+ {
236
+ "ename": "NameError",
237
+ "evalue": "name 'common_voice_clean' is not defined",
238
+ "output_type": "error",
239
+ "traceback": [
240
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
241
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
242
+ "Cell \u001b[0;32mIn[1], line 9\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[39mreturn\u001b[39;00m example\n\u001b[1;32m 7\u001b[0m common_voice_clear \u001b[39m=\u001b[39m DatasetDict()\n\u001b[0;32m----> 9\u001b[0m common_voice_clear[\u001b[39m\"\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m common_voice_clean[\u001b[39m\"\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mmap(convert_to_uppercase)\n\u001b[1;32m 10\u001b[0m common_voice_clear[\u001b[39m\"\u001b[39m\u001b[39mtest\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m common_voice_clean[\u001b[39m\"\u001b[39m\u001b[39mtest\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mmap(convert_to_uppercase)\n",
243
+ "\u001b[0;31mNameError\u001b[0m: name 'common_voice_clean' is not defined"
244
+ ]
245
+ }
246
+ ],
247
+ "source": [
248
+ "from datasets import DatasetDict\n",
249
+ "\n",
250
+ "def convert_to_uppercase(example):\n",
251
+ " example[\"sentence\"] = example[\"sentence\"].upper()\n",
252
+ " return example\n",
253
+ "\n",
254
+ "common_voice_clear = DatasetDict()\n",
255
+ "\n",
256
+ "common_voice_clear[\"train\"] = common_voice_clean[\"train\"].map(convert_to_uppercase)\n",
257
+ "common_voice_clear[\"test\"] = common_voice_clean[\"test\"].map(convert_to_uppercase)"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": 93,
263
+ "metadata": {},
264
+ "outputs": [
265
+ {
266
+ "name": "stdout",
267
+ "output_type": "stream",
268
+ "text": [
269
+ "{'audio': {'path': '/home/tesla/.cache/huggingface/datasets/downloads/extracted/acb70896120347904e003bb826dcabc1ddd05a02210935cb44ce1c807e8742a5/vi_train_0/common_voice_vi_23901118.mp3', 'array': array([ 0.00000000e+00, 4.20543185e-14, 1.38823347e-14, ...,\n",
270
+ " -8.41874498e-06, -8.36193431e-06, -6.76584477e-06]), 'sampling_rate': 48000}, 'sentence': 'KHI CON CÓ MẸ'}\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "print(common_voice_clear['train'][1])"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": 94,
281
+ "metadata": {},
282
+ "outputs": [
283
+ {
284
+ "name": "stderr",
285
+ "output_type": "stream",
286
+ "text": [
287
+ "100%|██████████| 2854/2854 [33:25<00:00, 1.42it/s]\n"
288
+ ]
289
+ }
290
+ ],
291
+ "source": [
292
+ "from pydub import AudioSegment\n",
293
+ "import os\n",
294
+ "\n",
295
+ "from tqdm import tqdm\n",
296
+ "\n",
297
+ "def convert_mp3_to_wav(mp3_path, wav_path, target_sampling_rate):\n",
298
+ " audio = AudioSegment.from_mp3(mp3_path)\n",
299
+ " audio = audio.set_frame_rate(target_sampling_rate)\n",
300
+ " audio.export(wav_path, format='wav')\n",
301
+ "\n",
302
+ "target_sampling_rate = 16000\n",
303
+ "\n",
304
+ "for example in tqdm(common_voice_clear[\"train\"]):\n",
305
+ " mp3_path = example[\"audio\"][\"path\"]\n",
306
+ " wav_path = os.path.splitext(mp3_path)[0] + \".wav\"\n",
307
+ " convert_mp3_to_wav(mp3_path, wav_path, target_sampling_rate)\n",
308
+ " example[\"audio\"][\"path\"] = wav_path\n",
309
+ "\n"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": 95,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "import datasets\n",
319
+ "from datasets import Audio\n",
320
+ "\n",
321
+ "common_voice_clean = common_voice_clean.cast_column(\"audio\", Audio(sampling_rate=16000))"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": 47,
327
+ "metadata": {},
328
+ "outputs": [],
329
+ "source": [
330
+ "concat = DatasetDict()"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 96,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "concat[\"train\"] = datasets.concatenate_datasets([common_voice_clean[\"train\"], vivos_clean[\"train\"]])\n",
340
+ "\n",
341
+ "#concat['test']= datasets.concatenate_datasets([common_voice_clean[\"test\"], vivos_clean[\"test\"]])\n",
342
+ "concat['test']= vivos_clean[\"test\"]\n"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": 97,
348
+ "metadata": {},
349
+ "outputs": [
350
+ {
351
+ "data": {
352
+ "text/plain": [
353
+ "DatasetDict({\n",
354
+ " train: Dataset({\n",
355
+ " features: ['audio', 'sentence'],\n",
356
+ " num_rows: 14514\n",
357
+ " })\n",
358
+ " test: Dataset({\n",
359
+ " features: ['audio', 'sentence'],\n",
360
+ " num_rows: 760\n",
361
+ " })\n",
362
+ "})"
363
+ ]
364
+ },
365
+ "execution_count": 97,
366
+ "metadata": {},
367
+ "output_type": "execute_result"
368
+ }
369
+ ],
370
+ "source": [
371
+ "concat"
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": 98,
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "from transformers import WhisperFeatureExtractor\n",
381
+ "\n",
382
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-small\")\n"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": 99,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "from transformers import WhisperTokenizerFast\n",
392
+ "\n",
393
+ "tokenizer = WhisperTokenizerFast.from_pretrained(\"openai/whisper-small\", language=\"Vietnamese\", task=\"transcribe\")\n"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": 80,
399
+ "metadata": {},
400
+ "outputs": [
401
+ {
402
+ "name": "stdout",
403
+ "output_type": "stream",
404
+ "text": [
405
+ "Input: KHÔNG CÓ AI BÁC BỎ QUYỀN ĐÓ\n",
406
+ "Decoded w/ special: <|startoftranscript|><|notimestamps|>KHÔNG CÓ AI BÁC BỎ QUYỀN ĐÓ<|endoftext|>\n",
407
+ "Decoded w/out special: KHÔNG CÓ AI BÁC BỎ QUYỀN ĐÓ\n",
408
+ "Are equal: True\n"
409
+ ]
410
+ }
411
+ ],
412
+ "source": [
413
+ "input_str = concat[\"train\"][8550][\"sentence\"]\n",
414
+ "labels = tokenizer(input_str).input_ids\n",
415
+ "decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)\n",
416
+ "decoded_str = tokenizer.decode(labels, skip_special_tokens=True)\n",
417
+ "\n",
418
+ "print(f\"Input: {input_str}\")\n",
419
+ "print(f\"Decoded w/ special: {decoded_with_special}\")\n",
420
+ "print(f\"Decoded w/out special: {decoded_str}\")\n",
421
+ "print(f\"Are equal: {input_str == decoded_str}\")\n"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "execution_count": 100,
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": [
430
+ "from transformers import WhisperProcessor\n",
431
+ "\n",
432
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"Vietnamese\", task=\"transcribe\")\n"
433
+ ]
434
+ },
435
+ {
436
+ "cell_type": "code",
437
+ "execution_count": 19,
438
+ "metadata": {},
439
+ "outputs": [],
440
+ "source": [
441
+ "from datasets import Audio\n",
442
+ "\n",
443
+ "concat = concat.cast_column(\"audio\", Audio(sampling_rate=16000))"
444
+ ]
445
+ },
446
+ {
447
+ "cell_type": "code",
448
+ "execution_count": 59,
449
+ "metadata": {},
450
+ "outputs": [
451
+ {
452
+ "name": "stdout",
453
+ "output_type": "stream",
454
+ "text": [
455
+ "{'audio': {'path': 'vivos/train/waves/VIVOSSPK12/VIVOSSPK12_R077.wav', 'array': array([ 0.00000000e+00, 0.00000000e+00, -3.05175781e-05, ...,\n",
456
+ " 1.31225586e-03, 1.12915039e-03, 1.55639648e-03]), 'sampling_rate': 16000}, 'sentence': 'KIÊN GIANG'}\n"
457
+ ]
458
+ }
459
+ ],
460
+ "source": [
461
+ "print(concat[\"train\"][4500])"
462
+ ]
463
+ },
464
+ {
465
+ "cell_type": "code",
466
+ "execution_count": 101,
467
+ "metadata": {},
468
+ "outputs": [],
469
+ "source": [
470
+ "def prepare_dataset(batch):\n",
471
+ " # load and resample audio data from 48 to 16kHz\n",
472
+ " audio = batch[\"audio\"]\n",
473
+ "\n",
474
+ " # compute log-Mel input features from input audio array \n",
475
+ " batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
476
+ "\n",
477
+ " # encode target text to label ids \n",
478
+ " batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
479
+ " return batch\n"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": 102,
485
+ "metadata": {},
486
+ "outputs": [
487
+ {
488
+ "data": {
489
+ "application/vnd.jupyter.widget-view+json": {
490
+ "model_id": "c35c921e0dde433fb0ef9346310238a3",
491
+ "version_major": 2,
492
+ "version_minor": 0
493
+ },
494
+ "text/plain": [
495
+ "Map (num_proc=6): 0%| | 0/14514 [00:00<?, ? examples/s]"
496
+ ]
497
+ },
498
+ "metadata": {},
499
+ "output_type": "display_data"
500
+ },
501
+ {
502
+ "data": {
503
+ "application/vnd.jupyter.widget-view+json": {
504
+ "model_id": "8c5af4ed5f8141d2b0673972f7616941",
505
+ "version_major": 2,
506
+ "version_minor": 0
507
+ },
508
+ "text/plain": [
509
+ "Map (num_proc=6): 0%| | 0/760 [00:00<?, ? examples/s]"
510
+ ]
511
+ },
512
+ "metadata": {},
513
+ "output_type": "display_data"
514
+ }
515
+ ],
516
+ "source": [
517
+ "concat = concat.map(prepare_dataset, remove_columns=concat.column_names[\"train\"], num_proc=6)"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "code",
522
+ "execution_count": 103,
523
+ "metadata": {},
524
+ "outputs": [],
525
+ "source": [
526
+ "import torch\n",
527
+ "\n",
528
+ "from dataclasses import dataclass\n",
529
+ "from typing import Any, Dict, List, Union\n",
530
+ "\n",
531
+ "@dataclass\n",
532
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
533
+ " processor: Any\n",
534
+ "\n",
535
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
536
+ " # split inputs and labels since they have to be of different lengths and need different padding methods\n",
537
+ " # first treat the audio inputs by simply returning torch tensors\n",
538
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
539
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
540
+ "\n",
541
+ " # get the tokenized label sequences\n",
542
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
543
+ " # pad the labels to max length\n",
544
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
545
+ "\n",
546
+ " # replace padding with -100 to ignore loss correctly\n",
547
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
548
+ "\n",
549
+ " # if bos token is appended in previous tokenization step,\n",
550
+ " # cut bos token here as it's append later anyways\n",
551
+ " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
552
+ " labels = labels[:, 1:]\n",
553
+ "\n",
554
+ " batch[\"labels\"] = labels\n",
555
+ "\n",
556
+ " return batch\n"
557
+ ]
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "execution_count": 104,
562
+ "metadata": {},
563
+ "outputs": [],
564
+ "source": [
565
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": 105,
571
+ "metadata": {},
572
+ "outputs": [],
573
+ "source": [
574
+ "import os\n",
575
+ "\n",
576
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" "
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "markdown",
581
+ "metadata": {},
582
+ "source": [
583
+ "Train\n"
584
+ ]
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "execution_count": 106,
589
+ "metadata": {},
590
+ "outputs": [],
591
+ "source": [
592
+ "import evaluate\n",
593
+ "\n",
594
+ "metric = evaluate.load(\"wer\")\n",
595
+ "\n",
596
+ "\n",
597
+ "def compute_metrics(pred):\n",
598
+ " pred_ids = pred.predictions\n",
599
+ " label_ids = pred.label_ids\n",
600
+ "\n",
601
+ " # replace -100 with the pad_token_id\n",
602
+ " label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
603
+ "\n",
604
+ " # we do not want to group tokens when computing the metrics\n",
605
+ " pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
606
+ " label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
607
+ "\n",
608
+ " wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n",
609
+ "\n",
610
+ " return {\"wer\": wer}\n"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "code",
615
+ "execution_count": 107,
616
+ "metadata": {},
617
+ "outputs": [],
618
+ "source": [
619
+ "from transformers import WhisperForConditionalGeneration\n",
620
+ "\n",
621
+ "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-small\")\n"
622
+ ]
623
+ },
624
+ {
625
+ "cell_type": "code",
626
+ "execution_count": 108,
627
+ "metadata": {},
628
+ "outputs": [],
629
+ "source": [
630
+ "model.config.forced_decoder_ids = None\n",
631
+ "model.config.suppress_tokens = []"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": 109,
637
+ "metadata": {},
638
+ "outputs": [],
639
+ "source": [
640
+ "from transformers import Seq2SeqTrainingArguments\n",
641
+ "\n",
642
+ "training_args = Seq2SeqTrainingArguments(\n",
643
+ " output_dir=\"./vi_whisper-small\", # change to a repo name of your choice\n",
644
+ " per_device_train_batch_size=16,\n",
645
+ " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
646
+ " learning_rate=1e-4,\n",
647
+ " warmup_steps=1000,\n",
648
+ " max_steps=8000,\n",
649
+ " gradient_checkpointing=True,\n",
650
+ " fp16=True,\n",
651
+ " evaluation_strategy=\"steps\",\n",
652
+ " per_device_eval_batch_size=8,\n",
653
+ " predict_with_generate=True,\n",
654
+ " generation_max_length=225,\n",
655
+ " save_steps=4000,\n",
656
+ " eval_steps=1000,\n",
657
+ " logging_steps=25,\n",
658
+ " report_to=[\"tensorboard\"],\n",
659
+ " load_best_model_at_end=True,\n",
660
+ " metric_for_best_model=\"wer\",\n",
661
+ " greater_is_better=False,\n",
662
+ " push_to_hub=True,\n",
663
+ ")\n"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "code",
668
+ "execution_count": 126,
669
+ "metadata": {},
670
+ "outputs": [
671
+ {
672
+ "name": "stderr",
673
+ "output_type": "stream",
674
+ "text": [
675
+ "/media/tesla/New Volume1/DEMO/DUY/Vietnamese_ASR/./vi_whisper-small is already a clone of https://huggingface.co/DuyTa/vi_whisper-small. Make sure you pull the latest changes with `repo.git_pull()`.\n"
676
+ ]
677
+ },
678
+ {
679
+ "ename": "OSError",
680
+ "evalue": "From https://huggingface.co/DuyTa/vi_whisper-small\n d7893fc..47c00b5 main -> origin/main\nhint: You have divergent branches and need to specify how to reconcile them.\nhint: You can do so by running one of the following commands sometime before\nhint: your next pull:\nhint: \nhint: git config pull.rebase false # merge (the default strategy)\nhint: git config pull.rebase true # rebase\nhint: git config pull.ff only # fast-forward only\nhint: \nhint: You can replace \"git config\" with \"git config --global\" to set a default\nhint: preference for all repositories. You can also pass --rebase, --no-rebase,\nhint: or --ff-only on the command line to override the configured default per\nhint: invocation.\nfatal: Need to specify how to reconcile divergent branches.\n",
681
+ "output_type": "error",
682
+ "traceback": [
683
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
684
+ "\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)",
685
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:984\u001b[0m, in \u001b[0;36mRepository.git_pull\u001b[0;34m(self, rebase, lfs)\u001b[0m\n\u001b[1;32m 983\u001b[0m \u001b[39mwith\u001b[39;00m _lfs_log_progress():\n\u001b[0;32m--> 984\u001b[0m result \u001b[39m=\u001b[39m run_subprocess(command, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlocal_dir)\n\u001b[1;32m 985\u001b[0m logger\u001b[39m.\u001b[39minfo(result\u001b[39m.\u001b[39mstdout)\n",
686
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/utils/_subprocess.py:83\u001b[0m, in \u001b[0;36mrun_subprocess\u001b[0;34m(command, folder, check, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m folder \u001b[39m=\u001b[39m \u001b[39mstr\u001b[39m(folder)\n\u001b[0;32m---> 83\u001b[0m \u001b[39mreturn\u001b[39;00m subprocess\u001b[39m.\u001b[39;49mrun(\n\u001b[1;32m 84\u001b[0m command,\n\u001b[1;32m 85\u001b[0m stderr\u001b[39m=\u001b[39;49msubprocess\u001b[39m.\u001b[39;49mPIPE,\n\u001b[1;32m 86\u001b[0m stdout\u001b[39m=\u001b[39;49msubprocess\u001b[39m.\u001b[39;49mPIPE,\n\u001b[1;32m 87\u001b[0m check\u001b[39m=\u001b[39;49mcheck,\n\u001b[1;32m 88\u001b[0m encoding\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mutf-8\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 89\u001b[0m errors\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mreplace\u001b[39;49m\u001b[39m\"\u001b[39;49m, \u001b[39m# if not utf-8, replace char by �\u001b[39;49;00m\n\u001b[1;32m 90\u001b[0m cwd\u001b[39m=\u001b[39;49mfolder \u001b[39mor\u001b[39;49;00m os\u001b[39m.\u001b[39;49mgetcwd(),\n\u001b[1;32m 91\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 92\u001b[0m )\n",
687
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/subprocess.py:528\u001b[0m, in \u001b[0;36mrun\u001b[0;34m(input, capture_output, timeout, check, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[39mif\u001b[39;00m check \u001b[39mand\u001b[39;00m retcode:\n\u001b[0;32m--> 528\u001b[0m \u001b[39mraise\u001b[39;00m CalledProcessError(retcode, process\u001b[39m.\u001b[39margs,\n\u001b[1;32m 529\u001b[0m output\u001b[39m=\u001b[39mstdout, stderr\u001b[39m=\u001b[39mstderr)\n\u001b[1;32m 530\u001b[0m \u001b[39mreturn\u001b[39;00m CompletedProcess(process\u001b[39m.\u001b[39margs, retcode, stdout, stderr)\n",
688
+ "\u001b[0;31mCalledProcessError\u001b[0m: Command '['git', 'pull']' returned non-zero exit status 128.",
689
+ "\nDuring handling of the above exception, another exception occurred:\n",
690
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
691
+ "Cell \u001b[0;32mIn[126], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformers\u001b[39;00m \u001b[39mimport\u001b[39;00m Seq2SeqTrainer\n\u001b[0;32m----> 3\u001b[0m trainer \u001b[39m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 4\u001b[0m args\u001b[39m=\u001b[39;49mtraining_args,\n\u001b[1;32m 5\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 6\u001b[0m train_dataset\u001b[39m=\u001b[39;49mconcat[\u001b[39m\"\u001b[39;49m\u001b[39mtrain\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 7\u001b[0m \n\u001b[1;32m 8\u001b[0m \n\u001b[1;32m 9\u001b[0m \n\u001b[1;32m 10\u001b[0m \n\u001b[1;32m 11\u001b[0m eval_dataset\u001b[39m=\u001b[39;49mconcat[\u001b[39m\"\u001b[39;49m\u001b[39mtest\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 12\u001b[0m data_collator\u001b[39m=\u001b[39;49mdata_collator,\n\u001b[1;32m 13\u001b[0m compute_metrics\u001b[39m=\u001b[39;49mcompute_metrics,\n\u001b[1;32m 14\u001b[0m tokenizer\u001b[39m=\u001b[39;49mprocessor\u001b[39m.\u001b[39;49mfeature_extractor,\n\u001b[1;32m 15\u001b[0m )\n",
692
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer_seq2seq.py:56\u001b[0m, in \u001b[0;36mSeq2SeqTrainer.__init__\u001b[0;34m(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 43\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 44\u001b[0m model: Union[\u001b[39m\"\u001b[39m\u001b[39mPreTrainedModel\u001b[39m\u001b[39m\"\u001b[39m, nn\u001b[39m.\u001b[39mModule] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 54\u001b[0m preprocess_logits_for_metrics: Optional[Callable[[torch\u001b[39m.\u001b[39mTensor, torch\u001b[39m.\u001b[39mTensor], torch\u001b[39m.\u001b[39mTensor]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 55\u001b[0m ):\n\u001b[0;32m---> 56\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 57\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 58\u001b[0m args\u001b[39m=\u001b[39;49margs,\n\u001b[1;32m 59\u001b[0m data_collator\u001b[39m=\u001b[39;49mdata_collator,\n\u001b[1;32m 60\u001b[0m train_dataset\u001b[39m=\u001b[39;49mtrain_dataset,\n\u001b[1;32m 61\u001b[0m eval_dataset\u001b[39m=\u001b[39;49meval_dataset,\n\u001b[1;32m 62\u001b[0m tokenizer\u001b[39m=\u001b[39;49mtokenizer,\n\u001b[1;32m 63\u001b[0m model_init\u001b[39m=\u001b[39;49mmodel_init,\n\u001b[1;32m 64\u001b[0m compute_metrics\u001b[39m=\u001b[39;49mcompute_metrics,\n\u001b[1;32m 65\u001b[0m callbacks\u001b[39m=\u001b[39;49mcallbacks,\n\u001b[1;32m 66\u001b[0m optimizers\u001b[39m=\u001b[39;49moptimizers,\n\u001b[1;32m 67\u001b[0m preprocess_logits_for_metrics\u001b[39m=\u001b[39;49mpreprocess_logits_for_metrics,\n\u001b[1;32m 68\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[39m# Override self.model.generation_config if a GenerationConfig is specified in args.\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[39m# Priority: args.generation_config > model.generation_config > default GenerationConfig.\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mgeneration_config \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
693
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer.py:551\u001b[0m, in \u001b[0;36mTrainer.__init__\u001b[0;34m(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[39m# Create clone of distant repo and output directory if needed\u001b[39;00m\n\u001b[1;32m 550\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mpush_to_hub:\n\u001b[0;32m--> 551\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_git_repo(at_init\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 552\u001b[0m \u001b[39m# In case of pull, we need to make sure every process has the latest.\u001b[39;00m\n\u001b[1;32m 553\u001b[0m \u001b[39mif\u001b[39;00m is_torch_tpu_available():\n",
694
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer.py:3449\u001b[0m, in \u001b[0;36mTrainer.init_git_repo\u001b[0;34m(self, at_init)\u001b[0m\n\u001b[1;32m 3446\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 3447\u001b[0m \u001b[39mraise\u001b[39;00m\n\u001b[0;32m-> 3449\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrepo\u001b[39m.\u001b[39;49mgit_pull()\n\u001b[1;32m 3451\u001b[0m \u001b[39m# By default, ignore the checkpoint folders\u001b[39;00m\n\u001b[1;32m 3452\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 3453\u001b[0m \u001b[39mnot\u001b[39;00m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mexists(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39moutput_dir, \u001b[39m\"\u001b[39m\u001b[39m.gitignore\u001b[39m\u001b[39m\"\u001b[39m))\n\u001b[1;32m 3454\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mhub_strategy \u001b[39m!=\u001b[39m HubStrategy\u001b[39m.\u001b[39mALL_CHECKPOINTS\n\u001b[1;32m 3455\u001b[0m ):\n",
695
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:987\u001b[0m, in \u001b[0;36mRepository.git_pull\u001b[0;34m(self, rebase, lfs)\u001b[0m\n\u001b[1;32m 985\u001b[0m logger\u001b[39m.\u001b[39minfo(result\u001b[39m.\u001b[39mstdout)\n\u001b[1;32m 986\u001b[0m \u001b[39mexcept\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError \u001b[39mas\u001b[39;00m exc:\n\u001b[0;32m--> 987\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(exc\u001b[39m.\u001b[39mstderr)\n",
696
+ "\u001b[0;31mOSError\u001b[0m: From https://huggingface.co/DuyTa/vi_whisper-small\n d7893fc..47c00b5 main -> origin/main\nhint: You have divergent branches and need to specify how to reconcile them.\nhint: You can do so by running one of the following commands sometime before\nhint: your next pull:\nhint: \nhint: git config pull.rebase false # merge (the default strategy)\nhint: git config pull.rebase true # rebase\nhint: git config pull.ff only # fast-forward only\nhint: \nhint: You can replace \"git config\" with \"git config --global\" to set a default\nhint: preference for all repositories. You can also pass --rebase, --no-rebase,\nhint: or --ff-only on the command line to override the configured default per\nhint: invocation.\nfatal: Need to specify how to reconcile divergent branches.\n"
697
+ ]
698
+ }
699
+ ],
700
+ "source": [
701
+ "from transformers import Seq2SeqTrainer\n",
702
+ "\n",
703
+ "trainer = Seq2SeqTrainer(\n",
704
+ " args=training_args,\n",
705
+ " model=model,\n",
706
+ " train_dataset=concat[\"train\"],\n",
707
+ "\n",
708
+ " eval_dataset=concat[\"test\"],\n",
709
+ " data_collator=data_collator,\n",
710
+ " compute_metrics=compute_metrics,\n",
711
+ " tokenizer=processor.feature_extractor,\n",
712
+ ")\n"
713
+ ]
714
+ },
715
+ {
716
+ "cell_type": "code",
717
+ "execution_count": 130,
718
+ "metadata": {},
719
+ "outputs": [
720
+ {
721
+ "data": {
722
+ "text/plain": [
723
+ "('./vi_whisper-small/tokenizer_config.json',\n",
724
+ " './vi_whisper-small/special_tokens_map.json',\n",
725
+ " './vi_whisper-small/vocab.json',\n",
726
+ " './vi_whisper-small/merges.txt',\n",
727
+ " './vi_whisper-small/normalizer.json',\n",
728
+ " './vi_whisper-small/added_tokens.json',\n",
729
+ " './vi_whisper-small/tokenizer.json')"
730
+ ]
731
+ },
732
+ "execution_count": 130,
733
+ "metadata": {},
734
+ "output_type": "execute_result"
735
+ }
736
+ ],
737
+ "source": [
738
+ "tokenizer.save_pretrained(\"./vi_whisper-small/\")"
739
+ ]
740
+ },
741
+ {
742
+ "cell_type": "code",
743
+ "execution_count": 31,
744
+ "metadata": {},
745
+ "outputs": [
746
+ {
747
+ "name": "stdout",
748
+ "output_type": "stream",
749
+ "text": [
750
+ "Device 0:\n",
751
+ " Currently allocated memory: 922.884765625 MB\n",
752
+ " Peak memory usage: 922.884765625 MB\n"
753
+ ]
754
+ }
755
+ ],
756
+ "source": [
757
+ "import torch\n",
758
+ "\n",
759
+ "device_count = torch.cuda.device_count()\n",
760
+ "\n",
761
+ "for device in range(device_count):\n",
762
+ " torch.cuda.device(device)\n",
763
+ " allocated_memory = torch.cuda.memory_allocated(device)\n",
764
+ " peak_memory = torch.cuda.max_memory_allocated(device)\n",
765
+ " print(f\"Device {device}:\")\n",
766
+ " print(f\" Currently allocated memory: {allocated_memory / 1024**2} MB\")\n",
767
+ " print(f\" Peak memory usage: {peak_memory / 1024**2} MB\")\n",
768
+ "\n"
769
+ ]
770
+ },
771
+ {
772
+ "cell_type": "code",
773
+ "execution_count": 32,
774
+ "metadata": {},
775
+ "outputs": [
776
+ {
777
+ "name": "stdout",
778
+ "output_type": "stream",
779
+ "text": [
780
+ "Device 0:\n",
781
+ " Name: Tesla T4\n",
782
+ " Max Memory: 14966.375 MB\n"
783
+ ]
784
+ }
785
+ ],
786
+ "source": [
787
+ "device_count = torch.cuda.device_count()\n",
788
+ "\n",
789
+ "for device in range(device_count):\n",
790
+ " properties = torch.cuda.get_device_properties(device)\n",
791
+ " print(f\"Device {device}:\")\n",
792
+ " print(f\" Name: {properties.name}\")\n",
793
+ " print(f\" Max Memory: {properties.total_memory / 1024**2} MB\")\n"
794
+ ]
795
+ },
796
+ {
797
+ "cell_type": "code",
798
+ "execution_count": 111,
799
+ "metadata": {},
800
+ "outputs": [
801
+ {
802
+ "name": "stderr",
803
+ "output_type": "stream",
804
+ "text": [
805
+ "/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
806
+ " warnings.warn(\n"
807
+ ]
808
+ },
809
+ {
810
+ "data": {
811
+ "application/vnd.jupyter.widget-view+json": {
812
+ "model_id": "da5a536a979e4d34bc59eaede9204d06",
813
+ "version_major": 2,
814
+ "version_minor": 0
815
+ },
816
+ "text/plain": [
817
+ " 0%| | 0/8000 [00:00<?, ?it/s]"
818
+ ]
819
+ },
820
+ "metadata": {},
821
+ "output_type": "display_data"
822
+ },
823
+ {
824
+ "name": "stdout",
825
+ "output_type": "stream",
826
+ "text": [
827
+ "{'loss': 3.8537, 'learning_rate': 2.1000000000000002e-06, 'epoch': 0.03}\n",
828
+ "{'loss': 2.2347, 'learning_rate': 4.6e-06, 'epoch': 0.06}\n",
829
+ "{'loss': 1.2627, 'learning_rate': 7.1e-06, 'epoch': 0.08}\n",
830
+ "{'loss': 0.8976, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.11}\n",
831
+ "{'loss': 0.7313, 'learning_rate': 1.2100000000000001e-05, 'epoch': 0.14}\n",
832
+ "{'loss': 0.6526, 'learning_rate': 1.4599999999999999e-05, 'epoch': 0.17}\n",
833
+ "{'loss': 0.7221, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.19}\n",
834
+ "{'loss': 0.6478, 'learning_rate': 1.9500000000000003e-05, 'epoch': 0.22}\n",
835
+ "{'loss': 1.7029, 'learning_rate': 2.19e-05, 'epoch': 0.25}\n",
836
+ "{'loss': 1.1476, 'learning_rate': 2.44e-05, 'epoch': 0.28}\n",
837
+ "{'loss': 0.5837, 'learning_rate': 2.6900000000000003e-05, 'epoch': 0.3}\n",
838
+ "{'loss': 0.5912, 'learning_rate': 2.94e-05, 'epoch': 0.33}\n",
839
+ "{'loss': 0.6872, 'learning_rate': 3.19e-05, 'epoch': 0.36}\n",
840
+ "{'loss': 0.4103, 'learning_rate': 3.4399999999999996e-05, 'epoch': 0.39}\n",
841
+ "{'loss': 0.4293, 'learning_rate': 3.69e-05, 'epoch': 0.41}\n",
842
+ "{'loss': 0.3055, 'learning_rate': 3.94e-05, 'epoch': 0.44}\n",
843
+ "{'loss': 0.311, 'learning_rate': 4.19e-05, 'epoch': 0.47}\n",
844
+ "{'loss': 0.3212, 'learning_rate': 4.44e-05, 'epoch': 0.5}\n",
845
+ "{'loss': 0.2917, 'learning_rate': 4.69e-05, 'epoch': 0.52}\n",
846
+ "{'loss': 0.2975, 'learning_rate': 4.94e-05, 'epoch': 0.55}\n",
847
+ "{'loss': 0.3254, 'learning_rate': 5.19e-05, 'epoch': 0.58}\n",
848
+ "{'loss': 0.2825, 'learning_rate': 5.440000000000001e-05, 'epoch': 0.61}\n",
849
+ "{'loss': 0.2929, 'learning_rate': 5.69e-05, 'epoch': 0.63}\n",
850
+ "{'loss': 0.3056, 'learning_rate': 5.94e-05, 'epoch': 0.66}\n",
851
+ "{'loss': 0.3105, 'learning_rate': 6.19e-05, 'epoch': 0.69}\n",
852
+ "{'loss': 0.3702, 'learning_rate': 6.440000000000001e-05, 'epoch': 0.72}\n",
853
+ "{'loss': 0.2684, 'learning_rate': 6.690000000000001e-05, 'epoch': 0.74}\n",
854
+ "{'loss': 0.2767, 'learning_rate': 6.939999999999999e-05, 'epoch': 0.77}\n",
855
+ "{'loss': 0.315, 'learning_rate': 7.19e-05, 'epoch': 0.8}\n",
856
+ "{'loss': 0.3132, 'learning_rate': 7.44e-05, 'epoch': 0.83}\n",
857
+ "{'loss': 0.3933, 'learning_rate': 7.69e-05, 'epoch': 0.85}\n",
858
+ "{'loss': 0.311, 'learning_rate': 7.94e-05, 'epoch': 0.88}\n",
859
+ "{'loss': 0.3104, 'learning_rate': 8.19e-05, 'epoch': 0.91}\n",
860
+ "{'loss': 0.297, 'learning_rate': 8.44e-05, 'epoch': 0.94}\n",
861
+ "{'loss': 0.3094, 'learning_rate': 8.69e-05, 'epoch': 0.96}\n",
862
+ "{'loss': 0.29, 'learning_rate': 8.94e-05, 'epoch': 0.99}\n",
863
+ "{'loss': 0.2712, 'learning_rate': 9.190000000000001e-05, 'epoch': 1.02}\n",
864
+ "{'loss': 0.262, 'learning_rate': 9.44e-05, 'epoch': 1.05}\n",
865
+ "{'loss': 0.2481, 'learning_rate': 9.69e-05, 'epoch': 1.07}\n",
866
+ "{'loss': 0.249, 'learning_rate': 9.94e-05, 'epoch': 1.1}\n"
867
+ ]
868
+ },
869
+ {
870
+ "data": {
871
+ "application/vnd.jupyter.widget-view+json": {
872
+ "model_id": "ea1e7cf193cc4b5dacd0883200ff6ef6",
873
+ "version_major": 2,
874
+ "version_minor": 0
875
+ },
876
+ "text/plain": [
877
+ " 0%| | 0/95 [00:00<?, ?it/s]"
878
+ ]
879
+ },
880
+ "metadata": {},
881
+ "output_type": "display_data"
882
+ },
883
+ {
884
+ "name": "stdout",
885
+ "output_type": "stream",
886
+ "text": [
887
+ "{'eval_loss': 0.3765707015991211, 'eval_wer': 32.16783216783217, 'eval_runtime': 349.1035, 'eval_samples_per_second': 2.177, 'eval_steps_per_second': 0.272, 'epoch': 1.1}\n",
888
+ "{'loss': 0.2729, 'learning_rate': 9.972857142857144e-05, 'epoch': 1.13}\n",
889
+ "{'loss': 0.267, 'learning_rate': 9.937142857142857e-05, 'epoch': 1.16}\n",
890
+ "{'loss': 0.2617, 'learning_rate': 9.901428571428571e-05, 'epoch': 1.18}\n",
891
+ "{'loss': 0.2613, 'learning_rate': 9.865714285714286e-05, 'epoch': 1.21}\n",
892
+ "{'loss': 0.2736, 'learning_rate': 9.83e-05, 'epoch': 1.24}\n",
893
+ "{'loss': 0.245, 'learning_rate': 9.794285714285714e-05, 'epoch': 1.27}\n",
894
+ "{'loss': 0.2385, 'learning_rate': 9.75857142857143e-05, 'epoch': 1.29}\n",
895
+ "{'loss': 0.258, 'learning_rate': 9.722857142857144e-05, 'epoch': 1.32}\n",
896
+ "{'loss': 0.2623, 'learning_rate': 9.687142857142858e-05, 'epoch': 1.35}\n",
897
+ "{'loss': 0.2346, 'learning_rate': 9.651428571428572e-05, 'epoch': 1.38}\n",
898
+ "{'loss': 0.2376, 'learning_rate': 9.615714285714286e-05, 'epoch': 1.4}\n",
899
+ "{'loss': 0.246, 'learning_rate': 9.58e-05, 'epoch': 1.43}\n",
900
+ "{'loss': 0.2201, 'learning_rate': 9.544285714285715e-05, 'epoch': 1.46}\n",
901
+ "{'loss': 0.2233, 'learning_rate': 9.508571428571429e-05, 'epoch': 1.49}\n",
902
+ "{'loss': 0.2154, 'learning_rate': 9.472857142857143e-05, 'epoch': 1.51}\n",
903
+ "{'loss': 0.2348, 'learning_rate': 9.437142857142857e-05, 'epoch': 1.54}\n",
904
+ "{'loss': 0.2159, 'learning_rate': 9.401428571428572e-05, 'epoch': 1.57}\n",
905
+ "{'loss': 0.2265, 'learning_rate': 9.365714285714286e-05, 'epoch': 1.6}\n",
906
+ "{'loss': 0.2118, 'learning_rate': 9.33e-05, 'epoch': 1.62}\n",
907
+ "{'loss': 0.2223, 'learning_rate': 9.294285714285714e-05, 'epoch': 1.65}\n",
908
+ "{'loss': 0.2, 'learning_rate': 9.258571428571428e-05, 'epoch': 1.68}\n",
909
+ "{'loss': 0.206, 'learning_rate': 9.222857142857142e-05, 'epoch': 1.71}\n",
910
+ "{'loss': 0.1979, 'learning_rate': 9.187142857142858e-05, 'epoch': 1.73}\n",
911
+ "{'loss': 0.2022, 'learning_rate': 9.151428571428572e-05, 'epoch': 1.76}\n",
912
+ "{'loss': 0.2028, 'learning_rate': 9.115714285714286e-05, 'epoch': 1.79}\n",
913
+ "{'loss': 0.2161, 'learning_rate': 9.080000000000001e-05, 'epoch': 1.82}\n",
914
+ "{'loss': 0.1964, 'learning_rate': 9.044285714285715e-05, 'epoch': 1.84}\n",
915
+ "{'loss': 0.2151, 'learning_rate': 9.008571428571429e-05, 'epoch': 1.87}\n",
916
+ "{'loss': 0.2056, 'learning_rate': 8.972857142857143e-05, 'epoch': 1.9}\n",
917
+ "{'loss': 0.189, 'learning_rate': 8.937142857142857e-05, 'epoch': 1.93}\n",
918
+ "{'loss': 0.1944, 'learning_rate': 8.901428571428571e-05, 'epoch': 1.95}\n",
919
+ "{'loss': 0.1834, 'learning_rate': 8.865714285714287e-05, 'epoch': 1.98}\n",
920
+ "{'loss': 0.1557, 'learning_rate': 8.83e-05, 'epoch': 2.01}\n",
921
+ "{'loss': 0.1337, 'learning_rate': 8.794285714285714e-05, 'epoch': 2.04}\n",
922
+ "{'loss': 0.1338, 'learning_rate': 8.75857142857143e-05, 'epoch': 2.06}\n",
923
+ "{'loss': 0.1338, 'learning_rate': 8.722857142857144e-05, 'epoch': 2.09}\n",
924
+ "{'loss': 0.1385, 'learning_rate': 8.687142857142856e-05, 'epoch': 2.12}\n",
925
+ "{'loss': 0.1259, 'learning_rate': 8.651428571428572e-05, 'epoch': 2.15}\n",
926
+ "{'loss': 0.1268, 'learning_rate': 8.615714285714286e-05, 'epoch': 2.18}\n",
927
+ "{'loss': 0.1416, 'learning_rate': 8.58e-05, 'epoch': 2.2}\n"
928
+ ]
929
+ },
930
+ {
931
+ "data": {
932
+ "application/vnd.jupyter.widget-view+json": {
933
+ "model_id": "2b48a513542d43558d8fef6a8ef00629",
934
+ "version_major": 2,
935
+ "version_minor": 0
936
+ },
937
+ "text/plain": [
938
+ " 0%| | 0/95 [00:00<?, ?it/s]"
939
+ ]
940
+ },
941
+ "metadata": {},
942
+ "output_type": "display_data"
943
+ },
944
+ {
945
+ "name": "stdout",
946
+ "output_type": "stream",
947
+ "text": [
948
+ "{'eval_loss': 0.2880653738975525, 'eval_wer': 46.464646464646464, 'eval_runtime': 337.0754, 'eval_samples_per_second': 2.255, 'eval_steps_per_second': 0.282, 'epoch': 2.2}\n",
949
+ "{'loss': 0.1271, 'learning_rate': 8.544285714285715e-05, 'epoch': 2.23}\n",
950
+ "{'loss': 0.1345, 'learning_rate': 8.508571428571429e-05, 'epoch': 2.26}\n",
951
+ "{'loss': 0.149, 'learning_rate': 8.472857142857143e-05, 'epoch': 2.29}\n",
952
+ "{'loss': 0.1289, 'learning_rate': 8.437142857142859e-05, 'epoch': 2.31}\n",
953
+ "{'loss': 0.1391, 'learning_rate': 8.401428571428573e-05, 'epoch': 2.34}\n",
954
+ "{'loss': 0.1532, 'learning_rate': 8.365714285714285e-05, 'epoch': 2.37}\n",
955
+ "{'loss': 0.1283, 'learning_rate': 8.33e-05, 'epoch': 2.4}\n",
956
+ "{'loss': 0.1336, 'learning_rate': 8.294285714285715e-05, 'epoch': 2.42}\n",
957
+ "{'loss': 0.129, 'learning_rate': 8.258571428571429e-05, 'epoch': 2.45}\n",
958
+ "{'loss': 0.1399, 'learning_rate': 8.222857142857144e-05, 'epoch': 2.48}\n",
959
+ "{'loss': 0.1411, 'learning_rate': 8.187142857142858e-05, 'epoch': 2.51}\n",
960
+ "{'loss': 0.1298, 'learning_rate': 8.151428571428572e-05, 'epoch': 2.53}\n",
961
+ "{'loss': 0.1397, 'learning_rate': 8.115714285714286e-05, 'epoch': 2.56}\n",
962
+ "{'loss': 0.1356, 'learning_rate': 8.080000000000001e-05, 'epoch': 2.59}\n",
963
+ "{'loss': 0.1366, 'learning_rate': 8.044285714285714e-05, 'epoch': 2.62}\n",
964
+ "{'loss': 0.1331, 'learning_rate': 8.008571428571429e-05, 'epoch': 2.64}\n",
965
+ "{'loss': 0.1297, 'learning_rate': 7.972857142857143e-05, 'epoch': 2.67}\n",
966
+ "{'loss': 0.1414, 'learning_rate': 7.937142857142857e-05, 'epoch': 2.7}\n",
967
+ "{'loss': 0.1189, 'learning_rate': 7.901428571428571e-05, 'epoch': 2.73}\n",
968
+ "{'loss': 0.1416, 'learning_rate': 7.865714285714287e-05, 'epoch': 2.75}\n",
969
+ "{'loss': 0.1378, 'learning_rate': 7.83e-05, 'epoch': 2.78}\n",
970
+ "{'loss': 0.1305, 'learning_rate': 7.794285714285715e-05, 'epoch': 2.81}\n",
971
+ "{'loss': 0.1571, 'learning_rate': 7.75857142857143e-05, 'epoch': 2.84}\n",
972
+ "{'loss': 0.1285, 'learning_rate': 7.722857142857143e-05, 'epoch': 2.86}\n",
973
+ "{'loss': 0.1339, 'learning_rate': 7.687142857142857e-05, 'epoch': 2.89}\n",
974
+ "{'loss': 0.1216, 'learning_rate': 7.651428571428572e-05, 'epoch': 2.92}\n",
975
+ "{'loss': 0.1321, 'learning_rate': 7.615714285714286e-05, 'epoch': 2.95}\n",
976
+ "{'loss': 0.1259, 'learning_rate': 7.58e-05, 'epoch': 2.97}\n",
977
+ "{'loss': 0.1259, 'learning_rate': 7.544285714285715e-05, 'epoch': 3.0}\n",
978
+ "{'loss': 0.0851, 'learning_rate': 7.508571428571429e-05, 'epoch': 3.03}\n",
979
+ "{'loss': 0.0764, 'learning_rate': 7.472857142857143e-05, 'epoch': 3.06}\n",
980
+ "{'loss': 0.0986, 'learning_rate': 7.438571428571429e-05, 'epoch': 3.08}\n",
981
+ "{'loss': 0.0883, 'learning_rate': 7.402857142857143e-05, 'epoch': 3.11}\n",
982
+ "{'loss': 0.0811, 'learning_rate': 7.367142857142858e-05, 'epoch': 3.14}\n",
983
+ "{'loss': 0.0872, 'learning_rate': 7.331428571428571e-05, 'epoch': 3.17}\n",
984
+ "{'loss': 0.0872, 'learning_rate': 7.295714285714286e-05, 'epoch': 3.19}\n",
985
+ "{'loss': 0.0805, 'learning_rate': 7.26e-05, 'epoch': 3.22}\n",
986
+ "{'loss': 0.0803, 'learning_rate': 7.224285714285714e-05, 'epoch': 3.25}\n",
987
+ "{'loss': 0.0753, 'learning_rate': 7.188571428571428e-05, 'epoch': 3.28}\n",
988
+ "{'loss': 0.0839, 'learning_rate': 7.152857142857144e-05, 'epoch': 3.3}\n"
989
+ ]
990
+ },
991
+ {
992
+ "data": {
993
+ "application/vnd.jupyter.widget-view+json": {
994
+ "model_id": "ce95b9d2a270464fbeab22454f214a1d",
995
+ "version_major": 2,
996
+ "version_minor": 0
997
+ },
998
+ "text/plain": [
999
+ " 0%| | 0/95 [00:00<?, ?it/s]"
1000
+ ]
1001
+ },
1002
+ "metadata": {},
1003
+ "output_type": "display_data"
1004
+ },
1005
+ {
1006
+ "name": "stdout",
1007
+ "output_type": "stream",
1008
+ "text": [
1009
+ "{'eval_loss': 0.279912531375885, 'eval_wer': 22.779072779072777, 'eval_runtime': 345.4945, 'eval_samples_per_second': 2.2, 'eval_steps_per_second': 0.275, 'epoch': 3.3}\n",
1010
+ "{'loss': 0.0885, 'learning_rate': 7.117142857142858e-05, 'epoch': 3.33}\n",
1011
+ "{'loss': 0.0845, 'learning_rate': 7.081428571428572e-05, 'epoch': 3.36}\n",
1012
+ "{'loss': 0.0761, 'learning_rate': 7.045714285714287e-05, 'epoch': 3.39}\n",
1013
+ "{'loss': 0.0756, 'learning_rate': 7.01e-05, 'epoch': 3.41}\n",
1014
+ "{'loss': 0.0859, 'learning_rate': 6.974285714285715e-05, 'epoch': 3.44}\n",
1015
+ "{'loss': 0.0972, 'learning_rate': 6.938571428571429e-05, 'epoch': 3.47}\n",
1016
+ "{'loss': 0.0822, 'learning_rate': 6.902857142857143e-05, 'epoch': 3.5}\n",
1017
+ "{'loss': 0.0892, 'learning_rate': 6.867142857142857e-05, 'epoch': 3.52}\n",
1018
+ "{'loss': 0.0735, 'learning_rate': 6.831428571428572e-05, 'epoch': 3.55}\n",
1019
+ "{'loss': 0.0893, 'learning_rate': 6.795714285714286e-05, 'epoch': 3.58}\n",
1020
+ "{'loss': 0.0869, 'learning_rate': 6.76e-05, 'epoch': 3.61}\n",
1021
+ "{'loss': 0.0877, 'learning_rate': 6.724285714285714e-05, 'epoch': 3.63}\n",
1022
+ "{'loss': 0.07, 'learning_rate': 6.688571428571428e-05, 'epoch': 3.66}\n",
1023
+ "{'loss': 0.0807, 'learning_rate': 6.652857142857142e-05, 'epoch': 3.69}\n",
1024
+ "{'loss': 0.0831, 'learning_rate': 6.617142857142858e-05, 'epoch': 3.72}\n",
1025
+ "{'loss': 0.0836, 'learning_rate': 6.581428571428572e-05, 'epoch': 3.74}\n",
1026
+ "{'loss': 0.0875, 'learning_rate': 6.545714285714286e-05, 'epoch': 3.77}\n",
1027
+ "{'loss': 0.0846, 'learning_rate': 6.510000000000001e-05, 'epoch': 3.8}\n",
1028
+ "{'loss': 0.0779, 'learning_rate': 6.474285714285715e-05, 'epoch': 3.83}\n",
1029
+ "{'loss': 0.0871, 'learning_rate': 6.438571428571429e-05, 'epoch': 3.85}\n",
1030
+ "{'loss': 0.0777, 'learning_rate': 6.402857142857143e-05, 'epoch': 3.88}\n",
1031
+ "{'loss': 0.0856, 'learning_rate': 6.367142857142857e-05, 'epoch': 3.91}\n",
1032
+ "{'loss': 0.083, 'learning_rate': 6.331428571428571e-05, 'epoch': 3.94}\n",
1033
+ "{'loss': 0.0667, 'learning_rate': 6.295714285714286e-05, 'epoch': 3.96}\n",
1034
+ "{'loss': 0.083, 'learning_rate': 6.26e-05, 'epoch': 3.99}\n",
1035
+ "{'loss': 0.0505, 'learning_rate': 6.224285714285714e-05, 'epoch': 4.02}\n",
1036
+ "{'loss': 0.0426, 'learning_rate': 6.18857142857143e-05, 'epoch': 4.05}\n",
1037
+ "{'loss': 0.0453, 'learning_rate': 6.152857142857144e-05, 'epoch': 4.07}\n",
1038
+ "{'loss': 0.0482, 'learning_rate': 6.117142857142858e-05, 'epoch': 4.1}\n",
1039
+ "{'loss': 0.0511, 'learning_rate': 6.081428571428571e-05, 'epoch': 4.13}\n",
1040
+ "{'loss': 0.0583, 'learning_rate': 6.045714285714286e-05, 'epoch': 4.16}\n",
1041
+ "{'loss': 0.0466, 'learning_rate': 6.0100000000000004e-05, 'epoch': 4.19}\n",
1042
+ "{'loss': 0.0502, 'learning_rate': 5.9742857142857144e-05, 'epoch': 4.21}\n",
1043
+ "{'loss': 0.0414, 'learning_rate': 5.938571428571429e-05, 'epoch': 4.24}\n",
1044
+ "{'loss': 0.0501, 'learning_rate': 5.902857142857143e-05, 'epoch': 4.27}\n",
1045
+ "{'loss': 0.0478, 'learning_rate': 5.867142857142858e-05, 'epoch': 4.3}\n",
1046
+ "{'loss': 0.0482, 'learning_rate': 5.8314285714285724e-05, 'epoch': 4.32}\n",
1047
+ "{'loss': 0.0463, 'learning_rate': 5.7957142857142864e-05, 'epoch': 4.35}\n",
1048
+ "{'loss': 0.0513, 'learning_rate': 5.76e-05, 'epoch': 4.38}\n",
1049
+ "{'loss': 0.0546, 'learning_rate': 5.7242857142857144e-05, 'epoch': 4.41}\n"
1050
+ ]
1051
+ },
1052
+ {
1053
+ "data": {
1054
+ "application/vnd.jupyter.widget-view+json": {
1055
+ "model_id": "5b1ec44b408a4de8aa8a3e9702dae453",
1056
+ "version_major": 2,
1057
+ "version_minor": 0
1058
+ },
1059
+ "text/plain": [
1060
+ " 0%| | 0/95 [00:00<?, ?it/s]"
1061
+ ]
1062
+ },
1063
+ "metadata": {},
1064
+ "output_type": "display_data"
1065
+ },
1066
+ {
1067
+ "name": "stdout",
1068
+ "output_type": "stream",
1069
+ "text": [
1070
+ "{'eval_loss': 0.28944167494773865, 'eval_wer': 21.885521885521886, 'eval_runtime': 344.5818, 'eval_samples_per_second': 2.206, 'eval_steps_per_second': 0.276, 'epoch': 4.41}\n",
1071
+ "{'loss': 0.0515, 'learning_rate': 5.6885714285714284e-05, 'epoch': 4.43}\n",
1072
+ "{'loss': 0.0394, 'learning_rate': 5.652857142857143e-05, 'epoch': 4.46}\n",
1073
+ "{'loss': 0.0562, 'learning_rate': 5.617142857142858e-05, 'epoch': 4.49}\n",
1074
+ "{'loss': 0.0532, 'learning_rate': 5.581428571428572e-05, 'epoch': 4.52}\n",
1075
+ "{'loss': 0.0525, 'learning_rate': 5.5457142857142864e-05, 'epoch': 4.54}\n",
1076
+ "{'loss': 0.0553, 'learning_rate': 5.5100000000000004e-05, 'epoch': 4.57}\n",
1077
+ "{'loss': 0.0464, 'learning_rate': 5.474285714285714e-05, 'epoch': 4.6}\n",
1078
+ "{'loss': 0.0425, 'learning_rate': 5.4385714285714284e-05, 'epoch': 4.63}\n",
1079
+ "{'loss': 0.0529, 'learning_rate': 5.402857142857143e-05, 'epoch': 4.65}\n",
1080
+ "{'loss': 0.0534, 'learning_rate': 5.367142857142857e-05, 'epoch': 4.68}\n",
1081
+ "{'loss': 0.0505, 'learning_rate': 5.331428571428572e-05, 'epoch': 4.71}\n",
1082
+ "{'loss': 0.0416, 'learning_rate': 5.295714285714286e-05, 'epoch': 4.74}\n",
1083
+ "{'loss': 0.0438, 'learning_rate': 5.2600000000000005e-05, 'epoch': 4.76}\n",
1084
+ "{'loss': 0.0568, 'learning_rate': 5.224285714285715e-05, 'epoch': 4.79}\n",
1085
+ "{'loss': 0.0519, 'learning_rate': 5.188571428571429e-05, 'epoch': 4.82}\n",
1086
+ "{'loss': 0.0415, 'learning_rate': 5.1528571428571425e-05, 'epoch': 4.85}\n",
1087
+ "{'loss': 0.0502, 'learning_rate': 5.117142857142857e-05, 'epoch': 4.87}\n",
1088
+ "{'loss': 0.0433, 'learning_rate': 5.081428571428571e-05, 'epoch': 4.9}\n",
1089
+ "{'loss': 0.0527, 'learning_rate': 5.045714285714286e-05, 'epoch': 4.93}\n",
1090
+ "{'loss': 0.0434, 'learning_rate': 5.0100000000000005e-05, 'epoch': 4.96}\n",
1091
+ "{'loss': 0.0485, 'learning_rate': 4.9742857142857145e-05, 'epoch': 4.98}\n",
1092
+ "{'loss': 0.0358, 'learning_rate': 4.938571428571429e-05, 'epoch': 5.01}\n",
1093
+ "{'loss': 0.0218, 'learning_rate': 4.902857142857143e-05, 'epoch': 5.04}\n",
1094
+ "{'loss': 0.0245, 'learning_rate': 4.867142857142857e-05, 'epoch': 5.07}\n",
1095
+ "{'loss': 0.0272, 'learning_rate': 4.831428571428572e-05, 'epoch': 5.09}\n",
1096
+ "{'loss': 0.0258, 'learning_rate': 4.795714285714286e-05, 'epoch': 5.12}\n",
1097
+ "{'loss': 0.0228, 'learning_rate': 4.76e-05, 'epoch': 5.15}\n",
1098
+ "{'loss': 0.0275, 'learning_rate': 4.7242857142857145e-05, 'epoch': 5.18}\n",
1099
+ "{'loss': 0.0269, 'learning_rate': 4.6885714285714285e-05, 'epoch': 5.2}\n",
1100
+ "{'loss': 0.0237, 'learning_rate': 4.652857142857143e-05, 'epoch': 5.23}\n",
1101
+ "{'loss': 0.0288, 'learning_rate': 4.617142857142857e-05, 'epoch': 5.26}\n",
1102
+ "{'loss': 0.0269, 'learning_rate': 4.581428571428572e-05, 'epoch': 5.29}\n",
1103
+ "{'loss': 0.0276, 'learning_rate': 4.545714285714286e-05, 'epoch': 5.31}\n",
1104
+ "{'loss': 0.0276, 'learning_rate': 4.5100000000000005e-05, 'epoch': 5.34}\n",
1105
+ "{'loss': 0.0242, 'learning_rate': 4.4742857142857145e-05, 'epoch': 5.37}\n",
1106
+ "{'loss': 0.0238, 'learning_rate': 4.4385714285714285e-05, 'epoch': 5.4}\n",
1107
+ "{'loss': 0.0302, 'learning_rate': 4.402857142857143e-05, 'epoch': 5.42}\n",
1108
+ "{'loss': 0.0253, 'learning_rate': 4.367142857142857e-05, 'epoch': 5.45}\n",
1109
+ "{'loss': 0.0256, 'learning_rate': 4.331428571428572e-05, 'epoch': 5.48}\n",
1110
+ "{'loss': 0.0256, 'learning_rate': 4.295714285714286e-05, 'epoch': 5.51}\n"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "data": {
1115
+ "application/vnd.jupyter.widget-view+json": {
1116
+ "model_id": "2a3e7ace41cd44768ebc093aa571360e",
1117
+ "version_major": 2,
1118
+ "version_minor": 0
1119
+ },
1120
+ "text/plain": [
1121
+ " 0%| | 0/95 [00:00<?, ?it/s]"
1122
+ ]
1123
+ },
1124
+ "metadata": {},
1125
+ "output_type": "display_data"
1126
+ },
1127
+ {
1128
+ "name": "stdout",
1129
+ "output_type": "stream",
1130
+ "text": [
1131
+ "{'eval_loss': 0.3023395836353302, 'eval_wer': 32.2973322973323, 'eval_runtime': 361.3589, 'eval_samples_per_second': 2.103, 'eval_steps_per_second': 0.263, 'epoch': 5.51}\n",
1132
+ "{'loss': 0.0215, 'learning_rate': 4.26e-05, 'epoch': 5.53}\n",
1133
+ "{'loss': 0.0272, 'learning_rate': 4.2242857142857145e-05, 'epoch': 5.56}\n",
1134
+ "{'loss': 0.0268, 'learning_rate': 4.188571428571429e-05, 'epoch': 5.59}\n",
1135
+ "{'loss': 0.028, 'learning_rate': 4.1528571428571425e-05, 'epoch': 5.62}\n",
1136
+ "{'loss': 0.0209, 'learning_rate': 4.117142857142857e-05, 'epoch': 5.64}\n",
1137
+ "{'loss': 0.0258, 'learning_rate': 4.081428571428572e-05, 'epoch': 5.67}\n",
1138
+ "{'loss': 0.0249, 'learning_rate': 4.045714285714286e-05, 'epoch': 5.7}\n",
1139
+ "{'loss': 0.0249, 'learning_rate': 4.0100000000000006e-05, 'epoch': 5.73}\n",
1140
+ "{'loss': 0.0209, 'learning_rate': 3.9742857142857146e-05, 'epoch': 5.75}\n",
1141
+ "{'loss': 0.02, 'learning_rate': 3.9385714285714286e-05, 'epoch': 5.78}\n",
1142
+ "{'loss': 0.0244, 'learning_rate': 3.902857142857143e-05, 'epoch': 5.81}\n",
1143
+ "{'loss': 0.025, 'learning_rate': 3.867142857142857e-05, 'epoch': 5.84}\n",
1144
+ "{'loss': 0.0282, 'learning_rate': 3.831428571428571e-05, 'epoch': 5.86}\n",
1145
+ "{'loss': 0.0271, 'learning_rate': 3.795714285714286e-05, 'epoch': 5.89}\n",
1146
+ "{'loss': 0.0233, 'learning_rate': 3.76e-05, 'epoch': 5.92}\n",
1147
+ "{'loss': 0.0219, 'learning_rate': 3.7242857142857146e-05, 'epoch': 5.95}\n",
1148
+ "{'loss': 0.0232, 'learning_rate': 3.688571428571429e-05, 'epoch': 5.97}\n",
1149
+ "{'loss': 0.019, 'learning_rate': 3.6528571428571426e-05, 'epoch': 6.0}\n",
1150
+ "{'loss': 0.0152, 'learning_rate': 3.617142857142857e-05, 'epoch': 6.03}\n",
1151
+ "{'loss': 0.0111, 'learning_rate': 3.581428571428572e-05, 'epoch': 6.06}\n",
1152
+ "{'loss': 0.0162, 'learning_rate': 3.545714285714286e-05, 'epoch': 6.08}\n",
1153
+ "{'loss': 0.0126, 'learning_rate': 3.51e-05, 'epoch': 6.11}\n",
1154
+ "{'loss': 0.012, 'learning_rate': 3.4742857142857146e-05, 'epoch': 6.14}\n",
1155
+ "{'loss': 0.0153, 'learning_rate': 3.4385714285714286e-05, 'epoch': 6.17}\n",
1156
+ "{'loss': 0.0133, 'learning_rate': 3.402857142857143e-05, 'epoch': 6.19}\n",
1157
+ "{'loss': 0.0112, 'learning_rate': 3.367142857142857e-05, 'epoch': 6.22}\n",
1158
+ "{'loss': 0.0187, 'learning_rate': 3.331428571428571e-05, 'epoch': 6.25}\n",
1159
+ "{'loss': 0.0134, 'learning_rate': 3.295714285714286e-05, 'epoch': 6.28}\n",
1160
+ "{'loss': 0.0112, 'learning_rate': 3.26e-05, 'epoch': 6.31}\n",
1161
+ "{'loss': 0.0096, 'learning_rate': 3.2242857142857146e-05, 'epoch': 6.33}\n",
1162
+ "{'loss': 0.0112, 'learning_rate': 3.1885714285714286e-05, 'epoch': 6.36}\n",
1163
+ "{'loss': 0.0146, 'learning_rate': 3.1528571428571426e-05, 'epoch': 6.39}\n",
1164
+ "{'loss': 0.0106, 'learning_rate': 3.117142857142857e-05, 'epoch': 6.42}\n",
1165
+ "{'loss': 0.01, 'learning_rate': 3.081428571428572e-05, 'epoch': 6.44}\n",
1166
+ "{'loss': 0.0117, 'learning_rate': 3.0457142857142856e-05, 'epoch': 6.47}\n",
1167
+ "{'loss': 0.0135, 'learning_rate': 3.01e-05, 'epoch': 6.5}\n",
1168
+ "{'loss': 0.0137, 'learning_rate': 2.9742857142857143e-05, 'epoch': 6.53}\n",
1169
+ "{'loss': 0.0089, 'learning_rate': 2.938571428571429e-05, 'epoch': 6.55}\n",
1170
+ "{'loss': 0.0096, 'learning_rate': 2.9028571428571427e-05, 'epoch': 6.58}\n",
1171
+ "{'loss': 0.0111, 'learning_rate': 2.867142857142857e-05, 'epoch': 6.61}\n"
1172
+ ]
1173
+ },
1174
+ {
1175
+ "data": {
1176
+ "application/vnd.jupyter.widget-view+json": {
1177
+ "model_id": "a911edc03d0b43edbd4c59958024dca9",
1178
+ "version_major": 2,
1179
+ "version_minor": 0
1180
+ },
1181
+ "text/plain": [
1182
+ " 0%| | 0/95 [00:00<?, ?it/s]"
1183
+ ]
1184
+ },
1185
+ "metadata": {},
1186
+ "output_type": "display_data"
1187
+ },
1188
+ {
1189
+ "name": "stdout",
1190
+ "output_type": "stream",
1191
+ "text": [
1192
+ "{'eval_loss': 0.3060542345046997, 'eval_wer': 31.015281015281015, 'eval_runtime': 366.4932, 'eval_samples_per_second': 2.074, 'eval_steps_per_second': 0.259, 'epoch': 6.61}\n",
1193
+ "{'loss': 0.0091, 'learning_rate': 2.8314285714285717e-05, 'epoch': 6.64}\n",
1194
+ "{'loss': 0.0075, 'learning_rate': 2.795714285714286e-05, 'epoch': 6.66}\n",
1195
+ "{'loss': 0.0096, 'learning_rate': 2.7600000000000003e-05, 'epoch': 6.69}\n",
1196
+ "{'loss': 0.0071, 'learning_rate': 2.7242857142857143e-05, 'epoch': 6.72}\n",
1197
+ "{'loss': 0.0089, 'learning_rate': 2.6885714285714287e-05, 'epoch': 6.75}\n",
1198
+ "{'loss': 0.0103, 'learning_rate': 2.652857142857143e-05, 'epoch': 6.77}\n",
1199
+ "{'loss': 0.0125, 'learning_rate': 2.6171428571428574e-05, 'epoch': 6.8}\n",
1200
+ "{'loss': 0.0082, 'learning_rate': 2.5814285714285713e-05, 'epoch': 6.83}\n",
1201
+ "{'loss': 0.0079, 'learning_rate': 2.5457142857142857e-05, 'epoch': 6.86}\n",
1202
+ "{'loss': 0.0108, 'learning_rate': 2.51e-05, 'epoch': 6.88}\n",
1203
+ "{'loss': 0.0084, 'learning_rate': 2.4742857142857147e-05, 'epoch': 6.91}\n",
1204
+ "{'loss': 0.0107, 'learning_rate': 2.4385714285714287e-05, 'epoch': 6.94}\n",
1205
+ "{'loss': 0.009, 'learning_rate': 2.402857142857143e-05, 'epoch': 6.97}\n",
1206
+ "{'loss': 0.0081, 'learning_rate': 2.3671428571428574e-05, 'epoch': 6.99}\n",
1207
+ "{'loss': 0.0077, 'learning_rate': 2.3314285714285717e-05, 'epoch': 7.02}\n",
1208
+ "{'loss': 0.0064, 'learning_rate': 2.2957142857142857e-05, 'epoch': 7.05}\n",
1209
+ "{'loss': 0.0079, 'learning_rate': 2.26e-05, 'epoch': 7.08}\n",
1210
+ "{'loss': 0.0063, 'learning_rate': 2.2242857142857144e-05, 'epoch': 7.1}\n",
1211
+ "{'loss': 0.0044, 'learning_rate': 2.1885714285714287e-05, 'epoch': 7.13}\n",
1212
+ "{'loss': 0.0041, 'learning_rate': 2.1528571428571427e-05, 'epoch': 7.16}\n",
1213
+ "{'loss': 0.0048, 'learning_rate': 2.1171428571428574e-05, 'epoch': 7.19}\n",
1214
+ "{'loss': 0.0041, 'learning_rate': 2.0814285714285714e-05, 'epoch': 7.21}\n",
1215
+ "{'loss': 0.0031, 'learning_rate': 2.0457142857142857e-05, 'epoch': 7.24}\n",
1216
+ "{'loss': 0.0026, 'learning_rate': 2.01e-05, 'epoch': 7.27}\n",
1217
+ "{'loss': 0.0031, 'learning_rate': 1.9742857142857144e-05, 'epoch': 7.3}\n",
1218
+ "{'loss': 0.0029, 'learning_rate': 1.9385714285714287e-05, 'epoch': 7.32}\n",
1219
+ "{'loss': 0.0045, 'learning_rate': 1.9028571428571427e-05, 'epoch': 7.35}\n",
1220
+ "{'loss': 0.0024, 'learning_rate': 1.8671428571428574e-05, 'epoch': 7.38}\n",
1221
+ "{'loss': 0.002, 'learning_rate': 1.8314285714285714e-05, 'epoch': 7.41}\n",
1222
+ "{'loss': 0.0038, 'learning_rate': 1.7957142857142858e-05, 'epoch': 7.43}\n",
1223
+ "{'loss': 0.0035, 'learning_rate': 1.76e-05, 'epoch': 7.46}\n",
1224
+ "{'loss': 0.0058, 'learning_rate': 1.7242857142857144e-05, 'epoch': 7.49}\n",
1225
+ "{'loss': 0.0034, 'learning_rate': 1.6885714285714284e-05, 'epoch': 7.52}\n",
1226
+ "{'loss': 0.0036, 'learning_rate': 1.652857142857143e-05, 'epoch': 7.54}\n",
1227
+ "{'loss': 0.0031, 'learning_rate': 1.6171428571428574e-05, 'epoch': 7.57}\n",
1228
+ "{'loss': 0.0041, 'learning_rate': 1.5814285714285714e-05, 'epoch': 7.6}\n",
1229
+ "{'loss': 0.0021, 'learning_rate': 1.5457142857142858e-05, 'epoch': 7.63}\n",
1230
+ "{'loss': 0.0032, 'learning_rate': 1.51e-05, 'epoch': 7.65}\n",
1231
+ "{'loss': 0.0039, 'learning_rate': 1.4742857142857144e-05, 'epoch': 7.68}\n",
1232
+ "{'loss': 0.0028, 'learning_rate': 1.4385714285714286e-05, 'epoch': 7.71}\n"
1233
+ ]
1234
+ },
1235
+ {
1236
+ "data": {
1237
+ "application/vnd.jupyter.widget-view+json": {
1238
+ "model_id": "3ff56df660d1427daf8f920cd57d67fc",
1239
+ "version_major": 2,
1240
+ "version_minor": 0
1241
+ },
1242
+ "text/plain": [
1243
+ " 0%| | 0/95 [00:00<?, ?it/s]"
1244
+ ]
1245
+ },
1246
+ "metadata": {},
1247
+ "output_type": "display_data"
1248
+ },
1249
+ {
1250
+ "name": "stdout",
1251
+ "output_type": "stream",
1252
+ "text": [
1253
+ "{'eval_loss': 0.3143082559108734, 'eval_wer': 27.169127169127172, 'eval_runtime': 357.6908, 'eval_samples_per_second': 2.125, 'eval_steps_per_second': 0.266, 'epoch': 7.71}\n",
1254
+ "{'loss': 0.0032, 'learning_rate': 1.402857142857143e-05, 'epoch': 7.74}\n",
1255
+ "{'loss': 0.0028, 'learning_rate': 1.3671428571428571e-05, 'epoch': 7.76}\n",
1256
+ "{'loss': 0.0033, 'learning_rate': 1.3314285714285715e-05, 'epoch': 7.79}\n",
1257
+ "{'loss': 0.0052, 'learning_rate': 1.2957142857142856e-05, 'epoch': 7.82}\n",
1258
+ "{'loss': 0.0023, 'learning_rate': 1.2600000000000001e-05, 'epoch': 7.85}\n",
1259
+ "{'loss': 0.0034, 'learning_rate': 1.2242857142857143e-05, 'epoch': 7.87}\n",
1260
+ "{'loss': 0.0026, 'learning_rate': 1.1885714285714286e-05, 'epoch': 7.9}\n",
1261
+ "{'loss': 0.0022, 'learning_rate': 1.1528571428571428e-05, 'epoch': 7.93}\n",
1262
+ "{'loss': 0.0037, 'learning_rate': 1.1171428571428571e-05, 'epoch': 7.96}\n",
1263
+ "{'loss': 0.003, 'learning_rate': 1.0814285714285715e-05, 'epoch': 7.98}\n",
1264
+ "{'loss': 0.0009, 'learning_rate': 1.0457142857142856e-05, 'epoch': 8.01}\n",
1265
+ "{'loss': 0.0006, 'learning_rate': 1.0100000000000002e-05, 'epoch': 8.04}\n",
1266
+ "{'loss': 0.0022, 'learning_rate': 9.742857142857143e-06, 'epoch': 8.07}\n",
1267
+ "{'loss': 0.0005, 'learning_rate': 9.385714285714287e-06, 'epoch': 8.09}\n",
1268
+ "{'loss': 0.0007, 'learning_rate': 9.02857142857143e-06, 'epoch': 8.12}\n",
1269
+ "{'loss': 0.0006, 'learning_rate': 8.671428571428572e-06, 'epoch': 8.15}\n",
1270
+ "{'loss': 0.0006, 'learning_rate': 8.314285714285715e-06, 'epoch': 8.18}\n",
1271
+ "{'loss': 0.0005, 'learning_rate': 7.957142857142857e-06, 'epoch': 8.2}\n",
1272
+ "{'loss': 0.0024, 'learning_rate': 7.6e-06, 'epoch': 8.23}\n",
1273
+ "{'loss': 0.0009, 'learning_rate': 7.242857142857143e-06, 'epoch': 8.26}\n",
1274
+ "{'loss': 0.0004, 'learning_rate': 6.885714285714286e-06, 'epoch': 8.29}\n",
1275
+ "{'loss': 0.0006, 'learning_rate': 6.5285714285714285e-06, 'epoch': 8.31}\n",
1276
+ "{'loss': 0.0031, 'learning_rate': 6.171428571428572e-06, 'epoch': 8.34}\n",
1277
+ "{'loss': 0.001, 'learning_rate': 5.814285714285714e-06, 'epoch': 8.37}\n",
1278
+ "{'loss': 0.0013, 'learning_rate': 5.457142857142857e-06, 'epoch': 8.4}\n",
1279
+ "{'loss': 0.0011, 'learning_rate': 5.1e-06, 'epoch': 8.43}\n",
1280
+ "{'loss': 0.0004, 'learning_rate': 4.742857142857144e-06, 'epoch': 8.45}\n",
1281
+ "{'loss': 0.0012, 'learning_rate': 4.385714285714286e-06, 'epoch': 8.48}\n",
1282
+ "{'loss': 0.0013, 'learning_rate': 4.028571428571429e-06, 'epoch': 8.51}\n",
1283
+ "{'loss': 0.0008, 'learning_rate': 3.6714285714285717e-06, 'epoch': 8.54}\n",
1284
+ "{'loss': 0.0009, 'learning_rate': 3.314285714285714e-06, 'epoch': 8.56}\n",
1285
+ "{'loss': 0.0019, 'learning_rate': 2.957142857142857e-06, 'epoch': 8.59}\n",
1286
+ "{'loss': 0.0006, 'learning_rate': 2.6e-06, 'epoch': 8.62}\n",
1287
+ "{'loss': 0.0004, 'learning_rate': 2.242857142857143e-06, 'epoch': 8.65}\n",
1288
+ "{'loss': 0.0005, 'learning_rate': 1.8857142857142858e-06, 'epoch': 8.67}\n",
1289
+ "{'loss': 0.0005, 'learning_rate': 1.5285714285714287e-06, 'epoch': 8.7}\n",
1290
+ "{'loss': 0.001, 'learning_rate': 1.1714285714285715e-06, 'epoch': 8.73}\n",
1291
+ "{'loss': 0.0012, 'learning_rate': 8.142857142857143e-07, 'epoch': 8.76}\n",
1292
+ "{'loss': 0.001, 'learning_rate': 4.571428571428572e-07, 'epoch': 8.78}\n",
1293
+ "{'loss': 0.0014, 'learning_rate': 1.0000000000000001e-07, 'epoch': 8.81}\n"
1294
+ ]
1295
+ },
1296
+ {
1297
+ "data": {
1298
+ "application/vnd.jupyter.widget-view+json": {
1299
+ "model_id": "5120f3a800554c1689dfb561c55440fb",
1300
+ "version_major": 2,
1301
+ "version_minor": 0
1302
+ },
1303
+ "text/plain": [
1304
+ " 0%| | 0/95 [00:00<?, ?it/s]"
1305
+ ]
1306
+ },
1307
+ "metadata": {},
1308
+ "output_type": "display_data"
1309
+ },
1310
+ {
1311
+ "name": "stdout",
1312
+ "output_type": "stream",
1313
+ "text": [
1314
+ "{'eval_loss': 0.318661630153656, 'eval_wer': 27.36337736337736, 'eval_runtime': 356.3989, 'eval_samples_per_second': 2.132, 'eval_steps_per_second': 0.267, 'epoch': 8.81}\n",
1315
+ "{'train_runtime': 48216.7702, 'train_samples_per_second': 2.655, 'train_steps_per_second': 0.166, 'train_loss': 0.13303363310021815, 'epoch': 8.81}\n"
1316
+ ]
1317
+ },
1318
+ {
1319
+ "data": {
1320
+ "text/plain": [
1321
+ "TrainOutput(global_step=8000, training_loss=0.13303363310021815, metrics={'train_runtime': 48216.7702, 'train_samples_per_second': 2.655, 'train_steps_per_second': 0.166, 'train_loss': 0.13303363310021815, 'epoch': 8.81})"
1322
+ ]
1323
+ },
1324
+ "execution_count": 111,
1325
+ "metadata": {},
1326
+ "output_type": "execute_result"
1327
+ }
1328
+ ],
1329
+ "source": [
1330
+ "trainer.train()"
1331
+ ]
1332
+ },
1333
+ {
1334
+ "cell_type": "code",
1335
+ "execution_count": 34,
1336
+ "metadata": {},
1337
+ "outputs": [
1338
+ {
1339
+ "ename": "OSError",
1340
+ "evalue": "It looks like the config file at './whisper-base-vi/pytorch_model.bin' is not a valid JSON file.",
1341
+ "output_type": "error",
1342
+ "traceback": [
1343
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1344
+ "\u001b[0;31mUnicodeDecodeError\u001b[0m Traceback (most recent call last)",
1345
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:702\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 700\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 701\u001b[0m \u001b[39m# Load config dict\u001b[39;00m\n\u001b[0;32m--> 702\u001b[0m config_dict \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_dict_from_json_file(resolved_config_file)\n\u001b[1;32m 703\u001b[0m config_dict[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m commit_hash\n",
1346
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:793\u001b[0m, in \u001b[0;36mPretrainedConfig._dict_from_json_file\u001b[0;34m(cls, json_file)\u001b[0m\n\u001b[1;32m 792\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(json_file, \u001b[39m\"\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m, encoding\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m reader:\n\u001b[0;32m--> 793\u001b[0m text \u001b[39m=\u001b[39m reader\u001b[39m.\u001b[39;49mread()\n\u001b[1;32m 794\u001b[0m \u001b[39mreturn\u001b[39;00m json\u001b[39m.\u001b[39mloads(text)\n",
1347
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/codecs.py:322\u001b[0m, in \u001b[0;36mBufferedIncrementalDecoder.decode\u001b[0;34m(self, input, final)\u001b[0m\n\u001b[1;32m 321\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuffer \u001b[39m+\u001b[39m \u001b[39minput\u001b[39m\n\u001b[0;32m--> 322\u001b[0m (result, consumed) \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_buffer_decode(data, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49merrors, final)\n\u001b[1;32m 323\u001b[0m \u001b[39m# keep undecoded input until the next call\u001b[39;00m\n",
1348
+ "\u001b[0;31mUnicodeDecodeError\u001b[0m: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte",
1349
+ "\nDuring handling of the above exception, another exception occurred:\n",
1350
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
1351
+ "Cell \u001b[0;32mIn[34], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m pt_model \u001b[39m=\u001b[39m WhisperForConditionalGeneration\u001b[39m.\u001b[39;49mfrom_pretrained(\u001b[39m\"\u001b[39;49m\u001b[39m./whisper-base-vi/pytorch_model.bin\u001b[39;49m\u001b[39m\"\u001b[39;49m, from_tf\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 2\u001b[0m pt_model\u001b[39m.\u001b[39msave_pretrained(\u001b[39m\"\u001b[39m\u001b[39m./whisper-base-vi/vi_whisper.pt\u001b[39m\u001b[39m\"\u001b[39m)\n",
1352
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/modeling_utils.py:2325\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 2324\u001b[0m config_path \u001b[39m=\u001b[39m config \u001b[39mif\u001b[39;00m config \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 2325\u001b[0m config, model_kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mconfig_class\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m 2326\u001b[0m config_path,\n\u001b[1;32m 2327\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2328\u001b[0m return_unused_kwargs\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 2329\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 2330\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 2331\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 2332\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 2333\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 2334\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2335\u001b[0m subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m 2336\u001b[0m _from_auto\u001b[39m=\u001b[39;49mfrom_auto_class,\n\u001b[1;32m 2337\u001b[0m _from_pipeline\u001b[39m=\u001b[39;49mfrom_pipeline,\n\u001b[1;32m 2338\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 2339\u001b[0m )\n\u001b[1;32m 2340\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2341\u001b[0m model_kwargs \u001b[39m=\u001b[39m kwargs\n",
1353
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:590\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 586\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mrevision\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m revision\n\u001b[1;32m 588\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_set_token_in_kwargs(kwargs, token)\n\u001b[0;32m--> 590\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mget_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 591\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict \u001b[39mand\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mcls\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mand\u001b[39;00m config_dict[\u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m!=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type:\n\u001b[1;32m 592\u001b[0m logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m 593\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou are using a model of type \u001b[39m\u001b[39m{\u001b[39;00mconfig_dict[\u001b[39m'\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m}\u001b[39;00m\u001b[39m to instantiate a model of type \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 594\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type\u001b[39m}\u001b[39;00m\u001b[39m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 595\u001b[0m )\n",
1354
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:617\u001b[0m, in \u001b[0;36mPretrainedConfig.get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m original_kwargs \u001b[39m=\u001b[39m copy\u001b[39m.\u001b[39mdeepcopy(kwargs)\n\u001b[1;32m 616\u001b[0m \u001b[39m# Get config dict associated with the base config file\u001b[39;00m\n\u001b[0;32m--> 617\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_get_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 618\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict:\n\u001b[1;32m 619\u001b[0m original_kwargs[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m config_dict[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m]\n",
1355
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:705\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 703\u001b[0m config_dict[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m commit_hash\n\u001b[1;32m 704\u001b[0m \u001b[39mexcept\u001b[39;00m (json\u001b[39m.\u001b[39mJSONDecodeError, \u001b[39mUnicodeDecodeError\u001b[39;00m):\n\u001b[0;32m--> 705\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 706\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIt looks like the config file at \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mresolved_config_file\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m is not a valid JSON file.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 707\u001b[0m )\n\u001b[1;32m 709\u001b[0m \u001b[39mif\u001b[39;00m is_local:\n\u001b[1;32m 710\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mloading configuration file \u001b[39m\u001b[39m{\u001b[39;00mresolved_config_file\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
1356
+ "\u001b[0;31mOSError\u001b[0m: It looks like the config file at './whisper-base-vi/pytorch_model.bin' is not a valid JSON file."
1357
+ ]
1358
+ }
1359
+ ],
1360
+ "source": [
1361
+ "pt_model = WhisperForConditionalGeneration.from_pretrained(\"./whisper-base-vi/pytorch_model.bin\", from_tf=True)\n",
1362
+ "pt_model.save_pretrained(\"./whisper-base-vi/vi_whisper.pt\")"
1363
+ ]
1364
+ },
1365
+ {
1366
+ "cell_type": "code",
1367
+ "execution_count": null,
1368
+ "metadata": {},
1369
+ "outputs": [],
1370
+ "source": [
1371
+ "kwargs = {\n",
1372
+ " \"dataset_tags\": \"vivos-commonvoice\",\n",
1373
+ " \"dataset\": \"Vivos\", \n",
1374
+ " \"language\": \"vi\",\n",
1375
+ " \"model_name\": \"Whisper Small Vi - Duy Ta\", \n",
1376
+ " \"finetuned_from\": \"openai/whisper-small\",\n",
1377
+ " \"tasks\": \"automatic-speech-recognition\",\n",
1378
+ " \"config\" : None\n",
1379
+ "}\n"
1380
+ ]
1381
+ },
1382
+ {
1383
+ "cell_type": "code",
1384
+ "execution_count": 131,
1385
+ "metadata": {},
1386
+ "outputs": [
1387
+ {
1388
+ "name": "stderr",
1389
+ "output_type": "stream",
1390
+ "text": [
1391
+ "Several commits (2) will be pushed upstream.\n",
1392
+ "The progress bars may be unreliable.\n",
1393
+ "error: The destination you provided is not a full refname (i.e.,\n",
1394
+ "starting with \"refs/\"). We tried to guess what you meant by:\n",
1395
+ "\n",
1396
+ "- Looking for a ref that matches 'HEAD' on the remote side.\n",
1397
+ "- Checking if the <src> being pushed ('HEAD')\n",
1398
+ " is a ref in \"refs/{heads,tags}/\". If so we add a corresponding\n",
1399
+ " refs/{heads,tags}/ prefix on the remote side.\n",
1400
+ "\n",
1401
+ "Neither worked, so we gave up. You must fully qualify the ref.\n",
1402
+ "hint: The <src> part of the refspec is a commit object.\n",
1403
+ "hint: Did you mean to create a new branch by pushing to\n",
1404
+ "hint: 'HEAD:refs/heads/HEAD'?\n",
1405
+ "error: failed to push some refs to 'https://huggingface.co/DuyTa/vi_whisper-small'\n",
1406
+ "\n"
1407
+ ]
1408
+ },
1409
+ {
1410
+ "ename": "OSError",
1411
+ "evalue": "error: The destination you provided is not a full refname (i.e.,\nstarting with \"refs/\"). We tried to guess what you meant by:\n\n- Looking for a ref that matches 'HEAD' on the remote side.\n- Checking if the <src> being pushed ('HEAD')\n is a ref in \"refs/{heads,tags}/\". If so we add a corresponding\n refs/{heads,tags}/ prefix on the remote side.\n\nNeither worked, so we gave up. You must fully qualify the ref.\nhint: The <src> part of the refspec is a commit object.\nhint: Did you mean to create a new branch by pushing to\nhint: 'HEAD:refs/heads/HEAD'?\nerror: failed to push some refs to 'https://huggingface.co/DuyTa/vi_whisper-small'\n",
1412
+ "output_type": "error",
1413
+ "traceback": [
1414
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1415
+ "\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)",
1416
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:1099\u001b[0m, in \u001b[0;36mRepository.git_push\u001b[0;34m(self, upstream, blocking, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[39mif\u001b[39;00m return_code:\n\u001b[0;32m-> 1099\u001b[0m \u001b[39mraise\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError(return_code, process\u001b[39m.\u001b[39margs, output\u001b[39m=\u001b[39mstdout, stderr\u001b[39m=\u001b[39mstderr)\n\u001b[1;32m 1101\u001b[0m \u001b[39mexcept\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError \u001b[39mas\u001b[39;00m exc:\n",
1417
+ "\u001b[0;31mCalledProcessError\u001b[0m: Command '['git', 'push', '--set-upstream', 'origin', 'HEAD']' returned non-zero exit status 1.",
1418
+ "\nDuring handling of the above exception, another exception occurred:\n",
1419
+ "\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
1420
+ "Cell \u001b[0;32mIn[131], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m trainer\u001b[39m.\u001b[39;49mpush_to_hub(commit_message\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mchange\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
1421
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer.py:3609\u001b[0m, in \u001b[0;36mTrainer.push_to_hub\u001b[0;34m(self, commit_message, blocking, **kwargs)\u001b[0m\n\u001b[1;32m 3606\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpush_in_progress\u001b[39m.\u001b[39m_process\u001b[39m.\u001b[39mkill()\n\u001b[1;32m 3607\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpush_in_progress \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m-> 3609\u001b[0m git_head_commit_url \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrepo\u001b[39m.\u001b[39;49mpush_to_hub(\n\u001b[1;32m 3610\u001b[0m commit_message\u001b[39m=\u001b[39;49mcommit_message, blocking\u001b[39m=\u001b[39;49mblocking, auto_lfs_prune\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m\n\u001b[1;32m 3611\u001b[0m )\n\u001b[1;32m 3612\u001b[0m \u001b[39m# push separately the model card to be independant from the rest of the model\u001b[39;00m\n\u001b[1;32m 3613\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mshould_save:\n",
1422
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:1307\u001b[0m, in \u001b[0;36mRepository.push_to_hub\u001b[0;34m(self, commit_message, blocking, clean_ok, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1305\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgit_add(auto_lfs_track\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 1306\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgit_commit(commit_message)\n\u001b[0;32m-> 1307\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgit_push(\n\u001b[1;32m 1308\u001b[0m upstream\u001b[39m=\u001b[39;49m\u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39morigin \u001b[39;49m\u001b[39m{\u001b[39;49;00m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcurrent_branch\u001b[39m}\u001b[39;49;00m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 1309\u001b[0m blocking\u001b[39m=\u001b[39;49mblocking,\n\u001b[1;32m 1310\u001b[0m auto_lfs_prune\u001b[39m=\u001b[39;49mauto_lfs_prune,\n\u001b[1;32m 1311\u001b[0m )\n",
1423
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:1102\u001b[0m, in \u001b[0;36mRepository.git_push\u001b[0;34m(self, upstream, blocking, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[39mraise\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError(return_code, process\u001b[39m.\u001b[39margs, output\u001b[39m=\u001b[39mstdout, stderr\u001b[39m=\u001b[39mstderr)\n\u001b[1;32m 1101\u001b[0m \u001b[39mexcept\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError \u001b[39mas\u001b[39;00m exc:\n\u001b[0;32m-> 1102\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(exc\u001b[39m.\u001b[39mstderr)\n\u001b[1;32m 1104\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m blocking:\n\u001b[1;32m 1106\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstatus_method\u001b[39m():\n",
1424
+ "\u001b[0;31mOSError\u001b[0m: error: The destination you provided is not a full refname (i.e.,\nstarting with \"refs/\"). We tried to guess what you meant by:\n\n- Looking for a ref that matches 'HEAD' on the remote side.\n- Checking if the <src> being pushed ('HEAD')\n is a ref in \"refs/{heads,tags}/\". If so we add a corresponding\n refs/{heads,tags}/ prefix on the remote side.\n\nNeither worked, so we gave up. You must fully qualify the ref.\nhint: The <src> part of the refspec is a commit object.\nhint: Did you mean to create a new branch by pushing to\nhint: 'HEAD:refs/heads/HEAD'?\nerror: failed to push some refs to 'https://huggingface.co/DuyTa/vi_whisper-small'\n"
1425
+ ]
1426
+ }
1427
+ ],
1428
+ "source": [
1429
+ "trainer.push_to_hub(commit_message=\"change\")"
1430
+ ]
1431
+ },
1432
+ {
1433
+ "cell_type": "code",
1434
+ "execution_count": null,
1435
+ "metadata": {
1436
+ "tags": [
1437
+ "parameters"
1438
+ ]
1439
+ },
1440
+ "outputs": [
1441
+ {
1442
+ "data": {
1443
+ "application/vnd.jupyter.widget-view+json": {
1444
+ "model_id": "b6e666bab7b2450abf3e2adf07679122",
1445
+ "version_major": 2,
1446
+ "version_minor": 0
1447
+ },
1448
+ "text/plain": [
1449
+ "Downloading (…)lve/main/config.json: 0%| | 0.00/1.31k [00:00<?, ?B/s]"
1450
+ ]
1451
+ },
1452
+ "metadata": {},
1453
+ "output_type": "display_data"
1454
+ },
1455
+ {
1456
+ "data": {
1457
+ "application/vnd.jupyter.widget-view+json": {
1458
+ "model_id": "b212026dca9241cf994f9710f0b93c22",
1459
+ "version_major": 2,
1460
+ "version_minor": 0
1461
+ },
1462
+ "text/plain": [
1463
+ "Downloading (…)okenizer_config.json: 0%| | 0.00/838 [00:00<?, ?B/s]"
1464
+ ]
1465
+ },
1466
+ "metadata": {},
1467
+ "output_type": "display_data"
1468
+ }
1469
+ ],
1470
+ "source": [
1471
+ "from transformers import WhisperForConditionalGeneration, WhisperProcessor\n",
1472
+ "\n",
1473
+ "model = WhisperForConditionalGeneration.from_pretrained(\"DuyTa/vi_whisper\")\n",
1474
+ "processor = WhisperProcessor.from_pretrained(\"DuyTa/vi_whisper\")\n"
1475
+ ]
1476
+ },
1477
+ {
1478
+ "cell_type": "code",
1479
+ "execution_count": 36,
1480
+ "metadata": {},
1481
+ "outputs": [
1482
+ {
1483
+ "ename": "RuntimeError",
1484
+ "evalue": "Instantiating a pipeline without a task set raised an error: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vi_whisper-small'.",
1485
+ "output_type": "error",
1486
+ "traceback": [
1487
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1488
+ "\u001b[0;31mHFValidationError\u001b[0m Traceback (most recent call last)",
1489
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/pipelines/__init__.py:432\u001b[0m, in \u001b[0;36mget_task\u001b[0;34m(model, use_auth_token)\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 432\u001b[0m info \u001b[39m=\u001b[39m model_info(model, token\u001b[39m=\u001b[39;49muse_auth_token)\n\u001b[1;32m 433\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n",
1490
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:110\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[39mif\u001b[39;00m arg_name \u001b[39min\u001b[39;00m [\u001b[39m\"\u001b[39m\u001b[39mrepo_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mfrom_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mto_id\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[0;32m--> 110\u001b[0m validate_repo_id(arg_value)\n\u001b[1;32m 112\u001b[0m \u001b[39melif\u001b[39;00m arg_name \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtoken\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mand\u001b[39;00m arg_value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
1491
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:164\u001b[0m, in \u001b[0;36mvalidate_repo_id\u001b[0;34m(repo_id)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m REPO_ID_REGEX\u001b[39m.\u001b[39mmatch(repo_id):\n\u001b[0;32m--> 164\u001b[0m \u001b[39mraise\u001b[39;00m HFValidationError(\n\u001b[1;32m 165\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mRepo id must use alphanumeric chars or \u001b[39m\u001b[39m'\u001b[39m\u001b[39m-\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m_\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m--\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39m..\u001b[39m\u001b[39m'\u001b[39m\u001b[39m are\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 166\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m forbidden, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m-\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m\u001b[39m cannot start or end the name, max length is 96:\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 167\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mrepo_id\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 168\u001b[0m )\n\u001b[1;32m 170\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m--\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m repo_id \u001b[39mor\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m..\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m repo_id:\n",
1492
+ "\u001b[0;31mHFValidationError\u001b[0m: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vi_whisper-small'.",
1493
+ "\nDuring handling of the above exception, another exception occurred:\n",
1494
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1495
+ "Cell \u001b[0;32mIn[36], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformers\u001b[39;00m \u001b[39mimport\u001b[39;00m pipeline\n\u001b[1;32m 2\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mgradio\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mgr\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m pipe \u001b[39m=\u001b[39m pipeline(model\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m./vi_whisper-small\u001b[39;49m\u001b[39m\"\u001b[39;49m) \n\u001b[1;32m 6\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtranscribe\u001b[39m(audio):\n\u001b[1;32m 7\u001b[0m text \u001b[39m=\u001b[39m pipe(audio)[\u001b[39m\"\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m\"\u001b[39m]\n",
1496
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/pipelines/__init__.py:726\u001b[0m, in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, image_processor, framework, revision, use_fast, use_auth_token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(model, \u001b[39mstr\u001b[39m):\n\u001b[1;32m 722\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 723\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mInferring the task automatically requires to check the hub with a model_id defined as a `str`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 724\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mmodel\u001b[39m}\u001b[39;00m\u001b[39m is not a valid model_id.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 725\u001b[0m )\n\u001b[0;32m--> 726\u001b[0m task \u001b[39m=\u001b[39m get_task(model, use_auth_token)\n\u001b[1;32m 728\u001b[0m \u001b[39m# Retrieve the task\u001b[39;00m\n\u001b[1;32m 729\u001b[0m \u001b[39mif\u001b[39;00m task \u001b[39min\u001b[39;00m custom_tasks:\n",
1497
+ "File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/pipelines/__init__.py:434\u001b[0m, in \u001b[0;36mget_task\u001b[0;34m(model, use_auth_token)\u001b[0m\n\u001b[1;32m 432\u001b[0m info \u001b[39m=\u001b[39m model_info(model, token\u001b[39m=\u001b[39muse_auth_token)\n\u001b[1;32m 433\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m--> 434\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mInstantiating a pipeline without a task set raised an error: \u001b[39m\u001b[39m{\u001b[39;00me\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 435\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m info\u001b[39m.\u001b[39mpipeline_tag:\n\u001b[1;32m 436\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 437\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mThe model \u001b[39m\u001b[39m{\u001b[39;00mmodel\u001b[39m}\u001b[39;00m\u001b[39m does not seem to have a correct `pipeline_tag` set to infer the task automatically\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 438\u001b[0m )\n",
1498
+ "\u001b[0;31mRuntimeError\u001b[0m: Instantiating a pipeline without a task set raised an error: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vi_whisper-small'."
1499
+ ]
1500
+ }
1501
+ ],
1502
+ "source": [
1503
+ "from transformers import pipeline\n",
1504
+ "import gradio as gr\n",
1505
+ "\n",
1506
+ "pipe = pipeline(model=\"./vi_whisper-small\") \n",
1507
+ "\n",
1508
+ "def transcribe(audio):\n",
1509
+ " text = pipe(audio)[\"text\"]\n",
1510
+ " return text\n",
1511
+ "\n",
1512
+ "iface = gr.Interface(\n",
1513
+ " fn=transcribe,\n",
1514
+ " inputs=gr.Audio(source=\"upload\", type=\"filepath\"),\n",
1515
+ " outputs=\"text\",\n",
1516
+ " title=\"Whisper Base Vietnamese\",\n",
1517
+ " description=\"Realtime demo for Vietnamese speech recognition using a fine-tuned Whisper base model.\",\n",
1518
+ ")\n",
1519
+ "\n",
1520
+ "iface.launch()"
1521
+ ]
1522
+ }
1523
+ ],
1524
+ "metadata": {
1525
+ "kernelspec": {
1526
+ "display_name": "DUY",
1527
+ "language": "python",
1528
+ "name": "python3"
1529
+ },
1530
+ "language_info": {
1531
+ "codemirror_mode": {
1532
+ "name": "ipython",
1533
+ "version": 3
1534
+ },
1535
+ "file_extension": ".py",
1536
+ "mimetype": "text/x-python",
1537
+ "name": "python",
1538
+ "nbconvert_exporter": "python",
1539
+ "pygments_lexer": "ipython3",
1540
+ "version": "3.9.17"
1541
+ },
1542
+ "orig_nbformat": 4
1543
+ },
1544
+ "nbformat": 4,
1545
+ "nbformat_minor": 2
1546
+ }
src/training.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import interpreter_login
2
+
3
+ from datasets import load_dataset, DatasetDict, load_from_disk
4
+
5
+ from transformers import WhisperProcessor
6
+ from transformers import WhisperForConditionalGeneration
7
+ from transformers import Seq2SeqTrainingArguments
8
+ from transformers import Seq2SeqTrainer
9
+ from transformers import EarlyStoppingCallback
10
+ from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
11
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
12
+
13
+ from peft import prepare_model_for_int8_training
14
+ from peft import PeftModel, LoraModel, LoraConfig, get_peft_model
15
+
16
+ import torch
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Union
20
+
21
+ import evaluate
22
+
23
+ import os
24
+
25
+ class SavePeftModelCallback(TrainerCallback):
26
+ def on_save(
27
+ self,
28
+ args: TrainingArguments,
29
+ state: TrainerState,
30
+ control: TrainerControl,
31
+ **kwargs,
32
+ ):
33
+ checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
34
+
35
+ peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
36
+ kwargs["model"].save_pretrained(peft_model_path)
37
+
38
+ pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
39
+ if os.path.exists(pytorch_model_path):
40
+ os.remove(pytorch_model_path)
41
+ return control
42
+
43
+ @dataclass
44
+ class DataCollatorSpeechSeq2SeqWithPadding:
45
+ processor: Any
46
+
47
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
48
+ # split inputs and labels since they have to be of different lengths and need different padding methods
49
+ # first treat the audio inputs by simply returning torch tensors
50
+ input_features = [{"input_features": feature["input_features"]} for feature in features]
51
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
52
+
53
+ # get the tokenized label sequences
54
+ label_features = [{"input_ids": feature["labels"]} for feature in features]
55
+
56
+ # ******************This is only in the case of augmented data ***************** Remove if not
57
+ batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
58
+
59
+ # pad the labels to max length
60
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
61
+
62
+ # replace padding with -100 to ignore loss correctly
63
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
64
+
65
+ # if bos token is appended in previous tokenization step,
66
+ # cut bos token here as it's append later anyways
67
+ if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
68
+ labels = labels[:, 1:]
69
+
70
+ batch["labels"] = labels
71
+
72
+ return batch
73
+
74
+ def compute_metrics(pred):
75
+ pred_ids = pred.predictions
76
+ label_ids = pred.label_ids
77
+
78
+ # replace -100 with the pad_token_id
79
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
80
+
81
+ # we do not want to group tokens when computing the metrics
82
+ pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
83
+ label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
84
+
85
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
86
+
87
+ return {"wer": wer}
88
+
89
+
90
+
91
+ if __name__ == "__main__":
92
+
93
+
94
+ early_stopping_callback = EarlyStoppingCallback(
95
+ early_stopping_patience=3, # Stop training if the metric doesn't improve for 3 evaluations
96
+ early_stopping_threshold=0.0005, # Minimum change in the metric to be considered an improvement
97
+ )
98
+
99
+ # Load Dataset
100
+ processed_dataset = DatasetDict()
101
+ processed_dataset = load_from_disk("./vin_clean")
102
+
103
+
104
+ print(processed_dataset)
105
+
106
+ # load processor
107
+ processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe")
108
+
109
+
110
+ # intialize data collator
111
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
112
+
113
+ # download metric
114
+ metric = evaluate.load("wer")
115
+
116
+ # Download model in 8bit
117
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto")
118
+ model.config.forced_decoder_ids = None
119
+ model.config.suppress_tokens = []
120
+
121
+ # preparing model with PEFT
122
+ model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out")
123
+
124
+ config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
125
+
126
+ model = get_peft_model(model, config)
127
+ model.print_trainable_parameters()
128
+
129
+
130
+ # Define trainnig arguments
131
+ training_args = Seq2SeqTrainingArguments(
132
+ output_dir="./whisper-medium-Lora", # change to a repo name of your choice
133
+ per_device_train_batch_size=32,
134
+ gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size
135
+ learning_rate=5e-5,
136
+ warmup_steps=500,
137
+ max_steps=10000,
138
+ evaluation_strategy="steps",
139
+ gradient_checkpointing=True,
140
+ optim="adamw_torch",
141
+ fp16=True,
142
+ per_device_eval_batch_size=8,
143
+ generation_max_length=225,
144
+ save_steps=2000,
145
+ eval_steps=500,
146
+ logging_steps=25,
147
+ report_to=["tensorboard"],
148
+ predict_with_generate=True,
149
+ # load_best_model_at_end=True,
150
+ metric_for_best_model="wer",
151
+ greater_is_better=False,
152
+ # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
153
+ remove_unused_columns=False,
154
+ label_names=["labels"], # same reason as above
155
+ push_to_hub=False,
156
+ )
157
+
158
+ # initialize trainer
159
+ trainer = Seq2SeqTrainer(
160
+ args=training_args,
161
+ model=model,
162
+ train_dataset=processed_dataset["train"],
163
+ eval_dataset=processed_dataset["test"],
164
+ data_collator=data_collator,
165
+ tokenizer=processor.feature_extractor,
166
+ callbacks=[early_stopping_callback, SavePeftModelCallback],
167
+ )
168
+
169
+
170
+ # start training
171
+ trainer.train()
172
+
173
+
174
+ # set up args and push to hub
175
+ kwargs = {
176
+ "dataset": "vin100h",
177
+ "language": "vi",
178
+ "model_name": "Whisper Medium LoRA - Clean Data",
179
+ "finetuned_from": "openai/whisper-medium",
180
+ "tasks": "automatic-speech-recognition",
181
+ }
182
+
183
+ model.push_to_hub(**kwargs)
src/vin_whisper_medium.ipynb ADDED
@@ -0,0 +1,1164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/whisper/__init__.py\n"
13
+ ]
14
+ }
15
+ ],
16
+ "source": [
17
+ "import whisper\n",
18
+ "print(whisper.__file__)\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "import os\n",
28
+ "\n",
29
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": 3,
35
+ "metadata": {},
36
+ "outputs": [
37
+ {
38
+ "name": "stdout",
39
+ "output_type": "stream",
40
+ "text": [
41
+ "True\n",
42
+ "2\n",
43
+ "0\n",
44
+ "<torch.cuda.device object at 0x7f69e1e31eb0>\n",
45
+ "Tesla T4\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "import torch\n",
51
+ "\n",
52
+ "print(torch.cuda.is_available())\n",
53
+ "\n",
54
+ "\n",
55
+ "print(torch.cuda.device_count())\n",
56
+ "\n",
57
+ "\n",
58
+ "print(torch.cuda.current_device())\n",
59
+ "print(torch.cuda.device(0))\n",
60
+ "\n",
61
+ "print(torch.cuda.get_device_name(0))\n"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": 2,
67
+ "metadata": {},
68
+ "outputs": [
69
+ {
70
+ "data": {
71
+ "application/vnd.jupyter.widget-view+json": {
72
+ "model_id": "f290d4efc37a4112a662c062e621e482",
73
+ "version_major": 2,
74
+ "version_minor": 0
75
+ },
76
+ "text/plain": [
77
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
78
+ ]
79
+ },
80
+ "metadata": {},
81
+ "output_type": "display_data"
82
+ }
83
+ ],
84
+ "source": [
85
+ "from huggingface_hub import notebook_login\n",
86
+ "\n",
87
+ "notebook_login()"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 2,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "model_name_or_path = \"openai/whisper-medium\"\n",
97
+ "task = \"transcribe\""
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": 3,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "dataset_name = \"Vin100h_MITI_private\"\n",
107
+ "language = \"Vietnamese\"\n",
108
+ "language_abbr = \"vi\" # Short hand code for the language we want to fine-tune"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 4,
114
+ "metadata": {},
115
+ "outputs": [
116
+ {
117
+ "name": "stdout",
118
+ "output_type": "stream",
119
+ "text": [
120
+ "DatasetDict({\n",
121
+ " train: Dataset({\n",
122
+ " features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
123
+ " num_rows: 1679\n",
124
+ " })\n",
125
+ " test: Dataset({\n",
126
+ " features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
127
+ " num_rows: 420\n",
128
+ " })\n",
129
+ "})\n",
130
+ "DatasetDict({\n",
131
+ " train: Dataset({\n",
132
+ " features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
133
+ " num_rows: 6735\n",
134
+ " })\n",
135
+ " test: Dataset({\n",
136
+ " features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
137
+ " num_rows: 1688\n",
138
+ " })\n",
139
+ "})\n"
140
+ ]
141
+ }
142
+ ],
143
+ "source": [
144
+ " # Load Dataset\n",
145
+ "from datasets import load_dataset, DatasetDict, load_from_disk\n",
146
+ "processed_dataset = DatasetDict()\n",
147
+ "processed_dataset = load_from_disk(\"./MITI_clean\")\n",
148
+ "processed_dataset2 = load_from_disk(\"./vin_10h/\")\n",
149
+ "\n",
150
+ "print(processed_dataset)\n",
151
+ "print(processed_dataset2)"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "code",
156
+ "execution_count": 49,
157
+ "metadata": {},
158
+ "outputs": [],
159
+ "source": [
160
+ "from datasets import Dataset\n",
161
+ "\n",
162
+ "# Assuming you have already loaded your dataset\n",
163
+ "# processed_dataset2 = ...\n",
164
+ "\n",
165
+ "# Randomly select 5000 indices from the train dataset\n",
166
+ "import random\n",
167
+ "num_samples_train = 5000\n",
168
+ "num_samples_test = 600\n",
169
+ "random_indices_train = random.sample(range(len(processed_dataset2['train'])), num_samples_train)\n",
170
+ "random_indices_test = random.sample(range(len(processed_dataset2['test'])), num_samples_test)\n",
171
+ "\n",
172
+ "# Initialize lists for train dataset\n",
173
+ "input_features_train = []\n",
174
+ "input_length_train = []\n",
175
+ "attention_mask_train = []\n",
176
+ "labels_train = []\n",
177
+ "\n",
178
+ "# Initialize lists for test dataset\n",
179
+ "input_features_test = []\n",
180
+ "input_length_test = []\n",
181
+ "attention_mask_test = []\n",
182
+ "labels_test = []\n",
183
+ "\n",
184
+ "# Populate lists for train dataset\n",
185
+ "for i in random_indices_train:\n",
186
+ " input_features_train.append(processed_dataset2['train'][i]['input_features'])\n",
187
+ " input_length_train.append(processed_dataset2['train'][i]['input_length'])\n",
188
+ " attention_mask_train.append(processed_dataset2['train'][i]['attention_mask'])\n",
189
+ " labels_train.append(processed_dataset2['train'][i]['labels'])\n",
190
+ "\n",
191
+ "# Populate lists for test dataset\n",
192
+ "for i in random_indices_test:\n",
193
+ " input_features_test.append(processed_dataset2['test'][i]['input_features'])\n",
194
+ " input_length_test.append(processed_dataset2['test'][i]['input_length'])\n",
195
+ " attention_mask_test.append(processed_dataset2['test'][i]['attention_mask'])\n",
196
+ " labels_test.append(processed_dataset2['test'][i]['labels'])\n",
197
+ "\n",
198
+ "# Create a new dataset with the randomly selected rows\n",
199
+ "random_subset = Dataset.from_dict({\n",
200
+ " 'train': {\n",
201
+ " 'input_features': input_features_train,\n",
202
+ " 'input_length': input_length_train,\n",
203
+ " 'attention_mask': attention_mask_train,\n",
204
+ " 'labels': labels_train,\n",
205
+ " },\n",
206
+ " 'test': {\n",
207
+ " 'input_features': input_features_test,\n",
208
+ " 'input_length': input_length_test,\n",
209
+ " 'attention_mask': attention_mask_test,\n",
210
+ " 'labels': labels_test,\n",
211
+ " }\n",
212
+ "})\n",
213
+ "\n",
214
+ "\n"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": 5,
220
+ "metadata": {},
221
+ "outputs": [],
222
+ "source": [
223
+ "import datasets\n",
224
+ "concat = DatasetDict()\n",
225
+ "concat[\"train\"] = datasets.concatenate_datasets([processed_dataset[\"train\"], processed_dataset2[\"train\"]])\n",
226
+ "concat['test']= datasets.concatenate_datasets([processed_dataset[\"test\"], processed_dataset2[\"test\"]])\n"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 7,
232
+ "metadata": {},
233
+ "outputs": [
234
+ {
235
+ "data": {
236
+ "text/plain": [
237
+ "DatasetDict({\n",
238
+ " train: Dataset({\n",
239
+ " features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
240
+ " num_rows: 8414\n",
241
+ " })\n",
242
+ " test: Dataset({\n",
243
+ " features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
244
+ " num_rows: 2108\n",
245
+ " })\n",
246
+ "})"
247
+ ]
248
+ },
249
+ "execution_count": 7,
250
+ "metadata": {},
251
+ "output_type": "execute_result"
252
+ }
253
+ ],
254
+ "source": [
255
+ "concat"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": 7,
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "from transformers import WhisperFeatureExtractor\n",
265
+ "\n",
266
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": 6,
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "\n",
276
+ "from transformers import WhisperTokenizer\n",
277
+ "\n",
278
+ "tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 11,
284
+ "metadata": {},
285
+ "outputs": [
286
+ {
287
+ "data": {
288
+ "text/plain": [
289
+ "('./Viet_ASR/tokenizer_config.json',\n",
290
+ " './Viet_ASR/special_tokens_map.json',\n",
291
+ " './Viet_ASR/vocab.json',\n",
292
+ " './Viet_ASR/merges.txt',\n",
293
+ " './Viet_ASR/normalizer.json',\n",
294
+ " './Viet_ASR/added_tokens.json')"
295
+ ]
296
+ },
297
+ "execution_count": 11,
298
+ "metadata": {},
299
+ "output_type": "execute_result"
300
+ }
301
+ ],
302
+ "source": [
303
+ "tokenizer.save_pretrained('./Viet_ASR')"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 8,
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "import torch\n",
313
+ "\n",
314
+ "from dataclasses import dataclass\n",
315
+ "from typing import Any, Dict, List, Union\n",
316
+ "from transformers import WhisperProcessor\n",
317
+ "\n",
318
+ "processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)\n",
319
+ "@dataclass\n",
320
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
321
+ " processor: Any\n",
322
+ "\n",
323
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
324
+ " # split inputs and labels since they have to be of different lengths and need different padding methods\n",
325
+ " # first treat the audio inputs by simply returning torch tensors\n",
326
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
327
+ " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
328
+ "\n",
329
+ " # get the tokenized label sequences\n",
330
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
331
+ "\n",
332
+ "\n",
333
+ " # pad the labels to max length\n",
334
+ " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
335
+ "\n",
336
+ " # replace padding with -100 to ignore loss correctly\n",
337
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
338
+ "\n",
339
+ " # if bos token is appended in previous tokenization step,\n",
340
+ " # cut bos token here as it's append later anyways\n",
341
+ " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
342
+ " labels = labels[:, 1:]\n",
343
+ "\n",
344
+ " batch[\"labels\"] = labels\n",
345
+ "\n",
346
+ " return batch\n",
347
+ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "execution_count": 9,
353
+ "metadata": {},
354
+ "outputs": [],
355
+ "source": [
356
+ "import evaluate\n",
357
+ "\n",
358
+ "metric = evaluate.load(\"wer\")"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": 12,
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "from transformers import WhisperForConditionalGeneration\n",
368
+ "\n",
369
+ "model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-medium', load_in_8bit=True, device_map=\"auto\" )"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": 13,
375
+ "metadata": {},
376
+ "outputs": [],
377
+ "source": [
378
+ "model.config.forced_decoder_ids = None\n",
379
+ "model.config.suppress_tokens = []"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": 14,
385
+ "metadata": {},
386
+ "outputs": [
387
+ {
388
+ "data": {
389
+ "text/plain": [
390
+ "<torch.utils.hooks.RemovableHandle at 0x7f1a9445da60>"
391
+ ]
392
+ },
393
+ "execution_count": 14,
394
+ "metadata": {},
395
+ "output_type": "execute_result"
396
+ }
397
+ ],
398
+ "source": [
399
+ "from peft import prepare_model_for_kbit_training\n",
400
+ "\n",
401
+ "model = prepare_model_for_kbit_training(model)\n",
402
+ "def make_inputs_require_grad(module, input, output):\n",
403
+ " output.requires_grad_(True)\n",
404
+ "\n",
405
+ "model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)"
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": 15,
411
+ "metadata": {},
412
+ "outputs": [
413
+ {
414
+ "name": "stdout",
415
+ "output_type": "stream",
416
+ "text": [
417
+ "trainable params: 9,437,184 || all params: 773,295,104 || trainable%: 1.2203858463844612\n"
418
+ ]
419
+ }
420
+ ],
421
+ "source": [
422
+ "from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model\n",
423
+ "#target_modules = [\"k_proj\", \"q_proj\", \"v_proj\", \"out_proj\", \"fc1\", \"fc2\"] #will it better ?\n",
424
+ "target_modules=[\"q_proj\", \"v_proj\"]\n",
425
+ "config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias=\"none\")\n",
426
+ "\n",
427
+ "model = get_peft_model(model, config)\n",
428
+ "model.print_trainable_parameters()"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": 16,
434
+ "metadata": {},
435
+ "outputs": [],
436
+ "source": [
437
+ "from transformers import Seq2SeqTrainingArguments\n",
438
+ "\n",
439
+ "training_args = Seq2SeqTrainingArguments(\n",
440
+ " output_dir=\"./Vietnamese_ASR\", \n",
441
+ " per_device_train_batch_size=10,\n",
442
+ " #auto_find_batch_size = True,\n",
443
+ " gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size\n",
444
+ " learning_rate=5e-5,\n",
445
+ " warmup_steps=50,\n",
446
+ " num_train_epochs=3,\n",
447
+ " evaluation_strategy=\"epoch\",\n",
448
+ " gradient_checkpointing=True,\n",
449
+ " optim=\"adamw_torch\",\n",
450
+ " fp16=True,\n",
451
+ " per_device_eval_batch_size=8,\n",
452
+ " generation_max_length=225,\n",
453
+ " logging_steps=100,\n",
454
+ " report_to=[\"tensorboard\"],\n",
455
+ " predict_with_generate=True,\n",
456
+ " # load_best_model_at_end=True,\n",
457
+ " greater_is_better=False,\n",
458
+ " save_strategy = \"epoch\",\n",
459
+ " # required as the PeftModel forward doesn't have the signature of the wrapped model's forward\n",
460
+ " remove_unused_columns=False,\n",
461
+ " label_names=[\"labels\"], # same reason as above\n",
462
+ " push_to_hub=True,\n",
463
+ ")"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": 19,
469
+ "metadata": {},
470
+ "outputs": [
471
+ {
472
+ "name": "stderr",
473
+ "output_type": "stream",
474
+ "text": [
475
+ "/media/tesla/New Volume/DEMO/DUY/Vietnamese_ASR/./Vietnamese_ASR is already a clone of https://huggingface.co/DuyTa/Vietnamese_ASR. Make sure you pull the latest changes with `repo.git_pull()`.\n"
476
+ ]
477
+ }
478
+ ],
479
+ "source": [
480
+ "from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl\n",
481
+ "from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
482
+ "\n",
483
+ "\n",
484
+ "class SavePeftModelCallback(TrainerCallback):\n",
485
+ " def on_save(\n",
486
+ " self,\n",
487
+ " args: TrainingArguments,\n",
488
+ " state: TrainerState,\n",
489
+ " control: TrainerControl,\n",
490
+ " **kwargs,\n",
491
+ " ):\n",
492
+ " checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n",
493
+ "\n",
494
+ " peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
495
+ " kwargs[\"model\"].save_pretrained(peft_model_path)\n",
496
+ "\n",
497
+ " pytorch_model_path = os.path.join(checkpoint_folder, \"pytorch_model.bin\")\n",
498
+ " if os.path.exists(pytorch_model_path):\n",
499
+ " os.remove(pytorch_model_path)\n",
500
+ " return control\n",
501
+ "\n",
502
+ "\n",
503
+ "trainer = Seq2SeqTrainer(\n",
504
+ " args=training_args,\n",
505
+ " model=model,\n",
506
+ " train_dataset=concat[\"train\"],\n",
507
+ " eval_dataset=concat[\"test\"],\n",
508
+ " data_collator=data_collator,\n",
509
+ " # compute_metrics=compute_metrics,\n",
510
+ " tokenizer=processor.feature_extractor,\n",
511
+ " callbacks=[SavePeftModelCallback],\n",
512
+ ")\n",
513
+ "model.config.use_cache = False # silence the warnings. Please re-enable for inference!"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": 20,
519
+ "metadata": {},
520
+ "outputs": [
521
+ {
522
+ "data": {
523
+ "application/vnd.jupyter.widget-view+json": {
524
+ "model_id": "b0aa0180f6e64eaa8951a4c940aa518f",
525
+ "version_major": 2,
526
+ "version_minor": 0
527
+ },
528
+ "text/plain": [
529
+ " 0%| | 0/1263 [00:00<?, ?it/s]"
530
+ ]
531
+ },
532
+ "metadata": {},
533
+ "output_type": "display_data"
534
+ },
535
+ {
536
+ "name": "stderr",
537
+ "output_type": "stream",
538
+ "text": [
539
+ "/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
540
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
541
+ ]
542
+ },
543
+ {
544
+ "name": "stdout",
545
+ "output_type": "stream",
546
+ "text": [
547
+ "{'loss': 1.9814, 'learning_rate': 4.814509480626546e-05, 'epoch': 0.24}\n",
548
+ "{'loss': 0.6861, 'learning_rate': 4.402308326463314e-05, 'epoch': 0.48}\n",
549
+ "{'loss': 0.3736, 'learning_rate': 3.9901071723000826e-05, 'epoch': 0.71}\n",
550
+ "{'loss': 0.332, 'learning_rate': 3.577906018136851e-05, 'epoch': 0.95}\n"
551
+ ]
552
+ },
553
+ {
554
+ "data": {
555
+ "application/vnd.jupyter.widget-view+json": {
556
+ "model_id": "2e9a8f06d39e448a9523d9a29699cadc",
557
+ "version_major": 2,
558
+ "version_minor": 0
559
+ },
560
+ "text/plain": [
561
+ " 0%| | 0/264 [00:00<?, ?it/s]"
562
+ ]
563
+ },
564
+ "metadata": {},
565
+ "output_type": "display_data"
566
+ },
567
+ {
568
+ "name": "stdout",
569
+ "output_type": "stream",
570
+ "text": [
571
+ "{'eval_loss': 0.3133259117603302, 'eval_runtime': 887.0949, 'eval_samples_per_second': 2.376, 'eval_steps_per_second': 0.298, 'epoch': 1.0}\n"
572
+ ]
573
+ },
574
+ {
575
+ "name": "stderr",
576
+ "output_type": "stream",
577
+ "text": [
578
+ "/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
579
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
580
+ ]
581
+ },
582
+ {
583
+ "name": "stdout",
584
+ "output_type": "stream",
585
+ "text": [
586
+ "{'loss': 0.3005, 'learning_rate': 3.165704863973619e-05, 'epoch': 1.19}\n",
587
+ "{'loss': 0.307, 'learning_rate': 2.753503709810388e-05, 'epoch': 1.43}\n",
588
+ "{'loss': 0.2838, 'learning_rate': 2.341302555647156e-05, 'epoch': 1.66}\n",
589
+ "{'loss': 0.2746, 'learning_rate': 1.9291014014839242e-05, 'epoch': 1.9}\n"
590
+ ]
591
+ },
592
+ {
593
+ "data": {
594
+ "application/vnd.jupyter.widget-view+json": {
595
+ "model_id": "1e65ecdbc96246b8b4721505b4252a8a",
596
+ "version_major": 2,
597
+ "version_minor": 0
598
+ },
599
+ "text/plain": [
600
+ " 0%| | 0/264 [00:00<?, ?it/s]"
601
+ ]
602
+ },
603
+ "metadata": {},
604
+ "output_type": "display_data"
605
+ },
606
+ {
607
+ "name": "stdout",
608
+ "output_type": "stream",
609
+ "text": [
610
+ "{'eval_loss': 0.28433552384376526, 'eval_runtime': 880.1965, 'eval_samples_per_second': 2.395, 'eval_steps_per_second': 0.3, 'epoch': 2.0}\n"
611
+ ]
612
+ },
613
+ {
614
+ "name": "stderr",
615
+ "output_type": "stream",
616
+ "text": [
617
+ "/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
618
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
619
+ ]
620
+ },
621
+ {
622
+ "name": "stdout",
623
+ "output_type": "stream",
624
+ "text": [
625
+ "{'loss': 0.2857, 'learning_rate': 1.5169002473206925e-05, 'epoch': 2.14}\n",
626
+ "{'loss': 0.2643, 'learning_rate': 1.104699093157461e-05, 'epoch': 2.38}\n",
627
+ "{'loss': 0.2604, 'learning_rate': 6.924979389942292e-06, 'epoch': 2.61}\n",
628
+ "{'loss': 0.2505, 'learning_rate': 2.8029678483099755e-06, 'epoch': 2.85}\n"
629
+ ]
630
+ },
631
+ {
632
+ "data": {
633
+ "application/vnd.jupyter.widget-view+json": {
634
+ "model_id": "3a4f479ef36f4f00b3c591503b411e5f",
635
+ "version_major": 2,
636
+ "version_minor": 0
637
+ },
638
+ "text/plain": [
639
+ " 0%| | 0/264 [00:00<?, ?it/s]"
640
+ ]
641
+ },
642
+ "metadata": {},
643
+ "output_type": "display_data"
644
+ },
645
+ {
646
+ "name": "stdout",
647
+ "output_type": "stream",
648
+ "text": [
649
+ "{'eval_loss': 0.27759623527526855, 'eval_runtime': 879.7333, 'eval_samples_per_second': 2.396, 'eval_steps_per_second': 0.3, 'epoch': 3.0}\n",
650
+ "{'train_runtime': 35575.7347, 'train_samples_per_second': 0.71, 'train_steps_per_second': 0.036, 'train_loss': 0.4555940831925127, 'epoch': 3.0}\n"
651
+ ]
652
+ },
653
+ {
654
+ "data": {
655
+ "text/plain": [
656
+ "TrainOutput(global_step=1263, training_loss=0.4555940831925127, metrics={'train_runtime': 35575.7347, 'train_samples_per_second': 0.71, 'train_steps_per_second': 0.036, 'train_loss': 0.4555940831925127, 'epoch': 3.0})"
657
+ ]
658
+ },
659
+ "execution_count": 20,
660
+ "metadata": {},
661
+ "output_type": "execute_result"
662
+ }
663
+ ],
664
+ "source": [
665
+ "trainer.train()"
666
+ ]
667
+ },
668
+ {
669
+ "cell_type": "code",
670
+ "execution_count": 22,
671
+ "metadata": {},
672
+ "outputs": [
673
+ {
674
+ "name": "stdout",
675
+ "output_type": "stream",
676
+ "text": [
677
+ "DuyTa/Vietnamese_ASR\n"
678
+ ]
679
+ }
680
+ ],
681
+ "source": [
682
+ "peft_model_id = \"DuyTa/Vietnamese_ASR\"\n",
683
+ "model.push_to_hub(peft_model_id)\n",
684
+ "print(peft_model_id)"
685
+ ]
686
+ },
687
+ {
688
+ "cell_type": "code",
689
+ "execution_count": 10,
690
+ "metadata": {},
691
+ "outputs": [],
692
+ "source": [
693
+ "from peft import PeftModel, PeftConfig\n",
694
+ "from transformers import WhisperForConditionalGeneration\n",
695
+ "peft_model_id = \"./Vietnamese_ASR\"\n",
696
+ "peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
697
+ "model = WhisperForConditionalGeneration.from_pretrained(\n",
698
+ " peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
699
+ ")\n",
700
+ "model = PeftModel.from_pretrained(model, peft_model_id)"
701
+ ]
702
+ },
703
+ {
704
+ "cell_type": "code",
705
+ "execution_count": 11,
706
+ "metadata": {},
707
+ "outputs": [
708
+ {
709
+ "name": "stderr",
710
+ "output_type": "stream",
711
+ "text": [
712
+ " 0%| | 0/88 [00:00<?, ?it/s]"
713
+ ]
714
+ },
715
+ {
716
+ "name": "stderr",
717
+ "output_type": "stream",
718
+ "text": [
719
+ "/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
720
+ " warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
721
+ "100%|██████████| 88/88 [59:07<00:00, 40.31s/it]"
722
+ ]
723
+ },
724
+ {
725
+ "name": "stdout",
726
+ "output_type": "stream",
727
+ "text": [
728
+ "wer=15.57082617523036\n"
729
+ ]
730
+ },
731
+ {
732
+ "name": "stderr",
733
+ "output_type": "stream",
734
+ "text": [
735
+ "\n"
736
+ ]
737
+ }
738
+ ],
739
+ "source": [
740
+ "from torch.utils.data import DataLoader\n",
741
+ "from tqdm import tqdm\n",
742
+ "import numpy as np\n",
743
+ "import gc\n",
744
+ "\n",
745
+ "eval_dataloader = DataLoader(concat[\"test\"], batch_size=24, collate_fn=data_collator)\n",
746
+ "\n",
747
+ "model.eval()\n",
748
+ "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
749
+ " with torch.cuda.amp.autocast():\n",
750
+ " with torch.no_grad():\n",
751
+ " generated_tokens = (\n",
752
+ " model.generate(\n",
753
+ " input_features=batch[\"input_features\"].to(\"cuda\"),\n",
754
+ " decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),\n",
755
+ " max_new_tokens=255,\n",
756
+ " )\n",
757
+ " .cpu()\n",
758
+ " .numpy()\n",
759
+ " )\n",
760
+ " labels = batch[\"labels\"].cpu().numpy()\n",
761
+ " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
762
+ " decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
763
+ " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
764
+ " metric.add_batch(\n",
765
+ " predictions=decoded_preds,\n",
766
+ " references=decoded_labels,\n",
767
+ " )\n",
768
+ " del generated_tokens, labels, batch\n",
769
+ " gc.collect()\n",
770
+ "wer = 100 * metric.compute()\n",
771
+ "print(f\"{wer=}\")"
772
+ ]
773
+ },
774
+ {
775
+ "cell_type": "markdown",
776
+ "metadata": {},
777
+ "source": [
778
+ "## Text Norm"
779
+ ]
780
+ },
781
+ {
782
+ "cell_type": "code",
783
+ "execution_count": null,
784
+ "metadata": {},
785
+ "outputs": [],
786
+ "source": [
787
+ "# using Vietnamese text normalization after take whisper out token"
788
+ ]
789
+ },
790
+ {
791
+ "cell_type": "code",
792
+ "execution_count": 12,
793
+ "metadata": {},
794
+ "outputs": [
795
+ {
796
+ "name": "stderr",
797
+ "output_type": "stream",
798
+ "text": [
799
+ "100%|██████████| 88/88 [58:18<00:00, 39.75s/it]"
800
+ ]
801
+ },
802
+ {
803
+ "name": "stdout",
804
+ "output_type": "stream",
805
+ "text": [
806
+ "normalized_wer=14.601364195460137\n"
807
+ ]
808
+ },
809
+ {
810
+ "name": "stderr",
811
+ "output_type": "stream",
812
+ "text": [
813
+ "\n"
814
+ ]
815
+ }
816
+ ],
817
+ "source": [
818
+ "from torch.utils.data import DataLoader\n",
819
+ "from tqdm import tqdm\n",
820
+ "import numpy as np\n",
821
+ "import gc\n",
822
+ "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n",
823
+ "normalizer = BasicTextNormalizer()\n",
824
+ "forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)\n",
825
+ "\n",
826
+ "model.eval()\n",
827
+ "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
828
+ " with torch.cuda.amp.autocast():\n",
829
+ " with torch.no_grad():\n",
830
+ " generated_tokens= model.generate(input_features=batch[\"input_features\"].to(\"cuda\"),\n",
831
+ " forced_decoder_ids=forced_decoder_ids,\n",
832
+ " max_new_tokens=255).cpu().numpy()\n",
833
+ " labels = batch[\"labels\"].cpu().numpy()\n",
834
+ " labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)\n",
835
+ " decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
836
+ " decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
837
+ " metric.add_batch(\n",
838
+ " predictions=[normalizer(pred).strip() for pred in decoded_preds],\n",
839
+ " references=[normalizer(label).strip() for label in decoded_labels],\n",
840
+ " )\n",
841
+ " # if step==0:\n",
842
+ " # break\n",
843
+ " del generated_tokens, labels, batch\n",
844
+ " gc.collect()\n",
845
+ "normalized_wer = 100 * metric.compute()\n",
846
+ "print(f\"{normalized_wer=}\")"
847
+ ]
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": 13,
852
+ "metadata": {},
853
+ "outputs": [
854
+ {
855
+ "name": "stdout",
856
+ "output_type": "stream",
857
+ "text": [
858
+ "pred='toàn bộ phi hành đoàn đã bị giết chết khiến con tàu quân sự đậm thẳng vào ishimura', label='toàn bộ phi hành đoàn đã bị giết chết khiến con tàu quân sự đâm thẳng vào ishimura'\n",
859
+ "pred='đủ kinh nghiệm để mình quản lý nhân viên và mình làm sao để mình đưa ra một cái dịch vụ tốt nhất', label='đủ kinh nghiệm để mình quản lý nhân viên và mình làm sao để mình đưa ra một cái dịch vụ tốt nhất'\n",
860
+ "pred='nói một trong một cái chương trình trong tương lai về ngành thu y thì ở mỹ tất cả các đại học nào lớn đều có ngành thu y hết', label='<unk> nói một trong một cái chương trình trong tương lai về ngành thú y thì ở mỹ tất cả các đại học nào lớn đều có ngành thú y hết'\n",
861
+ "pred='phấn đấu đến năm hai ngàn không trăm hai mười có từ tám trăm đến một ngàn kinh nghiệp tham gia sàn sâu dịch thu mại điện tử của bộ công thương và quốc tế năm mươi phần trăm số này vô bắn trên sàn', label='phấn đấu đến năm hai ngàn không trăm hai mươi có từ tám trăm đến một ngàn doanh nghiệp tham gia sàn giao dịch thương mại điện tử của bộ công thương và quốc tế năm mươi phần trăm số này luôn bán trên sàn'\n",
862
+ "pred='còn trách nhiệm kiểm tra thanh tra là của ủy ban nhân dân các cấp', label='còn trách nhiệm kiểm tra thanh tra là của ủy ban nhân dân các cấp'\n",
863
+ "pred='vậy mà cậu im lặng khóa trái tìm mình chắc địa dành cho cái gì đó vĩ đại hơn chăng', label='vậy mà cậu im lặng khóa trái tim mình chắc để giành cho cái gì đó vĩ đại hơn chăng'\n",
864
+ "pred='khi nộp phiếu trả lời trắc nghiệm thí sinh phải ghi tên và danh sách thí sinh nộp bài', label='khi nộp phiếu trả lời trắc nghiệm thí sinh phải ký tên vào danh sách thí sinh nộp bài'\n",
865
+ "pred='khi nghĩ rằng mình đã khỏi ai ngờ ung thư lại tái phát và tôi đã lắng nghe câu chuyện của tất cả mọi người', label='khi nghĩ rằng mình đã khỏi ai ngờ ung thư lại tái phát và tôi đã lắng nghe câu chuyện của tất cả mọi người'\n",
866
+ "pred='người cùng ấp là trương thật từng muốn kết giao với giám vì ông từ chối', label='người cùng ấp là trương thực từng muốn kết giao với giám bị ông từ chối'\n",
867
+ "pred='bài thơ với những dòng thơ rất xúc động như sau', label='bài thơ với những dòng thơ rất xúc động như sau'\n",
868
+ "pred='công bố chỉ số niềm tin kinh doanh của doanh nghiệp', label='công bố chỉ s�� niềm tin kinh doanh của doanh nghiệp'\n",
869
+ "pred='khi quanh hồ tổng tới thăng lông đúng lúc tô trung tự đang đánh nhau to với đồ quảng', label='khi quân hộ tống tới thăng long đúng lúc tô trung từ đang đánh nhau to với đỗ quảng'\n",
870
+ "pred='chứ không lẽ bây giờ kêu men trai', label='chứ hổng lẽ bây giờ kêu mê trai'\n",
871
+ "pred='trong thời gian đó anh ấy hãy tâm sự với tôi', label='trong thời gian đó anh ấy hay tâm sự với tôi'\n",
872
+ "pred='mi mo sa lại cho màu sắc lá đẹp không cần dùng đến màu nhuộng hoàng sực dỡ từ duy nhẹ bé đó đã giúp vườn mi mo sa của bà nhanh chóng đem lại lợi nhuộn', label='mi mo sa lại cho màu sắc lá đẹp không cần dùng đến màu nhuộm hoa rực rỡ tư duy nhạy bén đó đã giúp vườn mi mo sa của bà nhanh chóng đem lại lợi nhuận'\n",
873
+ "pred='chơi tìm kiếm tài năng thiên đỉnh god thai lần thế các táo đâu hết cả rồi', label='chơi tìm kiếm tài năng thiên đình gót thai lừn thế các táo đâu hết cả rồi'\n",
874
+ "pred='dù đức và pháp bất đồng sâu sắc nhưng chính kiến của họ thì đều sai', label='dù đức và pháp bất đồng sâu sắc nhưng chính kiến của họ thì đều sai'\n",
875
+ "pred='đại ca bảo không hình anh ra mà ngồi anh đánh đi', label='đại ca bảo buông anh ra mà thôi anh'\n",
876
+ "pred='khi mà mang thai bác thị cũng cảnh báo rồi', label='khi mà mang thai thì bác sĩ cũng cảnh báo rồi'\n",
877
+ "pred='là tăng giảm thất thường và đột xuất kéo dài', label='mà tăng giảm thất thường và đột xuất kéo dài'\n"
878
+ ]
879
+ }
880
+ ],
881
+ "source": [
882
+ "for pred,label in zip(decoded_preds,decoded_labels):\n",
883
+ " print(f\"{pred=}, {label=}\")"
884
+ ]
885
+ },
886
+ {
887
+ "cell_type": "code",
888
+ "execution_count": 14,
889
+ "metadata": {},
890
+ "outputs": [],
891
+ "source": [
892
+ "import torch\n",
893
+ "from transformers import (\n",
894
+ " AutomaticSpeechRecognitionPipeline,\n",
895
+ " WhisperForConditionalGeneration,\n",
896
+ " WhisperTokenizer,\n",
897
+ " WhisperProcessor,\n",
898
+ ")\n",
899
+ "from peft import PeftModel, PeftConfig\n",
900
+ "\n",
901
+ "\n",
902
+ "peft_model_id = \"./Vietnamese_ASR\"\n",
903
+ "language = \"Vietnamese\"\n",
904
+ "task = \"transcribe\"\n",
905
+ "\n",
906
+ "peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
907
+ "model = WhisperForConditionalGeneration.from_pretrained(\n",
908
+ " peft_config.base_model_name_or_path\n",
909
+ ")\n",
910
+ "model = PeftModel.from_pretrained(model, peft_model_id)\n",
911
+ "merged_model = model.merge_and_unload()\n"
912
+ ]
913
+ },
914
+ {
915
+ "cell_type": "code",
916
+ "execution_count": 16,
917
+ "metadata": {},
918
+ "outputs": [],
919
+ "source": [
920
+ "merged_model.save_pretrained(\"./Vietnamese_ASR/merged\")"
921
+ ]
922
+ },
923
+ {
924
+ "cell_type": "code",
925
+ "execution_count": 17,
926
+ "metadata": {},
927
+ "outputs": [],
928
+ "source": [
929
+ "from transformers import WhisperTokenizer\n",
930
+ "\n",
931
+ "tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-medium', language=language, task=task)"
932
+ ]
933
+ },
934
+ {
935
+ "cell_type": "code",
936
+ "execution_count": 18,
937
+ "metadata": {},
938
+ "outputs": [
939
+ {
940
+ "data": {
941
+ "text/plain": [
942
+ "('./Vietnamese_ASR/merged/tokenizer_config.json',\n",
943
+ " './Vietnamese_ASR/merged/special_tokens_map.json',\n",
944
+ " './Vietnamese_ASR/merged/vocab.json',\n",
945
+ " './Vietnamese_ASR/merged/merges.txt',\n",
946
+ " './Vietnamese_ASR/merged/normalizer.json',\n",
947
+ " './Vietnamese_ASR/merged/added_tokens.json')"
948
+ ]
949
+ },
950
+ "execution_count": 18,
951
+ "metadata": {},
952
+ "output_type": "execute_result"
953
+ }
954
+ ],
955
+ "source": [
956
+ "tokenizer.save_pretrained('./Vietnamese_ASR/merged')"
957
+ ]
958
+ },
959
+ {
960
+ "cell_type": "code",
961
+ "execution_count": 19,
962
+ "metadata": {},
963
+ "outputs": [
964
+ {
965
+ "name": "stdout",
966
+ "output_type": "stream",
967
+ "text": [
968
+ "/bin/bash: /home/tesla/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n"
969
+ ]
970
+ }
971
+ ],
972
+ "source": [
973
+ "!ct2-transformers-converter --model ./Vietnamese_ASR/merged --output_dir ./Vietnamese_ASR/ct2ranslate"
974
+ ]
975
+ },
976
+ {
977
+ "cell_type": "code",
978
+ "execution_count": 20,
979
+ "metadata": {},
980
+ "outputs": [
981
+ {
982
+ "name": "stdout",
983
+ "output_type": "stream",
984
+ "text": [
985
+ "/bin/bash: /home/tesla/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n"
986
+ ]
987
+ }
988
+ ],
989
+ "source": [
990
+ "!ct2-transformers-converter --model ./Vietnamese_ASR/merged --output_dir ./Vietnamese_ASR/ct2ranslate/quantized --quantization float16"
991
+ ]
992
+ },
993
+ {
994
+ "cell_type": "code",
995
+ "execution_count": 6,
996
+ "metadata": {},
997
+ "outputs": [
998
+ {
999
+ "name": "stderr",
1000
+ "output_type": "stream",
1001
+ "text": [
1002
+ "/media/tesla/New Volume/DEMO/DUY/Vietnamese_ASR/Vietnamese_ASR/src/Vietnamese_ASR is already a clone of https://huggingface.co/DuyTa/Vietnamese_ASR. Make sure you pull the latest changes with `repo.git_pull()`.\n"
1003
+ ]
1004
+ }
1005
+ ],
1006
+ "source": [
1007
+ "from huggingface_hub import Repository\n",
1008
+ "repo = Repository(local_dir=\"\", clone_from='DuyTa/Vietnamese_ASR')"
1009
+ ]
1010
+ },
1011
+ {
1012
+ "cell_type": "code",
1013
+ "execution_count": 6,
1014
+ "metadata": {},
1015
+ "outputs": [
1016
+ {
1017
+ "data": {
1018
+ "application/vnd.jupyter.widget-view+json": {
1019
+ "model_id": "061e5ea903e04d2e95bc3ff8a8de434b",
1020
+ "version_major": 2,
1021
+ "version_minor": 0
1022
+ },
1023
+ "text/plain": [
1024
+ "Clean file runs/Aug17_22-42-43_tesla-T4/events.out.tfevents.1692289257.tesla-T4.201346.0: 14%|#4 | 1.0…"
1025
+ ]
1026
+ },
1027
+ "metadata": {},
1028
+ "output_type": "display_data"
1029
+ }
1030
+ ],
1031
+ "source": [
1032
+ "repo.git_pull(rebase=True)"
1033
+ ]
1034
+ },
1035
+ {
1036
+ "cell_type": "code",
1037
+ "execution_count": null,
1038
+ "metadata": {},
1039
+ "outputs": [],
1040
+ "source": [
1041
+ "repo.git_add(\".\")\n",
1042
+ "repo.git_commit(commit_message=\"3 epochs finetuning and quantized model )\")"
1043
+ ]
1044
+ },
1045
+ {
1046
+ "cell_type": "code",
1047
+ "execution_count": 8,
1048
+ "metadata": {},
1049
+ "outputs": [
1050
+ {
1051
+ "name": "stderr",
1052
+ "output_type": "stream",
1053
+ "text": [
1054
+ "Several commits (3) will be pushed upstream.\n",
1055
+ "The progress bars may be unreliable.\n"
1056
+ ]
1057
+ },
1058
+ {
1059
+ "name": "stderr",
1060
+ "output_type": "stream",
1061
+ "text": [
1062
+ "To https://huggingface.co/DuyTa/Vietnamese_ASR\n",
1063
+ " 63bacc4..82e8e84 main -> main\n",
1064
+ "\n"
1065
+ ]
1066
+ },
1067
+ {
1068
+ "data": {
1069
+ "text/plain": [
1070
+ "'https://huggingface.co/DuyTa/Vietnamese_ASR/commit/82e8e84fe4f1ffee17eff82c39a163f4b81335d5'"
1071
+ ]
1072
+ },
1073
+ "execution_count": 8,
1074
+ "metadata": {},
1075
+ "output_type": "execute_result"
1076
+ }
1077
+ ],
1078
+ "source": [
1079
+ "repo.git_push()"
1080
+ ]
1081
+ },
1082
+ {
1083
+ "cell_type": "code",
1084
+ "execution_count": null,
1085
+ "metadata": {},
1086
+ "outputs": [],
1087
+ "source": [
1088
+ "merged_model.push_to_hub(\"DuyTa/MITI_Whisper\")"
1089
+ ]
1090
+ },
1091
+ {
1092
+ "cell_type": "code",
1093
+ "execution_count": null,
1094
+ "metadata": {},
1095
+ "outputs": [],
1096
+ "source": [
1097
+ "import torch\n",
1098
+ "import gradio as gr\n",
1099
+ "from transformers import (\n",
1100
+ " AutomaticSpeechRecognitionPipeline,\n",
1101
+ " WhisperForConditionalGeneration,\n",
1102
+ " WhisperTokenizer,\n",
1103
+ " WhisperProcessor,\n",
1104
+ ")\n",
1105
+ "from peft import PeftModel, PeftConfig\n",
1106
+ "\n",
1107
+ "\n",
1108
+ "peft_model_id = \"DuyTa/MITI_Whisper\"\n",
1109
+ "language = \"Vietnamese\"\n",
1110
+ "task = \"transcribe\"\n",
1111
+ "peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
1112
+ "model = WhisperForConditionalGeneration.from_pretrained(\n",
1113
+ " peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
1114
+ ")\n",
1115
+ "\n",
1116
+ "model = PeftModel.from_pretrained(model, peft_model_id)\n",
1117
+ "tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
1118
+ "processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
1119
+ "feature_extractor = processor.feature_extractor\n",
1120
+ "forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)\n",
1121
+ "pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
1122
+ "\n",
1123
+ "\n",
1124
+ "def transcribe(audio):\n",
1125
+ " with torch.cuda.amp.autocast():\n",
1126
+ " text = pipe(audio, generate_kwargs={\"forced_decoder_ids\": forced_decoder_ids}, max_new_tokens=255)[\"text\"]\n",
1127
+ " return text\n",
1128
+ "\n",
1129
+ "\n",
1130
+ "iface = gr.Interface(\n",
1131
+ " fn=transcribe,\n",
1132
+ " inputs=gr.Audio(source=\"upload\", type=\"filepath\"),\n",
1133
+ " outputs=\"text\",\n",
1134
+ " title=\"PEFT LoRA\",\n",
1135
+ " description=\"Realtime demo for Vietnamese speech recognition using `PEFT-LoRA+INT8` fine-tuned Whisper Medium .\",\n",
1136
+ ")\n",
1137
+ "\n",
1138
+ "iface.launch(share=True)"
1139
+ ]
1140
+ }
1141
+ ],
1142
+ "metadata": {
1143
+ "kernelspec": {
1144
+ "display_name": "DUY",
1145
+ "language": "python",
1146
+ "name": "python3"
1147
+ },
1148
+ "language_info": {
1149
+ "codemirror_mode": {
1150
+ "name": "ipython",
1151
+ "version": 3
1152
+ },
1153
+ "file_extension": ".py",
1154
+ "mimetype": "text/x-python",
1155
+ "name": "python",
1156
+ "nbconvert_exporter": "python",
1157
+ "pygments_lexer": "ipython3",
1158
+ "version": "3.9.17"
1159
+ },
1160
+ "orig_nbformat": 4
1161
+ },
1162
+ "nbformat": 4,
1163
+ "nbformat_minor": 2
1164
+ }
src/whisperX.ipynb ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 15,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n",
13
+ "Model was trained with torch 1.10.0+cu102, yours is 2.0.0+cu118. Bad things might happen unless you revert torch to 1.x.\n",
14
+ "CPU times: user 826 ms, sys: 96.7 ms, total: 923 ms\n",
15
+ "Wall time: 831 ms\n",
16
+ "[{'text': 'đó là ước vọng của nguyễn ái quốc từ những năm hai mươi của thế kỷ trước về một nhà nước việt nam độc lập dân chủ', 'start': 0.008, 'end': 6.556}]\n"
17
+ ]
18
+ }
19
+ ],
20
+ "source": [
21
+ "import whisperx\n",
22
+ "import gc \n",
23
+ "\n",
24
+ "device = \"cuda\" \n",
25
+ "audio_file = \"6.wav\"\n",
26
+ "batch_size = 16 \n",
27
+ "compute_type = \"float16\" # change to \"int8\" if low on GPU mem (may reduce accuracy)\n",
28
+ "model_path = \"./Vietnamese_ASR/ct2ranslate\"\n",
29
+ "# 1. Transcribe with original whisper (batched)\n",
30
+ "model = whisperx.load_model(model_path, device, compute_type=compute_type,language='vi')\n",
31
+ "\n",
32
+ "audio = whisperx.load_audio(audio_file)\n",
33
+ "%time result = model.transcribe(audio, batch_size=batch_size)\n",
34
+ "print(result[\"segments\"]) # before alignment\n",
35
+ "\n",
36
+ "# delete model if low on GPU resources\n",
37
+ "# import gc; gc.collect(); torch.cuda.empty_cache(); del model\n",
38
+ "\n"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import gc; gc.collect()\n",
48
+ "import torch\n",
49
+ "torch.cuda.empty_cache(); del model"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 9,
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stderr",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at nguyenvulebinh/wav2vec2-base-vi and are newly initialized: ['lm_head.weight', 'lm_head.bias']\n",
62
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
63
+ ]
64
+ },
65
+ {
66
+ "name": "stdout",
67
+ "output_type": "stream",
68
+ "text": [
69
+ "[{'start': 0.008, 'end': 2.396, 'text': 'với một người đi làm thuê như anh số tiền kiếm được chưa đủ để thoả mãn nhu cầu cá nhân nói gì đến chăm lo cho gia đình', 'words': [{'word': 'với', 'start': 0.008, 'end': 0.068, 'score': 0.01}, {'word': 'một', 'start': 0.088, 'end': 0.148, 'score': 0.01}, {'word': 'người', 'start': 0.169, 'end': 0.269, 'score': 0.011}, {'word': 'đi', 'start': 0.289, 'end': 0.329, 'score': 0.01}, {'word': 'làm', 'start': 0.349, 'end': 0.409, 'score': 0.011}, {'word': 'thuê', 'start': 0.429, 'end': 0.51, 'score': 0.01}, {'word': 'như', 'start': 0.53, 'end': 0.59, 'score': 0.012}, {'word': 'anh', 'start': 0.61, 'end': 0.67, 'score': 0.01}, {'word': 'số', 'start': 0.69, 'end': 0.73, 'score': 0.01}, {'word': 'tiền', 'start': 0.75, 'end': 0.831, 'score': 0.01}, {'word': 'kiếm', 'start': 0.851, 'end': 0.931, 'score': 0.01}, {'word': 'được', 'start': 0.951, 'end': 1.031, 'score': 0.01}, {'word': 'chưa', 'start': 1.051, 'end': 1.132, 'score': 0.01}, {'word': 'đủ', 'start': 1.152, 'end': 1.192, 'score': 0.01}, {'word': 'để', 'start': 1.212, 'end': 1.252, 'score': 0.01}, {'word': 'thoả', 'start': 1.272, 'end': 1.353, 'score': 0.01}, {'word': 'mãn', 'start': 1.373, 'end': 1.433, 'score': 0.011}, {'word': 'nhu', 'start': 1.453, 'end': 1.513, 'score': 0.011}, {'word': 'cầu', 'start': 1.533, 'end': 1.593, 'score': 0.011}, {'word': 'cá', 'start': 1.613, 'end': 1.654, 'score': 0.01}, {'word': 'nhân', 'start': 1.674, 'end': 1.754, 'score': 0.011}, {'word': 'nói', 'start': 1.774, 'end': 1.834, 'score': 0.01}, {'word': 'gì', 'start': 1.854, 'end': 1.894, 'score': 0.011}, {'word': 'đến', 'start': 1.914, 'end': 1.975, 'score': 0.01}, {'word': 'chăm', 'start': 1.995, 'end': 2.075, 'score': 0.011}, {'word': 'lo', 'start': 2.095, 'end': 2.135, 'score': 0.009}, {'word': 'cho', 'start': 2.155, 'end': 2.215, 'score': 0.011}, {'word': 'gia', 'start': 2.235, 'end': 2.296, 'score': 0.01}, {'word': 'đình', 'start': 2.316, 'end': 2.396, 'score': 0.011}]}]\n"
70
+ ]
71
+ }
72
+ ],
73
+ "source": [
74
+ "# 2. Align whisper output\n",
75
+ "device = \"cuda\" \n",
76
+ "audio_file = \"audio.wav\"\n",
77
+ "batch_size = 16 \n",
78
+ "compute_type = \"float16\" # change to \"int8\" if low on GPU mem (may reduce accuracy)\n",
79
+ "model_path = \"./Vietnamese_ASR/ct2ranslate\"\n",
80
+ "model_a, metadata = whisperx.load_align_model(language_code=\"vi\" ,device=device)\n",
81
+ "result = whisperx.align(result[\"segments\"], model_a, metadata, audio, device, return_char_alignments=False)\n",
82
+ "\n",
83
+ "print(result[\"segments\"]) # after alignment\n",
84
+ "\n",
85
+ "# delete model if low on GPU resources\n",
86
+ "import gc; gc.collect(); torch.cuda.empty_cache(); del model_a\n",
87
+ "\n"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "# 3. Assign speaker labels\n",
97
+ "diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)\n",
98
+ "\n",
99
+ "# add min/max number of speakers if known\n",
100
+ "diarize_segments = diarize_model(audio)\n",
101
+ "# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)\n",
102
+ "\n",
103
+ "result = whisperx.assign_word_speakers(diarize_segments, result)\n",
104
+ "print(diarize_segments)\n",
105
+ "print(result[\"segments\"]) # segments are now assigned speaker IDs"
106
+ ]
107
+ }
108
+ ],
109
+ "metadata": {
110
+ "kernelspec": {
111
+ "display_name": "DUY",
112
+ "language": "python",
113
+ "name": "python3"
114
+ },
115
+ "language_info": {
116
+ "codemirror_mode": {
117
+ "name": "ipython",
118
+ "version": 3
119
+ },
120
+ "file_extension": ".py",
121
+ "mimetype": "text/x-python",
122
+ "name": "python",
123
+ "nbconvert_exporter": "python",
124
+ "pygments_lexer": "ipython3",
125
+ "version": "3.9.17"
126
+ },
127
+ "orig_nbformat": 4
128
+ },
129
+ "nbformat": 4,
130
+ "nbformat_minor": 2
131
+ }
src/whisper_quant.py ADDED
@@ -0,0 +1,995 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import logging
3
+ import os
4
+ import zlib
5
+
6
+ from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
7
+
8
+ import ctranslate2
9
+ import numpy as np
10
+ import tokenizers
11
+
12
+ from faster_whisper.audio import decode_audio
13
+ from faster_whisper.feature_extractor import FeatureExtractor
14
+ from faster_whisper.tokenizer import Tokenizer
15
+ from download_quantized import download_model, format_timestamp, get_logger
16
+ from faster_whisper.vad import (
17
+ SpeechTimestampsMap,
18
+ VadOptions,
19
+ collect_chunks,
20
+ get_speech_timestamps,
21
+ )
22
+
23
+
24
+ class Word(NamedTuple):
25
+ start: float
26
+ end: float
27
+ word: str
28
+ probability: float
29
+
30
+
31
+ class Segment(NamedTuple):
32
+ id: int
33
+ seek: int
34
+ start: float
35
+ end: float
36
+ text: str
37
+ tokens: List[int]
38
+ temperature: float
39
+ avg_logprob: float
40
+ compression_ratio: float
41
+ no_speech_prob: float
42
+ words: Optional[List[Word]]
43
+
44
+
45
+ class TranscriptionOptions(NamedTuple):
46
+ beam_size: int
47
+ best_of: int
48
+ patience: float
49
+ length_penalty: float
50
+ repetition_penalty: float
51
+ log_prob_threshold: Optional[float]
52
+ no_speech_threshold: Optional[float]
53
+ compression_ratio_threshold: Optional[float]
54
+ condition_on_previous_text: bool
55
+ prompt_reset_on_temperature: float
56
+ temperatures: List[float]
57
+ initial_prompt: Optional[Union[str, Iterable[int]]]
58
+ prefix: Optional[str]
59
+ suppress_blank: bool
60
+ suppress_tokens: Optional[List[int]]
61
+ without_timestamps: bool
62
+ max_initial_timestamp: float
63
+ word_timestamps: bool
64
+ prepend_punctuations: str
65
+ append_punctuations: str
66
+
67
+
68
+ class TranscriptionInfo(NamedTuple):
69
+ language: str
70
+ language_probability: float
71
+ duration: float
72
+ all_language_probs: Optional[List[Tuple[str, float]]]
73
+ transcription_options: TranscriptionOptions
74
+ vad_options: VadOptions
75
+
76
+
77
+ class WhisperModel:
78
+ def __init__(
79
+ self,
80
+ model_size_or_path: str,
81
+ device: str = "auto",
82
+ device_index: Union[int, List[int]] = 0,
83
+ compute_type: str = "default",
84
+ cpu_threads: int = 0,
85
+ num_workers: int = 1,
86
+ download_root: Optional[str] = None,
87
+ local_files_only: bool = False,
88
+ ):
89
+ """Initializes the Whisper model.
90
+
91
+ Args:
92
+ model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
93
+ small, small.en, medium, medium.en, large-v1, or large-v2), a path to a converted
94
+ model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
95
+ When a size or a model ID is configured, the converted model is downloaded
96
+ from the Hugging Face Hub.
97
+ device: Device to use for computation ("cpu", "cuda", "auto").
98
+ device_index: Device ID to use.
99
+ The model can also be loaded on multiple GPUs by passing a list of IDs
100
+ (e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
101
+ when transcribe() is called from multiple Python threads (see also num_workers).
102
+ compute_type: Type to use for computation.
103
+ See https://opennmt.net/CTranslate2/quantization.html.
104
+ cpu_threads: Number of threads to use when running on CPU (4 by default).
105
+ A non zero value overrides the OMP_NUM_THREADS environment variable.
106
+ num_workers: When transcribe() is called from multiple Python threads,
107
+ having multiple workers enables true parallelism when running the model
108
+ (concurrent calls to self.model.generate() will run in parallel).
109
+ This can improve the global throughput at the cost of increased memory usage.
110
+ download_root: Directory where the models should be saved. If not set, the models
111
+ are saved in the standard Hugging Face cache directory.
112
+ local_files_only: If True, avoid downloading the file and return the path to the
113
+ local cached file if it exists.
114
+ """
115
+ self.logger = get_logger()
116
+
117
+ if os.path.isdir(model_size_or_path):
118
+ model_path = model_size_or_path
119
+ else:
120
+ model_path = download_model(
121
+ model_size_or_path,
122
+ local_files_only=local_files_only,
123
+ cache_dir=download_root,
124
+ )
125
+
126
+ self.model = ctranslate2.models.Whisper(
127
+ model_path,
128
+ device=device,
129
+ device_index=device_index,
130
+ compute_type=compute_type,
131
+ intra_threads=cpu_threads,
132
+ inter_threads=num_workers,
133
+ )
134
+
135
+ tokenizer_file = os.path.join(model_path, "tokenizer.json")
136
+ if os.path.isfile(tokenizer_file):
137
+ self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
138
+ else:
139
+ self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
140
+ "openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
141
+ )
142
+
143
+ self.feature_extractor = FeatureExtractor()
144
+ self.num_samples_per_token = self.feature_extractor.hop_length * 2
145
+ self.frames_per_second = (
146
+ self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
147
+ )
148
+ self.tokens_per_second = (
149
+ self.feature_extractor.sampling_rate // self.num_samples_per_token
150
+ )
151
+ self.input_stride = 2
152
+ self.time_precision = 0.02
153
+ self.max_length = 448
154
+
155
+ def transcribe(
156
+ self,
157
+ audio: Union[str, BinaryIO, np.ndarray],
158
+ language: Optional[str] = None,
159
+ task: str = "transcribe",
160
+ beam_size: int = 5,
161
+ best_of: int = 5,
162
+ patience: float = 1,
163
+ length_penalty: float = 1,
164
+ repetition_penalty: float = 1,
165
+ temperature: Union[float, List[float], Tuple[float, ...]] = [
166
+ 0.0,
167
+ 0.2,
168
+ 0.4,
169
+ 0.6,
170
+ 0.8,
171
+ 1.0,
172
+ ],
173
+ compression_ratio_threshold: Optional[float] = 2.4,
174
+ log_prob_threshold: Optional[float] = -1.0,
175
+ no_speech_threshold: Optional[float] = 0.6,
176
+ condition_on_previous_text: bool = True,
177
+ prompt_reset_on_temperature: float = 0.5,
178
+ initial_prompt: Optional[Union[str, Iterable[int]]] = None,
179
+ prefix: Optional[str] = None,
180
+ suppress_blank: bool = True,
181
+ suppress_tokens: Optional[List[int]] = [-1],
182
+ without_timestamps: bool = False,
183
+ max_initial_timestamp: float = 1.0,
184
+ word_timestamps: bool = False,
185
+ prepend_punctuations: str = "\"'“¿([{-",
186
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
187
+ vad_filter: bool = False,
188
+ vad_parameters: Optional[Union[dict, VadOptions]] = None,
189
+ ) -> Tuple[Iterable[Segment], TranscriptionInfo]:
190
+ """Transcribes an input file.
191
+
192
+ Arguments:
193
+ audio: Path to the input file (or a file-like object), or the audio waveform.
194
+ language: The language spoken in the audio. It should be a language code such
195
+ as "en" or "fr". If not set, the language will be detected in the first 30 seconds
196
+ of audio.
197
+ task: Task to execute (transcribe or translate).
198
+ beam_size: Beam size to use for decoding.
199
+ best_of: Number of candidates when sampling with non-zero temperature.
200
+ patience: Beam search patience factor.
201
+ length_penalty: Exponential length penalty constant.
202
+ repetition_penalty: Penalty applied to the score of previously generated tokens
203
+ (set > 1 to penalize).
204
+ temperature: Temperature for sampling. It can be a tuple of temperatures,
205
+ which will be successively used upon failures according to either
206
+ `compression_ratio_threshold` or `log_prob_threshold`.
207
+ compression_ratio_threshold: If the gzip compression ratio is above this value,
208
+ treat as failed.
209
+ log_prob_threshold: If the average log probability over sampled tokens is
210
+ below this value, treat as failed.
211
+ no_speech_threshold: If the no_speech probability is higher than this value AND
212
+ the average log probability over sampled tokens is below `log_prob_threshold`,
213
+ consider the segment as silent.
214
+ condition_on_previous_text: If True, the previous output of the model is provided
215
+ as a prompt for the next window; disabling may make the text inconsistent across
216
+ windows, but the model becomes less prone to getting stuck in a failure loop,
217
+ such as repetition looping or timestamps going out of sync.
218
+ prompt_reset_on_temperature: Resets prompt if temperature is above this value.
219
+ Arg has effect only if condition_on_previous_text is True.
220
+ initial_prompt: Optional text string or iterable of token ids to provide as a
221
+ prompt for the first window.
222
+ prefix: Optional text to provide as a prefix for the first window.
223
+ suppress_blank: Suppress blank outputs at the beginning of the sampling.
224
+ suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
225
+ of symbols as defined in the model config.json file.
226
+ without_timestamps: Only sample text tokens.
227
+ max_initial_timestamp: The initial timestamp cannot be later than this.
228
+ word_timestamps: Extract word-level timestamps using the cross-attention pattern
229
+ and dynamic time warping, and include the timestamps for each word in each segment.
230
+ prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
231
+ with the next word
232
+ append_punctuations: If word_timestamps is True, merge these punctuation symbols
233
+ with the previous word
234
+ vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
235
+ without speech. This step is using the Silero VAD model
236
+ https://github.com/snakers4/silero-vad.
237
+ vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
238
+ parameters and default values in the class `VadOptions`).
239
+
240
+ Returns:
241
+ A tuple with:
242
+
243
+ - a generator over transcribed segments
244
+ - an instance of TranscriptionInfo
245
+ """
246
+ sampling_rate = self.feature_extractor.sampling_rate
247
+
248
+ if not isinstance(audio, np.ndarray):
249
+ audio = decode_audio(audio, sampling_rate=sampling_rate)
250
+
251
+ duration = audio.shape[0] / sampling_rate
252
+
253
+ self.logger.info(
254
+ "Processing audio with duration %s", format_timestamp(duration)
255
+ )
256
+
257
+ if vad_filter:
258
+ if vad_parameters is None:
259
+ vad_parameters = VadOptions()
260
+ elif isinstance(vad_parameters, dict):
261
+ vad_parameters = VadOptions(**vad_parameters)
262
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
263
+ audio = collect_chunks(audio, speech_chunks)
264
+
265
+ self.logger.info(
266
+ "VAD filter removed %s of audio",
267
+ format_timestamp(duration - (audio.shape[0] / sampling_rate)),
268
+ )
269
+
270
+ if self.logger.isEnabledFor(logging.DEBUG):
271
+ self.logger.debug(
272
+ "VAD filter kept the following audio segments: %s",
273
+ ", ".join(
274
+ "[%s -> %s]"
275
+ % (
276
+ format_timestamp(chunk["start"] / sampling_rate),
277
+ format_timestamp(chunk["end"] / sampling_rate),
278
+ )
279
+ for chunk in speech_chunks
280
+ ),
281
+ )
282
+
283
+ else:
284
+ speech_chunks = None
285
+
286
+ features = self.feature_extractor(audio)
287
+
288
+ encoder_output = None
289
+ all_language_probs = None
290
+
291
+ if language is None:
292
+ if not self.model.is_multilingual:
293
+ language = "en"
294
+ language_probability = 1
295
+ else:
296
+ segment = features[:, : self.feature_extractor.nb_max_frames]
297
+ encoder_output = self.encode(segment)
298
+ # results is a list of tuple[str, float] with language names and
299
+ # probabilities.
300
+ results = self.model.detect_language(encoder_output)[0]
301
+ # Parse language names to strip out markers
302
+ all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
303
+ # Get top language token and probability
304
+ language, language_probability = all_language_probs[0]
305
+
306
+ self.logger.info(
307
+ "Detected language '%s' with probability %.2f",
308
+ language,
309
+ language_probability,
310
+ )
311
+ else:
312
+ language_probability = 1
313
+
314
+ tokenizer = Tokenizer(
315
+ self.hf_tokenizer,
316
+ self.model.is_multilingual,
317
+ task=task,
318
+ language=language,
319
+ )
320
+
321
+ options = TranscriptionOptions(
322
+ beam_size=beam_size,
323
+ best_of=best_of,
324
+ patience=patience,
325
+ length_penalty=length_penalty,
326
+ repetition_penalty=repetition_penalty,
327
+ log_prob_threshold=log_prob_threshold,
328
+ no_speech_threshold=no_speech_threshold,
329
+ compression_ratio_threshold=compression_ratio_threshold,
330
+ condition_on_previous_text=condition_on_previous_text,
331
+ prompt_reset_on_temperature=prompt_reset_on_temperature,
332
+ temperatures=(
333
+ temperature if isinstance(temperature, (list, tuple)) else [temperature]
334
+ ),
335
+ initial_prompt=initial_prompt,
336
+ prefix=prefix,
337
+ suppress_blank=suppress_blank,
338
+ suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
339
+ without_timestamps=without_timestamps,
340
+ max_initial_timestamp=max_initial_timestamp,
341
+ word_timestamps=word_timestamps,
342
+ prepend_punctuations=prepend_punctuations,
343
+ append_punctuations=append_punctuations,
344
+ )
345
+
346
+ segments = self.generate_segments(features, tokenizer, options, encoder_output)
347
+
348
+ if speech_chunks:
349
+ segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
350
+
351
+ info = TranscriptionInfo(
352
+ language=language,
353
+ language_probability=language_probability,
354
+ duration=duration,
355
+ transcription_options=options,
356
+ vad_options=vad_parameters,
357
+ all_language_probs=all_language_probs,
358
+ )
359
+
360
+ return segments, info
361
+
362
+ def generate_segments(
363
+ self,
364
+ features: np.ndarray,
365
+ tokenizer: Tokenizer,
366
+ options: TranscriptionOptions,
367
+ encoder_output: Optional[ctranslate2.StorageView] = None,
368
+ ) -> Iterable[Segment]:
369
+ content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
370
+ idx = 0
371
+ seek = 0
372
+ all_tokens = []
373
+ prompt_reset_since = 0
374
+
375
+ if options.initial_prompt is not None:
376
+ if isinstance(options.initial_prompt, str):
377
+ initial_prompt = " " + options.initial_prompt.strip()
378
+ initial_prompt_tokens = tokenizer.encode(initial_prompt)
379
+ all_tokens.extend(initial_prompt_tokens)
380
+ else:
381
+ all_tokens.extend(options.initial_prompt)
382
+
383
+ last_speech_timestamp = 0.0
384
+ while seek < content_frames:
385
+ time_offset = seek * self.feature_extractor.time_per_frame
386
+ segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
387
+ segment_size = min(
388
+ self.feature_extractor.nb_max_frames, content_frames - seek
389
+ )
390
+ segment_duration = segment_size * self.feature_extractor.time_per_frame
391
+
392
+ if self.logger.isEnabledFor(logging.DEBUG):
393
+ self.logger.debug(
394
+ "Processing segment at %s", format_timestamp(time_offset)
395
+ )
396
+
397
+ previous_tokens = all_tokens[prompt_reset_since:]
398
+ prompt = self.get_prompt(
399
+ tokenizer,
400
+ previous_tokens,
401
+ without_timestamps=options.without_timestamps,
402
+ prefix=options.prefix if seek == 0 else None,
403
+ )
404
+
405
+ if encoder_output is None:
406
+ encoder_output = self.encode(segment)
407
+
408
+ (
409
+ result,
410
+ avg_logprob,
411
+ temperature,
412
+ compression_ratio,
413
+ ) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
414
+
415
+ if options.no_speech_threshold is not None:
416
+ # no voice activity check
417
+ should_skip = result.no_speech_prob > options.no_speech_threshold
418
+
419
+ if (
420
+ options.log_prob_threshold is not None
421
+ and avg_logprob > options.log_prob_threshold
422
+ ):
423
+ # don't skip if the logprob is high enough, despite the no_speech_prob
424
+ should_skip = False
425
+
426
+ if should_skip:
427
+ self.logger.debug(
428
+ "No speech threshold is met (%f > %f)",
429
+ result.no_speech_prob,
430
+ options.no_speech_threshold,
431
+ )
432
+
433
+ # fast-forward to the next segment boundary
434
+ seek += segment_size
435
+ encoder_output = None
436
+ continue
437
+
438
+ tokens = result.sequences_ids[0]
439
+
440
+ previous_seek = seek
441
+ current_segments = []
442
+
443
+ single_timestamp_ending = (
444
+ len(tokens) >= 2
445
+ and tokens[-2] < tokenizer.timestamp_begin
446
+ and tokens[-1] >= tokenizer.timestamp_begin
447
+ )
448
+
449
+ consecutive_timestamps = [
450
+ i
451
+ for i in range(len(tokens))
452
+ if i > 0
453
+ and tokens[i] >= tokenizer.timestamp_begin
454
+ and tokens[i - 1] >= tokenizer.timestamp_begin
455
+ ]
456
+
457
+ if len(consecutive_timestamps) > 0:
458
+ slices = list(consecutive_timestamps)
459
+ if single_timestamp_ending:
460
+ slices.append(len(tokens))
461
+
462
+ last_slice = 0
463
+ for current_slice in slices:
464
+ sliced_tokens = tokens[last_slice:current_slice]
465
+ start_timestamp_position = (
466
+ sliced_tokens[0] - tokenizer.timestamp_begin
467
+ )
468
+ end_timestamp_position = (
469
+ sliced_tokens[-1] - tokenizer.timestamp_begin
470
+ )
471
+ start_time = (
472
+ time_offset + start_timestamp_position * self.time_precision
473
+ )
474
+ end_time = (
475
+ time_offset + end_timestamp_position * self.time_precision
476
+ )
477
+
478
+ current_segments.append(
479
+ dict(
480
+ seek=seek,
481
+ start=start_time,
482
+ end=end_time,
483
+ tokens=sliced_tokens,
484
+ )
485
+ )
486
+ last_slice = current_slice
487
+
488
+ if single_timestamp_ending:
489
+ # single timestamp at the end means no speech after the last timestamp.
490
+ seek += segment_size
491
+ else:
492
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
493
+ last_timestamp_position = (
494
+ tokens[last_slice - 1] - tokenizer.timestamp_begin
495
+ )
496
+ seek += last_timestamp_position * self.input_stride
497
+
498
+ else:
499
+ duration = segment_duration
500
+ timestamps = [
501
+ token for token in tokens if token >= tokenizer.timestamp_begin
502
+ ]
503
+ if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
504
+ last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
505
+ duration = last_timestamp_position * self.time_precision
506
+
507
+ current_segments.append(
508
+ dict(
509
+ seek=seek,
510
+ start=time_offset,
511
+ end=time_offset + duration,
512
+ tokens=tokens,
513
+ )
514
+ )
515
+
516
+ seek += segment_size
517
+
518
+ if options.word_timestamps:
519
+ self.add_word_timestamps(
520
+ current_segments,
521
+ tokenizer,
522
+ encoder_output,
523
+ segment_size,
524
+ options.prepend_punctuations,
525
+ options.append_punctuations,
526
+ last_speech_timestamp=last_speech_timestamp,
527
+ )
528
+
529
+ word_end_timestamps = [
530
+ w["end"] for s in current_segments for w in s["words"]
531
+ ]
532
+ if len(word_end_timestamps) > 0:
533
+ last_speech_timestamp = word_end_timestamps[-1]
534
+ if not single_timestamp_ending and len(word_end_timestamps) > 0:
535
+ seek_shift = round(
536
+ (word_end_timestamps[-1] - time_offset) * self.frames_per_second
537
+ )
538
+
539
+ if seek_shift > 0:
540
+ seek = previous_seek + seek_shift
541
+
542
+ encoder_output = None
543
+
544
+ for segment in current_segments:
545
+ tokens = segment["tokens"]
546
+ text = tokenizer.decode(tokens)
547
+
548
+ if segment["start"] == segment["end"] or not text.strip():
549
+ continue
550
+
551
+ all_tokens.extend(tokens)
552
+ idx += 1
553
+
554
+ yield Segment(
555
+ id=idx,
556
+ seek=seek,
557
+ start=segment["start"],
558
+ end=segment["end"],
559
+ text=text,
560
+ tokens=tokens,
561
+ temperature=temperature,
562
+ avg_logprob=avg_logprob,
563
+ compression_ratio=compression_ratio,
564
+ no_speech_prob=result.no_speech_prob,
565
+ words=(
566
+ [Word(**word) for word in segment["words"]]
567
+ if options.word_timestamps
568
+ else None
569
+ ),
570
+ )
571
+
572
+ if (
573
+ not options.condition_on_previous_text
574
+ or temperature > options.prompt_reset_on_temperature
575
+ ):
576
+ if options.condition_on_previous_text:
577
+ self.logger.debug(
578
+ "Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
579
+ temperature,
580
+ options.prompt_reset_on_temperature,
581
+ )
582
+
583
+ prompt_reset_since = len(all_tokens)
584
+
585
+ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
586
+ # When the model is running on multiple GPUs, the encoder output should be moved
587
+ # to the CPU since we don't know which GPU will handle the next job.
588
+ to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
589
+
590
+ features = np.expand_dims(features, 0)
591
+ features = get_ctranslate2_storage(features)
592
+
593
+ return self.model.encode(features, to_cpu=to_cpu)
594
+
595
+ def generate_with_fallback(
596
+ self,
597
+ encoder_output: ctranslate2.StorageView,
598
+ prompt: List[int],
599
+ tokenizer: Tokenizer,
600
+ options: TranscriptionOptions,
601
+ ) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
602
+ decode_result = None
603
+ all_results = []
604
+ below_cr_threshold_results = []
605
+
606
+ max_initial_timestamp_index = int(
607
+ round(options.max_initial_timestamp / self.time_precision)
608
+ )
609
+
610
+ for temperature in options.temperatures:
611
+ if temperature > 0:
612
+ kwargs = {
613
+ "beam_size": 1,
614
+ "num_hypotheses": options.best_of,
615
+ "sampling_topk": 0,
616
+ "sampling_temperature": temperature,
617
+ }
618
+ else:
619
+ kwargs = {
620
+ "beam_size": options.beam_size,
621
+ "patience": options.patience,
622
+ }
623
+
624
+ result = self.model.generate(
625
+ encoder_output,
626
+ [prompt],
627
+ length_penalty=options.length_penalty,
628
+ repetition_penalty=options.repetition_penalty,
629
+ max_length=self.max_length,
630
+ return_scores=True,
631
+ return_no_speech_prob=True,
632
+ suppress_blank=options.suppress_blank,
633
+ suppress_tokens=options.suppress_tokens,
634
+ max_initial_timestamp_index=max_initial_timestamp_index,
635
+ **kwargs,
636
+ )[0]
637
+
638
+ tokens = result.sequences_ids[0]
639
+
640
+ # Recover the average log prob from the returned score.
641
+ seq_len = len(tokens)
642
+ cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
643
+ avg_logprob = cum_logprob / (seq_len + 1)
644
+
645
+ text = tokenizer.decode(tokens).strip()
646
+ compression_ratio = get_compression_ratio(text)
647
+
648
+ decode_result = (
649
+ result,
650
+ avg_logprob,
651
+ temperature,
652
+ compression_ratio,
653
+ )
654
+ all_results.append(decode_result)
655
+
656
+ needs_fallback = False
657
+
658
+ if options.compression_ratio_threshold is not None:
659
+ if compression_ratio > options.compression_ratio_threshold:
660
+ needs_fallback = True # too repetitive
661
+
662
+ self.logger.debug(
663
+ "Compression ratio threshold is not met with temperature %.1f (%f > %f)",
664
+ temperature,
665
+ compression_ratio,
666
+ options.compression_ratio_threshold,
667
+ )
668
+ else:
669
+ below_cr_threshold_results.append(decode_result)
670
+
671
+ if (
672
+ options.log_prob_threshold is not None
673
+ and avg_logprob < options.log_prob_threshold
674
+ ):
675
+ needs_fallback = True # average log probability is too low
676
+
677
+ self.logger.debug(
678
+ "Log probability threshold is not met with temperature %.1f (%f < %f)",
679
+ temperature,
680
+ avg_logprob,
681
+ options.log_prob_threshold,
682
+ )
683
+
684
+ if (
685
+ options.no_speech_threshold is not None
686
+ and result.no_speech_prob > options.no_speech_threshold
687
+ ):
688
+ needs_fallback = False # silence
689
+
690
+ if not needs_fallback:
691
+ break
692
+ else:
693
+ # all failed, select the result with the highest average log probability
694
+ decode_result = max(
695
+ below_cr_threshold_results or all_results, key=lambda x: x[1]
696
+ )
697
+
698
+ return decode_result
699
+
700
+ def get_prompt(
701
+ self,
702
+ tokenizer: Tokenizer,
703
+ previous_tokens: List[int],
704
+ without_timestamps: bool = False,
705
+ prefix: Optional[str] = None,
706
+ ) -> List[int]:
707
+ prompt = []
708
+
709
+ if previous_tokens:
710
+ prompt.append(tokenizer.sot_prev)
711
+ prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
712
+
713
+ prompt.extend(tokenizer.sot_sequence)
714
+
715
+ if without_timestamps:
716
+ prompt.append(tokenizer.no_timestamps)
717
+
718
+ if prefix:
719
+ prefix_tokens = tokenizer.encode(" " + prefix.strip())
720
+ if len(prefix_tokens) >= self.max_length // 2:
721
+ prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
722
+ if not without_timestamps:
723
+ prompt.append(tokenizer.timestamp_begin)
724
+ prompt.extend(prefix_tokens)
725
+
726
+ return prompt
727
+
728
+ def add_word_timestamps(
729
+ self,
730
+ segments: List[dict],
731
+ tokenizer: Tokenizer,
732
+ encoder_output: ctranslate2.StorageView,
733
+ num_frames: int,
734
+ prepend_punctuations: str,
735
+ append_punctuations: str,
736
+ last_speech_timestamp: float,
737
+ ):
738
+ if len(segments) == 0:
739
+ return
740
+
741
+ text_tokens_per_segment = [
742
+ [token for token in segment["tokens"] if token < tokenizer.eot]
743
+ for segment in segments
744
+ ]
745
+
746
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
747
+ alignment = self.find_alignment(
748
+ tokenizer, text_tokens, encoder_output, num_frames
749
+ )
750
+ word_durations = np.array([word["end"] - word["start"] for word in alignment])
751
+ word_durations = word_durations[word_durations.nonzero()]
752
+ median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
753
+ max_duration = median_duration * 2
754
+
755
+ # hack: truncate long words at sentence boundaries.
756
+ # a better segmentation algorithm based on VAD should be able to replace this.
757
+ if len(word_durations) > 0:
758
+ sentence_end_marks = ".。!!??"
759
+ # ensure words at sentence boundaries
760
+ # are not longer than twice the median word duration.
761
+ for i in range(1, len(alignment)):
762
+ if alignment[i]["end"] - alignment[i]["start"] > max_duration:
763
+ if alignment[i]["word"] in sentence_end_marks:
764
+ alignment[i]["end"] = alignment[i]["start"] + max_duration
765
+ elif alignment[i - 1]["word"] in sentence_end_marks:
766
+ alignment[i]["start"] = alignment[i]["end"] - max_duration
767
+
768
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
769
+
770
+ time_offset = (
771
+ segments[0]["seek"]
772
+ * self.feature_extractor.hop_length
773
+ / self.feature_extractor.sampling_rate
774
+ )
775
+
776
+ word_index = 0
777
+
778
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
779
+ saved_tokens = 0
780
+ words = []
781
+
782
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
783
+ timing = alignment[word_index]
784
+
785
+ if timing["word"]:
786
+ words.append(
787
+ dict(
788
+ word=timing["word"],
789
+ start=round(time_offset + timing["start"], 2),
790
+ end=round(time_offset + timing["end"], 2),
791
+ probability=timing["probability"],
792
+ )
793
+ )
794
+
795
+ saved_tokens += len(timing["tokens"])
796
+ word_index += 1
797
+
798
+ # hack: truncate long words at segment boundaries.
799
+ # a better segmentation algorithm based on VAD should be able to replace this.
800
+ if len(words) > 0:
801
+ # ensure the first and second word after a pause is not longer than
802
+ # twice the median word duration.
803
+ if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
804
+ words[0]["end"] - words[0]["start"] > max_duration
805
+ or (
806
+ len(words) > 1
807
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
808
+ )
809
+ ):
810
+ if (
811
+ len(words) > 1
812
+ and words[1]["end"] - words[1]["start"] > max_duration
813
+ ):
814
+ boundary = max(
815
+ words[1]["end"] / 2, words[1]["end"] - max_duration
816
+ )
817
+ words[0]["end"] = words[1]["start"] = boundary
818
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
819
+
820
+ # prefer the segment-level start timestamp if the first word is too long.
821
+ if (
822
+ segment["start"] < words[0]["end"]
823
+ and segment["start"] - 0.5 > words[0]["start"]
824
+ ):
825
+ words[0]["start"] = max(
826
+ 0, min(words[0]["end"] - median_duration, segment["start"])
827
+ )
828
+ else:
829
+ segment["start"] = words[0]["start"]
830
+
831
+ # prefer the segment-level end timestamp if the last word is too long.
832
+ if (
833
+ segment["end"] > words[-1]["start"]
834
+ and segment["end"] + 0.5 < words[-1]["end"]
835
+ ):
836
+ words[-1]["end"] = max(
837
+ words[-1]["start"] + median_duration, segment["end"]
838
+ )
839
+ else:
840
+ segment["end"] = words[-1]["end"]
841
+
842
+ last_speech_timestamp = segment["end"]
843
+
844
+ segment["words"] = words
845
+
846
+ def find_alignment(
847
+ self,
848
+ tokenizer: Tokenizer,
849
+ text_tokens: List[int],
850
+ encoder_output: ctranslate2.StorageView,
851
+ num_frames: int,
852
+ median_filter_width: int = 7,
853
+ ) -> List[dict]:
854
+ if len(text_tokens) == 0:
855
+ return []
856
+
857
+ result = self.model.align(
858
+ encoder_output,
859
+ tokenizer.sot_sequence,
860
+ [text_tokens],
861
+ num_frames,
862
+ median_filter_width=median_filter_width,
863
+ )[0]
864
+
865
+ text_token_probs = result.text_token_probs
866
+
867
+ alignments = result.alignments
868
+ text_indices = np.array([pair[0] for pair in alignments])
869
+ time_indices = np.array([pair[1] for pair in alignments])
870
+
871
+ words, word_tokens = tokenizer.split_to_word_tokens(
872
+ text_tokens + [tokenizer.eot]
873
+ )
874
+ word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
875
+ if len(word_boundaries) <= 1:
876
+ return []
877
+
878
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
879
+ jump_times = time_indices[jumps] / self.tokens_per_second
880
+ start_times = jump_times[word_boundaries[:-1]]
881
+ end_times = jump_times[word_boundaries[1:]]
882
+ word_probabilities = [
883
+ np.mean(text_token_probs[i:j])
884
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
885
+ ]
886
+
887
+ return [
888
+ dict(
889
+ word=word, tokens=tokens, start=start, end=end, probability=probability
890
+ )
891
+ for word, tokens, start, end, probability in zip(
892
+ words, word_tokens, start_times, end_times, word_probabilities
893
+ )
894
+ ]
895
+
896
+
897
+ def restore_speech_timestamps(
898
+ segments: Iterable[Segment],
899
+ speech_chunks: List[dict],
900
+ sampling_rate: int,
901
+ ) -> Iterable[Segment]:
902
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
903
+
904
+ for segment in segments:
905
+ if segment.words:
906
+ words = []
907
+ for word in segment.words:
908
+ # Ensure the word start and end times are resolved to the same chunk.
909
+ middle = (word.start + word.end) / 2
910
+ chunk_index = ts_map.get_chunk_index(middle)
911
+ word = word._replace(
912
+ start=ts_map.get_original_time(word.start, chunk_index),
913
+ end=ts_map.get_original_time(word.end, chunk_index),
914
+ )
915
+ words.append(word)
916
+
917
+ segment = segment._replace(
918
+ start=words[0].start,
919
+ end=words[-1].end,
920
+ words=words,
921
+ )
922
+
923
+ else:
924
+ segment = segment._replace(
925
+ start=ts_map.get_original_time(segment.start),
926
+ end=ts_map.get_original_time(segment.end),
927
+ )
928
+
929
+ yield segment
930
+
931
+
932
+ def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
933
+ segment = np.ascontiguousarray(segment)
934
+ segment = ctranslate2.StorageView.from_array(segment)
935
+ return segment
936
+
937
+
938
+ def get_compression_ratio(text: str) -> float:
939
+ text_bytes = text.encode("utf-8")
940
+ return len(text_bytes) / len(zlib.compress(text_bytes))
941
+
942
+
943
+ def get_suppressed_tokens(tokenizer, suppress_tokens):
944
+ if not suppress_tokens or -1 in suppress_tokens:
945
+ return suppress_tokens
946
+
947
+ suppress_tokens = list(suppress_tokens)
948
+
949
+ # Ensure the following special tokens are suppressed when the user does
950
+ # not use the default set (-1).
951
+ suppress_tokens.extend(
952
+ [
953
+ tokenizer.transcribe,
954
+ tokenizer.translate,
955
+ tokenizer.sot,
956
+ tokenizer.sot_prev,
957
+ tokenizer.sot_lm,
958
+ ]
959
+ )
960
+
961
+ return sorted(set(suppress_tokens))
962
+
963
+
964
+ def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
965
+ # merge prepended punctuations
966
+ i = len(alignment) - 2
967
+ j = len(alignment) - 1
968
+ while i >= 0:
969
+ previous = alignment[i]
970
+ following = alignment[j]
971
+ if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
972
+ # prepend it to the following word
973
+ following["word"] = previous["word"] + following["word"]
974
+ following["tokens"] = previous["tokens"] + following["tokens"]
975
+ previous["word"] = ""
976
+ previous["tokens"] = []
977
+ else:
978
+ j = i
979
+ i -= 1
980
+
981
+ # merge appended punctuations
982
+ i = 0
983
+ j = 1
984
+ while j < len(alignment):
985
+ previous = alignment[i]
986
+ following = alignment[j]
987
+ if not previous["word"].endswith(" ") and following["word"] in appended:
988
+ # append it to the previous word
989
+ previous["word"] = previous["word"] + following["word"]
990
+ previous["tokens"] = previous["tokens"] + following["tokens"]
991
+ following["word"] = ""
992
+ following["tokens"] = []
993
+ else:
994
+ i = j
995
+ j += 1