Source )
Browse files- src/EDA.ipynb +391 -0
- src/MITI.ipynb +342 -0
- src/download_quantized.py +132 -0
- src/laboratory.ipynb +0 -0
- src/lora_tuning.py +773 -0
- src/merge_lora.py +44 -0
- src/prepare_data.py +212 -0
- src/realtime.py +157 -0
- src/requirements.txt +15 -0
- src/test_whisper.ipynb +1546 -0
- src/training.py +183 -0
- src/vin_whisper_medium.ipynb +1164 -0
- src/whisperX.ipynb +131 -0
- src/whisper_quant.py +995 -0
src/EDA.ipynb
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 6,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import glob\n",
|
10 |
+
"\n",
|
11 |
+
"def count_files_by_extension(path, extension):\n",
|
12 |
+
" \"\"\"\n",
|
13 |
+
" path : root path to check ,\n",
|
14 |
+
" extension : .wav , ...\n",
|
15 |
+
" \"\"\"\n",
|
16 |
+
"\n",
|
17 |
+
" files = glob.glob(f\"{path}/*.{extension}\")\n",
|
18 |
+
" return len(files)\n",
|
19 |
+
"\n",
|
20 |
+
"\n",
|
21 |
+
"root_path = \"./vin_data/vlsp2020_train_set_02/\"\n"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 7,
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [],
|
29 |
+
"source": [
|
30 |
+
"num_wav_files = count_files_by_extension(root_path, \"wav\")"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 8,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"num_txt_files = count_files_by_extension(root_path, \"txt\")"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 9,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [
|
47 |
+
{
|
48 |
+
"name": "stdout",
|
49 |
+
"output_type": "stream",
|
50 |
+
"text": [
|
51 |
+
"Số lượng file WAV: 56427\n",
|
52 |
+
"Số lượng file text: 56427\n"
|
53 |
+
]
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"source": [
|
57 |
+
"print(f\"Số lượng file WAV: {num_wav_files}\")\n",
|
58 |
+
"print(f\"Số lượng file text: {num_txt_files}\")"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": 10,
|
64 |
+
"metadata": {},
|
65 |
+
"outputs": [
|
66 |
+
{
|
67 |
+
"name": "stdout",
|
68 |
+
"output_type": "stream",
|
69 |
+
"text": [
|
70 |
+
"Tần số mẫu (sample rate): 16000 Hz\n",
|
71 |
+
"Số kênh (channels): 1\n"
|
72 |
+
]
|
73 |
+
}
|
74 |
+
],
|
75 |
+
"source": [
|
76 |
+
"import os\n",
|
77 |
+
"import random\n",
|
78 |
+
"import wave\n",
|
79 |
+
"\n",
|
80 |
+
"\n",
|
81 |
+
"def get_random_wav_file_info(folder_path):\n",
|
82 |
+
" wav_files = glob.glob(f\"{folder_path}/*.wav\")\n",
|
83 |
+
" \n",
|
84 |
+
" if not wav_files:\n",
|
85 |
+
" return None, None\n",
|
86 |
+
" \n",
|
87 |
+
" random_wav_file = random.choice(wav_files)\n",
|
88 |
+
" \n",
|
89 |
+
" with wave.open(random_wav_file, 'rb') as wav_file:\n",
|
90 |
+
" sample_rate = wav_file.getframerate()\n",
|
91 |
+
" channels = wav_file.getnchannels()\n",
|
92 |
+
" \n",
|
93 |
+
" return sample_rate, channels\n",
|
94 |
+
"\n",
|
95 |
+
"path_to_wav_folder = \"./vin_data/vlsp2020_train_set_02/\"\n",
|
96 |
+
"\n",
|
97 |
+
"sample_rate, channels = get_random_wav_file_info(path_to_wav_folder)\n",
|
98 |
+
"\n",
|
99 |
+
"if sample_rate is not None and channels is not None:\n",
|
100 |
+
" print(f\"Tần số mẫu (sample rate): {sample_rate} Hz\")\n",
|
101 |
+
" print(f\"Số kênh (channels): {channels}\")\n",
|
102 |
+
"else:\n",
|
103 |
+
" print(\"Nothing.\")\n"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": 13,
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"import os\n",
|
113 |
+
"import csv\n",
|
114 |
+
"from tqdm import tqdm\n",
|
115 |
+
"\n",
|
116 |
+
"def create_csv_from_wav_folder(folder_path, output_csv_file):\n",
|
117 |
+
" wav_files = glob.glob(f\"{folder_path}/*.wav\")\n",
|
118 |
+
"\n",
|
119 |
+
" if not wav_files:\n",
|
120 |
+
" print(\"Không có file WAV nào trong thư mục.\")\n",
|
121 |
+
" return\n",
|
122 |
+
"\n",
|
123 |
+
" # Mở tệp CSV đầu ra và tạo bộ đếm số lượng file WAV\n",
|
124 |
+
" with open(output_csv_file, mode='w', newline='') as csv_file:\n",
|
125 |
+
" csv_writer = csv.writer(csv_file)\n",
|
126 |
+
" csv_writer.writerow(['path', 'name','sentence'])\n",
|
127 |
+
"\n",
|
128 |
+
" for wav_file_path in tqdm(wav_files):\n",
|
129 |
+
"\n",
|
130 |
+
" text_file_path = os.path.splitext(wav_file_path)[0] + \".txt\"\n",
|
131 |
+
" if os.path.exists(text_file_path):\n",
|
132 |
+
" with open(text_file_path, 'r') as txt_file:\n",
|
133 |
+
" text_content = txt_file.read()\n",
|
134 |
+
" else:\n",
|
135 |
+
" text_content = \"Not found.\"\n",
|
136 |
+
"\n",
|
137 |
+
" csv_writer.writerow([wav_file_path, os.path.basename(wav_file_path), sample_rate, channels, text_content])\n"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"execution_count": 14,
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [
|
145 |
+
{
|
146 |
+
"name": "stderr",
|
147 |
+
"output_type": "stream",
|
148 |
+
"text": [
|
149 |
+
"100%|██████████| 56427/56427 [00:37<00:00, 1492.44it/s]\n"
|
150 |
+
]
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"source": [
|
154 |
+
"output_csv_file = \"vin.csv\"\n",
|
155 |
+
"path_to_wav_folder = \"./vin_data/vlsp2020_train_set_02/\"\n",
|
156 |
+
"create_csv_from_wav_folder(path_to_wav_folder, output_csv_file)"
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "code",
|
161 |
+
"execution_count": 34,
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [
|
164 |
+
{
|
165 |
+
"data": {
|
166 |
+
"text/html": [
|
167 |
+
"<div>\n",
|
168 |
+
"<style scoped>\n",
|
169 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
170 |
+
" vertical-align: middle;\n",
|
171 |
+
" }\n",
|
172 |
+
"\n",
|
173 |
+
" .dataframe tbody tr th {\n",
|
174 |
+
" vertical-align: top;\n",
|
175 |
+
" }\n",
|
176 |
+
"\n",
|
177 |
+
" .dataframe thead th {\n",
|
178 |
+
" text-align: right;\n",
|
179 |
+
" }\n",
|
180 |
+
"</style>\n",
|
181 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
182 |
+
" <thead>\n",
|
183 |
+
" <tr style=\"text-align: right;\">\n",
|
184 |
+
" <th></th>\n",
|
185 |
+
" <th>path</th>\n",
|
186 |
+
" <th>name</th>\n",
|
187 |
+
" <th>sentence</th>\n",
|
188 |
+
" </tr>\n",
|
189 |
+
" </thead>\n",
|
190 |
+
" <tbody>\n",
|
191 |
+
" <tr>\n",
|
192 |
+
" <th>0</th>\n",
|
193 |
+
" <td>./vin_data/vlsp2020_train_set_02/spkyut-201907...</td>\n",
|
194 |
+
" <td>spkyut-20190730-utt000000716.wav</td>\n",
|
195 |
+
" <td>cây cam canh là loại cây ăn quả dễ trồng dễ ch...</td>\n",
|
196 |
+
" </tr>\n",
|
197 |
+
" <tr>\n",
|
198 |
+
" <th>1</th>\n",
|
199 |
+
" <td>./vin_data/vlsp2020_train_set_02/database_sa3_...</td>\n",
|
200 |
+
" <td>database_sa3_1_150h_15Jan2020_cleaned_utt_0000...</td>\n",
|
201 |
+
" <td>những đặc sản vùng miền nổi tiếng như miến don...</td>\n",
|
202 |
+
" </tr>\n",
|
203 |
+
" <tr>\n",
|
204 |
+
" <th>2</th>\n",
|
205 |
+
" <td>./vin_data/vlsp2020_train_set_02/speaker_544-0...</td>\n",
|
206 |
+
" <td>speaker_544-069450-1.wav</td>\n",
|
207 |
+
" <td>trước thông tin này trương nam thành chia sẻ c...</td>\n",
|
208 |
+
" </tr>\n",
|
209 |
+
" <tr>\n",
|
210 |
+
" <th>3</th>\n",
|
211 |
+
" <td>./vin_data/vlsp2020_train_set_02/database_sa1_...</td>\n",
|
212 |
+
" <td>database_sa1_Jan08_Mar19_cleaned_utt_000005361...</td>\n",
|
213 |
+
" <td>giống như những nữ hoàng á</td>\n",
|
214 |
+
" </tr>\n",
|
215 |
+
" <tr>\n",
|
216 |
+
" <th>4</th>\n",
|
217 |
+
" <td>./vin_data/vlsp2020_train_set_02/database_sa2_...</td>\n",
|
218 |
+
" <td>database_sa2_Jan4_Feb29_cleaned_utt_0000154206...</td>\n",
|
219 |
+
" <td>thay vì phun toàn bộ cánh đồng bằng hóa chất c...</td>\n",
|
220 |
+
" </tr>\n",
|
221 |
+
" </tbody>\n",
|
222 |
+
"</table>\n",
|
223 |
+
"</div>"
|
224 |
+
],
|
225 |
+
"text/plain": [
|
226 |
+
" path \\\n",
|
227 |
+
"0 ./vin_data/vlsp2020_train_set_02/spkyut-201907... \n",
|
228 |
+
"1 ./vin_data/vlsp2020_train_set_02/database_sa3_... \n",
|
229 |
+
"2 ./vin_data/vlsp2020_train_set_02/speaker_544-0... \n",
|
230 |
+
"3 ./vin_data/vlsp2020_train_set_02/database_sa1_... \n",
|
231 |
+
"4 ./vin_data/vlsp2020_train_set_02/database_sa2_... \n",
|
232 |
+
"\n",
|
233 |
+
" name \\\n",
|
234 |
+
"0 spkyut-20190730-utt000000716.wav \n",
|
235 |
+
"1 database_sa3_1_150h_15Jan2020_cleaned_utt_0000... \n",
|
236 |
+
"2 speaker_544-069450-1.wav \n",
|
237 |
+
"3 database_sa1_Jan08_Mar19_cleaned_utt_000005361... \n",
|
238 |
+
"4 database_sa2_Jan4_Feb29_cleaned_utt_0000154206... \n",
|
239 |
+
"\n",
|
240 |
+
" sentence \n",
|
241 |
+
"0 cây cam canh là loại cây ăn quả dễ trồng dễ ch... \n",
|
242 |
+
"1 những đặc sản vùng miền nổi tiếng như miến don... \n",
|
243 |
+
"2 trước thông tin này trương nam thành chia sẻ c... \n",
|
244 |
+
"3 giống như những nữ hoàng á \n",
|
245 |
+
"4 thay vì phun toàn bộ cánh đồng bằng hóa chất c... "
|
246 |
+
]
|
247 |
+
},
|
248 |
+
"execution_count": 34,
|
249 |
+
"metadata": {},
|
250 |
+
"output_type": "execute_result"
|
251 |
+
}
|
252 |
+
],
|
253 |
+
"source": [
|
254 |
+
"import pandas as pd \n",
|
255 |
+
"data = pd.read_csv('vin_test.csv')\n",
|
256 |
+
"data.head(5)"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "code",
|
261 |
+
"execution_count": 30,
|
262 |
+
"metadata": {},
|
263 |
+
"outputs": [],
|
264 |
+
"source": [
|
265 |
+
"import csv\n",
|
266 |
+
"import random\n",
|
267 |
+
"\n",
|
268 |
+
"def split_csv_file(input_file, output_file1, output_file2, ratio):\n",
|
269 |
+
" with open(input_file, 'r', newline='', encoding='utf-8') as csvfile:\n",
|
270 |
+
" csvreader = csv.reader(csvfile)\n",
|
271 |
+
" header = next(csvreader) \n",
|
272 |
+
" \n",
|
273 |
+
" data = list(csvreader)\n",
|
274 |
+
" random.shuffle(data)\n",
|
275 |
+
"\n",
|
276 |
+
" total_rows = len(data)\n",
|
277 |
+
" rows_output_file1 = int(total_rows * ratio)\n",
|
278 |
+
" rows_output_file2 = total_rows - rows_output_file1\n",
|
279 |
+
" \n",
|
280 |
+
" # Split the data into two parts\n",
|
281 |
+
" data1 = data[:rows_output_file1]\n",
|
282 |
+
" data2 = data[rows_output_file1:]\n",
|
283 |
+
"\n",
|
284 |
+
" with open(output_file1, 'w', newline='', encoding='utf-8') as csvfile1:\n",
|
285 |
+
" csvwriter1 = csv.writer(csvfile1, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
|
286 |
+
" csvwriter1.writerow(header)\n",
|
287 |
+
" csvwriter1.writerows(data1)\n",
|
288 |
+
"\n",
|
289 |
+
" with open(output_file2, 'w', newline='', encoding='utf-8') as csvfile2:\n",
|
290 |
+
" csvwriter2 = csv.writer(csvfile2, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
|
291 |
+
" csvwriter2.writerow(header)\n",
|
292 |
+
" csvwriter2.writerows(data2)\n",
|
293 |
+
"\n",
|
294 |
+
"input_file = 'vin.csv'\n",
|
295 |
+
"output_file1 = 'vin_train.csv'\n",
|
296 |
+
"output_file2 = 'vin_test.csv'\n",
|
297 |
+
"ratio = 0.8 \n",
|
298 |
+
"\n",
|
299 |
+
"split_csv_file(input_file, output_file1, output_file2, ratio)\n"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": null,
|
305 |
+
"metadata": {},
|
306 |
+
"outputs": [],
|
307 |
+
"source": [
|
308 |
+
"from datasets import load_dataset, DatasetDict\n",
|
309 |
+
"\n",
|
310 |
+
"vivos = DatasetDict()"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": 46,
|
316 |
+
"metadata": {},
|
317 |
+
"outputs": [],
|
318 |
+
"source": [
|
319 |
+
"import os\n",
|
320 |
+
"import numpy as np\n",
|
321 |
+
"\n",
|
322 |
+
"import torch\n",
|
323 |
+
"import torchaudio\n",
|
324 |
+
"\n",
|
325 |
+
"import pandas as pd\n",
|
326 |
+
"import whisper\n",
|
327 |
+
"import torchaudio.transforms as at\n",
|
328 |
+
"from pathlib import Path\n",
|
329 |
+
"\n",
|
330 |
+
"def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:\n",
|
331 |
+
" waveform, sr = torchaudio.load(wave_path, normalize=True)\n",
|
332 |
+
" if sample_rate != sr:\n",
|
333 |
+
" waveform = at.Resample(sr, sample_rate)(waveform)\n",
|
334 |
+
" return waveform\n",
|
335 |
+
"\n",
|
336 |
+
"\n",
|
337 |
+
"\n",
|
338 |
+
"def get_list_files_vin100h(phase, dataset_path='./vin_data/vlsp2020_train_set_02/', text_max_length=10000, audio_max_sample_length=1000000, sample_rate=16000):\n",
|
339 |
+
" audio_transcript_pair_list = []\n",
|
340 |
+
" if phase == 'train':\n",
|
341 |
+
" csv_file = 'vin_train.csv'\n",
|
342 |
+
" else:\n",
|
343 |
+
" csv_file = 'vin_test.csv'\n",
|
344 |
+
" df = pd.read_csv(csv_file)\n",
|
345 |
+
" for index, row in df.iterrows():\n",
|
346 |
+
" new_path = Path(row['path'])\n",
|
347 |
+
" audio_id = index\n",
|
348 |
+
" text = row['sentence']\n",
|
349 |
+
" if new_path.exists():\n",
|
350 |
+
" audio = load_wave(new_path, sample_rate=sample_rate)[0]\n",
|
351 |
+
" # if len(text) > text_max_length or len(audio) > audio_max_sample_length:\n",
|
352 |
+
" # print('skip file:', new_path, 'with len text:', len(text), 'and len audio', len(audio))\n",
|
353 |
+
" # continue\n",
|
354 |
+
" audio_transcript_pair_list.append((audio_id, str(new_path), text))\n",
|
355 |
+
" print(audio_transcript_pair_list)\n",
|
356 |
+
" return audio, audio_transcript_pair_list\n"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "code",
|
361 |
+
"execution_count": null,
|
362 |
+
"metadata": {},
|
363 |
+
"outputs": [],
|
364 |
+
"source": [
|
365 |
+
"get_list_files_vin100h(phase='train')"
|
366 |
+
]
|
367 |
+
}
|
368 |
+
],
|
369 |
+
"metadata": {
|
370 |
+
"kernelspec": {
|
371 |
+
"display_name": "DUY",
|
372 |
+
"language": "python",
|
373 |
+
"name": "python3"
|
374 |
+
},
|
375 |
+
"language_info": {
|
376 |
+
"codemirror_mode": {
|
377 |
+
"name": "ipython",
|
378 |
+
"version": 3
|
379 |
+
},
|
380 |
+
"file_extension": ".py",
|
381 |
+
"mimetype": "text/x-python",
|
382 |
+
"name": "python",
|
383 |
+
"nbconvert_exporter": "python",
|
384 |
+
"pygments_lexer": "ipython3",
|
385 |
+
"version": "3.9.17"
|
386 |
+
},
|
387 |
+
"orig_nbformat": 4
|
388 |
+
},
|
389 |
+
"nbformat": 4,
|
390 |
+
"nbformat_minor": 2
|
391 |
+
}
|
src/MITI.ipynb
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 64,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import os\n",
|
10 |
+
"import glob\n",
|
11 |
+
"\n",
|
12 |
+
"def count_files_by_extension(path, extension):\n",
|
13 |
+
" \"\"\"\n",
|
14 |
+
" path : root path to check,\n",
|
15 |
+
" extension : .wav, ...\n",
|
16 |
+
" \"\"\"\n",
|
17 |
+
" total_count = 0\n",
|
18 |
+
" \n",
|
19 |
+
" for foldername, subfolders, filenames in os.walk(path):\n",
|
20 |
+
" files = glob.glob(os.path.join(foldername, f\"*.{extension}\"))\n",
|
21 |
+
" total_count += len(files)\n",
|
22 |
+
" \n",
|
23 |
+
" return total_count\n",
|
24 |
+
"\n",
|
25 |
+
"\n",
|
26 |
+
"root_path = \"./Cleaned_MITI/dataset_2\""
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 65,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"num_wav_files = count_files_by_extension(root_path, \"wav\")\n",
|
36 |
+
"num_txt_files = count_files_by_extension(root_path, \"txt\")"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 66,
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [
|
44 |
+
{
|
45 |
+
"name": "stdout",
|
46 |
+
"output_type": "stream",
|
47 |
+
"text": [
|
48 |
+
"Số lượng file WAV: 2099\n",
|
49 |
+
"Số lượng file text: 2099\n"
|
50 |
+
]
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"print(f\"Số lượng file WAV: {num_wav_files}\")\n",
|
55 |
+
"print(f\"Số lượng file text: {num_txt_files}\")"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 70,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [
|
63 |
+
{
|
64 |
+
"name": "stdout",
|
65 |
+
"output_type": "stream",
|
66 |
+
"text": [
|
67 |
+
"Tần số mẫu (sample rate): 44100 Hz\n",
|
68 |
+
"Số kênh (channels): 1\n"
|
69 |
+
]
|
70 |
+
}
|
71 |
+
],
|
72 |
+
"source": [
|
73 |
+
"import os\n",
|
74 |
+
"import random\n",
|
75 |
+
"import wave\n",
|
76 |
+
"\n",
|
77 |
+
"\n",
|
78 |
+
"def get_random_wav_file_info(folder_path):\n",
|
79 |
+
" for foldername, subfolders, filenames in os.walk(folder_path): \n",
|
80 |
+
" wav_files = glob.glob(f\"{foldername}/*.wav\")\n",
|
81 |
+
" \n",
|
82 |
+
" if not wav_files:\n",
|
83 |
+
" return None, None\n",
|
84 |
+
" \n",
|
85 |
+
" random_wav_file = random.choice(wav_files)\n",
|
86 |
+
" \n",
|
87 |
+
" with wave.open(random_wav_file, 'rb') as wav_file:\n",
|
88 |
+
" sample_rate = wav_file.getframerate()\n",
|
89 |
+
" channels = wav_file.getnchannels()\n",
|
90 |
+
" \n",
|
91 |
+
" return sample_rate, channels\n",
|
92 |
+
"\n",
|
93 |
+
"path_to_wav_folder = \"./Cleaned_MITI/dataset_2/\"\n",
|
94 |
+
"\n",
|
95 |
+
"sample_rate, channels = get_random_wav_file_info(path_to_wav_folder)\n",
|
96 |
+
"\n",
|
97 |
+
"if sample_rate is not None and channels is not None:\n",
|
98 |
+
" print(f\"Tần số mẫu (sample rate): {sample_rate} Hz\")\n",
|
99 |
+
" print(f\"Số kênh (channels): {channels}\")\n",
|
100 |
+
"else:\n",
|
101 |
+
" print(\"Nothing.\")\n"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": null,
|
107 |
+
"metadata": {},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"def remove_special_characters(input_string):\n",
|
111 |
+
" special_characters = ['.', ',', '-', '_', \" \"]\n",
|
112 |
+
" \n",
|
113 |
+
" # Duyệt qua từng ký tự trong chuỗi\n",
|
114 |
+
" filtered_string = ''.join([char for char in input_string if char not in special_characters])\n",
|
115 |
+
" \n",
|
116 |
+
" return filtered_string\n",
|
117 |
+
"\n",
|
118 |
+
"# Sử dụng hàm\n",
|
119 |
+
"input_string = \"Hello, this_is_a-test.string!\"\n",
|
120 |
+
"output_string = remove_special_characters(input_string)\n",
|
121 |
+
"print(output_string) # Kết quả: \"Hello thisisa teststring\"\n"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": 86,
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [
|
129 |
+
{
|
130 |
+
"name": "stderr",
|
131 |
+
"output_type": "stream",
|
132 |
+
"text": [
|
133 |
+
" 84%|████████▎ | 164/196 [00:00<00:00, 1629.92it/s]"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"name": "stderr",
|
138 |
+
"output_type": "stream",
|
139 |
+
"text": [
|
140 |
+
"100%|██████████| 196/196 [00:00<00:00, 1580.86it/s]\n",
|
141 |
+
"100%|██████████| 218/218 [00:00<00:00, 1440.12it/s]\n",
|
142 |
+
"100%|██████████| 216/216 [00:00<00:00, 1364.20it/s]\n",
|
143 |
+
"100%|██████████| 205/205 [00:00<00:00, 1412.14it/s]\n",
|
144 |
+
"100%|██████████| 204/204 [00:00<00:00, 1426.29it/s]\n",
|
145 |
+
"100%|██████████| 220/220 [00:00<00:00, 1511.87it/s]\n",
|
146 |
+
"100%|██████████| 225/225 [00:00<00:00, 1499.30it/s]\n",
|
147 |
+
"100%|██████████| 175/175 [00:00<00:00, 1492.85it/s]\n",
|
148 |
+
"100%|██████████| 220/220 [00:00<00:00, 1496.34it/s]\n",
|
149 |
+
"100%|██████████| 220/220 [00:00<00:00, 1480.81it/s]\n"
|
150 |
+
]
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"source": [
|
154 |
+
"import os\n",
|
155 |
+
"import csv\n",
|
156 |
+
"from tqdm import tqdm\n",
|
157 |
+
"import glob\n",
|
158 |
+
"from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n",
|
159 |
+
"normalizer = BasicTextNormalizer()\n",
|
160 |
+
"def create_csv_from_wav_folder(folder_path, output_csv_file):\n",
|
161 |
+
" with open(output_csv_file, mode='w', newline='') as csv_file:\n",
|
162 |
+
" csv_writer = csv.writer(csv_file)\n",
|
163 |
+
" csv_writer.writerow(['path', 'name', 'sentence'])\n",
|
164 |
+
"\n",
|
165 |
+
" for person_foldername, _, _ in os.walk(folder_path):\n",
|
166 |
+
" if \"person_\" in person_foldername:\n",
|
167 |
+
" wav_files = glob.glob(os.path.join(person_foldername, \"*.wav\"))\n",
|
168 |
+
"\n",
|
169 |
+
" for wav_file_path in tqdm(wav_files):\n",
|
170 |
+
" wav_filename = os.path.basename(wav_file_path)\n",
|
171 |
+
" text_filename = os.path.splitext(wav_filename)[0] + \".txt\"\n",
|
172 |
+
" text_file_path = os.path.join(person_foldername, text_filename)\n",
|
173 |
+
"\n",
|
174 |
+
" if os.path.exists(text_file_path):\n",
|
175 |
+
" with open(text_file_path, 'r') as txt_file:\n",
|
176 |
+
" text_content = normalizer(txt_file.read())\n",
|
177 |
+
" else:\n",
|
178 |
+
" text_content = \"Not found.\"\n",
|
179 |
+
"\n",
|
180 |
+
" csv_writer.writerow([wav_file_path, wav_filename, text_content])\n",
|
181 |
+
"\n",
|
182 |
+
"root_path = \"./Cleaned_MITI/dataset_2\" \n",
|
183 |
+
"output_csv_file = \"MITI.csv\" \n",
|
184 |
+
"\n",
|
185 |
+
"create_csv_from_wav_folder(root_path, output_csv_file)\n"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": 89,
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [
|
193 |
+
{
|
194 |
+
"data": {
|
195 |
+
"text/plain": [
|
196 |
+
"2099"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
"execution_count": 89,
|
200 |
+
"metadata": {},
|
201 |
+
"output_type": "execute_result"
|
202 |
+
}
|
203 |
+
],
|
204 |
+
"source": [
|
205 |
+
"import pandas as pd \n",
|
206 |
+
"data = pd.read_csv('MITI.csv')\n",
|
207 |
+
"len(data)"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 90,
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"import csv\n",
|
217 |
+
"import random\n",
|
218 |
+
"\n",
|
219 |
+
"def split_csv_file(input_file, output_file1, output_file2, ratio):\n",
|
220 |
+
" with open(input_file, 'r', newline='', encoding='utf-8') as csvfile:\n",
|
221 |
+
" csvreader = csv.reader(csvfile)\n",
|
222 |
+
" header = next(csvreader) \n",
|
223 |
+
" \n",
|
224 |
+
" data = list(csvreader)\n",
|
225 |
+
" random.shuffle(data)\n",
|
226 |
+
"\n",
|
227 |
+
" total_rows = len(data)\n",
|
228 |
+
" rows_output_file1 = int(total_rows * ratio)\n",
|
229 |
+
" rows_output_file2 = total_rows - rows_output_file1\n",
|
230 |
+
" \n",
|
231 |
+
" # Split the data into two parts\n",
|
232 |
+
" data1 = data[:rows_output_file1]\n",
|
233 |
+
" data2 = data[rows_output_file1:]\n",
|
234 |
+
"\n",
|
235 |
+
" with open(output_file1, 'w', newline='', encoding='utf-8') as csvfile1:\n",
|
236 |
+
" csvwriter1 = csv.writer(csvfile1, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
|
237 |
+
" csvwriter1.writerow(header)\n",
|
238 |
+
" csvwriter1.writerows(data1)\n",
|
239 |
+
"\n",
|
240 |
+
" with open(output_file2, 'w', newline='', encoding='utf-8') as csvfile2:\n",
|
241 |
+
" csvwriter2 = csv.writer(csvfile2, quotechar='|', quoting=csv.QUOTE_MINIMAL)\n",
|
242 |
+
" csvwriter2.writerow(header)\n",
|
243 |
+
" csvwriter2.writerows(data2)\n",
|
244 |
+
"\n",
|
245 |
+
"input_file = 'MITI.csv'\n",
|
246 |
+
"output_file1 = 'MITI_train.csv'\n",
|
247 |
+
"output_file2 = 'MITI_test.csv'\n",
|
248 |
+
"ratio = 0.8 \n",
|
249 |
+
"\n",
|
250 |
+
"split_csv_file(input_file, output_file1, output_file2, ratio)\n"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": null,
|
256 |
+
"metadata": {},
|
257 |
+
"outputs": [],
|
258 |
+
"source": [
|
259 |
+
"from datasets import load_dataset, DatasetDict\n",
|
260 |
+
"\n",
|
261 |
+
"vivos = DatasetDict()"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "code",
|
266 |
+
"execution_count": 46,
|
267 |
+
"metadata": {},
|
268 |
+
"outputs": [],
|
269 |
+
"source": [
|
270 |
+
"import os\n",
|
271 |
+
"import numpy as np\n",
|
272 |
+
"\n",
|
273 |
+
"import torch\n",
|
274 |
+
"import torchaudio\n",
|
275 |
+
"\n",
|
276 |
+
"import pandas as pd\n",
|
277 |
+
"import whisper\n",
|
278 |
+
"import torchaudio.transforms as at\n",
|
279 |
+
"from pathlib import Path\n",
|
280 |
+
"\n",
|
281 |
+
"def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:\n",
|
282 |
+
" waveform, sr = torchaudio.load(wave_path, normalize=True)\n",
|
283 |
+
" if sample_rate != sr:\n",
|
284 |
+
" waveform = at.Resample(sr, sample_rate)(waveform)\n",
|
285 |
+
" return waveform\n",
|
286 |
+
"\n",
|
287 |
+
"\n",
|
288 |
+
"\n",
|
289 |
+
"def get_list_files_vin100h(phase, dataset_path='./vin_data/vlsp2020_train_set_02/', text_max_length=10000, audio_max_sample_length=1000000, sample_rate=16000):\n",
|
290 |
+
" audio_transcript_pair_list = []\n",
|
291 |
+
" if phase == 'train':\n",
|
292 |
+
" csv_file = 'vin_train.csv'\n",
|
293 |
+
" else:\n",
|
294 |
+
" csv_file = 'vin_test.csv'\n",
|
295 |
+
" df = pd.read_csv(csv_file)\n",
|
296 |
+
" for index, row in df.iterrows():\n",
|
297 |
+
" new_path = Path(row['path'])\n",
|
298 |
+
" audio_id = index\n",
|
299 |
+
" text = row['sentence']\n",
|
300 |
+
" if new_path.exists():\n",
|
301 |
+
" audio = load_wave(new_path, sample_rate=sample_rate)[0]\n",
|
302 |
+
" # if len(text) > text_max_length or len(audio) > audio_max_sample_length:\n",
|
303 |
+
" # print('skip file:', new_path, 'with len text:', len(text), 'and len audio', len(audio))\n",
|
304 |
+
" # continue\n",
|
305 |
+
" audio_transcript_pair_list.append((audio_id, str(new_path), text))\n",
|
306 |
+
" print(audio_transcript_pair_list)\n",
|
307 |
+
" return audio, audio_transcript_pair_list\n"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "code",
|
312 |
+
"execution_count": null,
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"get_list_files_vin100h(phase='train')"
|
317 |
+
]
|
318 |
+
}
|
319 |
+
],
|
320 |
+
"metadata": {
|
321 |
+
"kernelspec": {
|
322 |
+
"display_name": "DUY",
|
323 |
+
"language": "python",
|
324 |
+
"name": "python3"
|
325 |
+
},
|
326 |
+
"language_info": {
|
327 |
+
"codemirror_mode": {
|
328 |
+
"name": "ipython",
|
329 |
+
"version": 3
|
330 |
+
},
|
331 |
+
"file_extension": ".py",
|
332 |
+
"mimetype": "text/x-python",
|
333 |
+
"name": "python",
|
334 |
+
"nbconvert_exporter": "python",
|
335 |
+
"pygments_lexer": "ipython3",
|
336 |
+
"version": "3.9.17"
|
337 |
+
},
|
338 |
+
"orig_nbformat": 4
|
339 |
+
},
|
340 |
+
"nbformat": 4,
|
341 |
+
"nbformat_minor": 2
|
342 |
+
}
|
src/download_quantized.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import huggingface_hub
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from tqdm.auto import tqdm
|
11 |
+
|
12 |
+
_MODELS = (
|
13 |
+
"medium"
|
14 |
+
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def get_assets_path():
|
19 |
+
"""Returns the path to the assets directory."""
|
20 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
|
21 |
+
|
22 |
+
|
23 |
+
def get_logger():
|
24 |
+
"""Returns the module logger."""
|
25 |
+
return logging.getLogger("faster_whisper")
|
26 |
+
|
27 |
+
|
28 |
+
def download_model(
|
29 |
+
size_or_id: str,
|
30 |
+
output_dir: Optional[str] = None,
|
31 |
+
local_files_only: bool = False,
|
32 |
+
cache_dir: Optional[str] = None,
|
33 |
+
):
|
34 |
+
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
35 |
+
|
36 |
+
The model is downloaded from https://huggingface.co/DuyTa.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
size_or_id: Size of the model to download (tiny, tiny.en, base, base.en, small, small.en,
|
40 |
+
medium, medium.en, large-v1, or large-v2), or a CTranslate2-converted model ID
|
41 |
+
from the Hugging Face Hub (e.g. guillaumekln/faster-whisper-large-v2).
|
42 |
+
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
43 |
+
the cache directory.
|
44 |
+
local_files_only: If True, avoid downloading the file and return the path to the local
|
45 |
+
cached file if it exists.
|
46 |
+
cache_dir: Path to the folder where cached files are stored.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
The path to the downloaded model.
|
50 |
+
|
51 |
+
Raises:
|
52 |
+
ValueError: if the model size is invalid.
|
53 |
+
"""
|
54 |
+
if re.match(r".*/.*", size_or_id):
|
55 |
+
repo_id = size_or_id
|
56 |
+
else:
|
57 |
+
if size_or_id not in _MODELS:
|
58 |
+
raise ValueError(
|
59 |
+
"Invalid model size '%s', expected one of: %s"
|
60 |
+
% (size_or_id, ", ".join(_MODELS))
|
61 |
+
)
|
62 |
+
|
63 |
+
#repo_id = "DuyTa/vi-whisper-%s-Lora" % size_or_id
|
64 |
+
repo_id = "DuyTa/Vietnamese_ASR"
|
65 |
+
|
66 |
+
allow_patterns = [
|
67 |
+
"config.json",
|
68 |
+
"model.bin",
|
69 |
+
"tokenizer.json",
|
70 |
+
"vocabulary.*",
|
71 |
+
]
|
72 |
+
|
73 |
+
kwargs = {
|
74 |
+
"local_files_only": local_files_only,
|
75 |
+
"allow_patterns": allow_patterns,
|
76 |
+
"tqdm_class": disabled_tqdm,
|
77 |
+
}
|
78 |
+
|
79 |
+
if output_dir is not None:
|
80 |
+
kwargs["local_dir"] = output_dir
|
81 |
+
kwargs["local_dir_use_symlinks"] = False
|
82 |
+
|
83 |
+
if cache_dir is not None:
|
84 |
+
kwargs["cache_dir"] = cache_dir
|
85 |
+
|
86 |
+
try:
|
87 |
+
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
88 |
+
except (
|
89 |
+
huggingface_hub.utils.HfHubHTTPError,
|
90 |
+
requests.exceptions.ConnectionError,
|
91 |
+
) as exception:
|
92 |
+
logger = get_logger()
|
93 |
+
logger.warning(
|
94 |
+
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
95 |
+
repo_id,
|
96 |
+
exception,
|
97 |
+
)
|
98 |
+
logger.warning(
|
99 |
+
"Trying to load the model directly from the local cache, if it exists."
|
100 |
+
)
|
101 |
+
|
102 |
+
kwargs["local_files_only"] = True
|
103 |
+
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
104 |
+
|
105 |
+
|
106 |
+
def format_timestamp(
|
107 |
+
seconds: float,
|
108 |
+
always_include_hours: bool = False,
|
109 |
+
decimal_marker: str = ".",
|
110 |
+
) -> str:
|
111 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
112 |
+
milliseconds = round(seconds * 1000.0)
|
113 |
+
|
114 |
+
hours = milliseconds // 3_600_000
|
115 |
+
milliseconds -= hours * 3_600_000
|
116 |
+
|
117 |
+
minutes = milliseconds // 60_000
|
118 |
+
milliseconds -= minutes * 60_000
|
119 |
+
|
120 |
+
seconds = milliseconds // 1_000
|
121 |
+
milliseconds -= seconds * 1_000
|
122 |
+
|
123 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
124 |
+
return (
|
125 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class disabled_tqdm(tqdm):
|
130 |
+
def __init__(self, *args, **kwargs):
|
131 |
+
kwargs["disable"] = True
|
132 |
+
super().__init__(*args, **kwargs)
|
src/laboratory.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
src/lora_tuning.py
ADDED
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import gc
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from datetime import datetime
|
9 |
+
from pathlib import Path
|
10 |
+
from random import randint
|
11 |
+
from typing import Any, Dict, List, Union
|
12 |
+
|
13 |
+
# datasets imports
|
14 |
+
import datasets
|
15 |
+
|
16 |
+
# metric imports
|
17 |
+
import evaluate
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
import transformers
|
21 |
+
import wandb
|
22 |
+
|
23 |
+
# accelerate imports
|
24 |
+
from accelerate import Accelerator, dispatch_model
|
25 |
+
from accelerate.logging import get_logger
|
26 |
+
from datasets import Audio, DatasetDict, IterableDatasetDict, interleave_datasets, load_dataset
|
27 |
+
|
28 |
+
# hf imports
|
29 |
+
from huggingface_hub import Repository
|
30 |
+
from torch.utils.data import DataLoader
|
31 |
+
from tqdm import tqdm
|
32 |
+
from transformers import (
|
33 |
+
SchedulerType,
|
34 |
+
WhisperForConditionalGeneration,
|
35 |
+
WhisperProcessor,
|
36 |
+
get_scheduler,
|
37 |
+
set_seed,
|
38 |
+
)
|
39 |
+
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
|
40 |
+
from transformers.utils import get_full_repo_name
|
41 |
+
|
42 |
+
# peft imports
|
43 |
+
from peft import AdaLoraConfig, LoraConfig, PeftModel, get_peft_model
|
44 |
+
|
45 |
+
|
46 |
+
logger = get_logger(__name__, log_level="INFO")
|
47 |
+
|
48 |
+
|
49 |
+
def parse_args():
|
50 |
+
parser = argparse.ArgumentParser(description="Whisper Fine-Tuning with AdaLora")
|
51 |
+
parser.add_argument(
|
52 |
+
"--model_name_or_path",
|
53 |
+
type=str,
|
54 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
55 |
+
required=True,
|
56 |
+
)
|
57 |
+
parser.add_argument("--language", type=str, help="Language to use for training; e.g., 'Hindi' ", required=True)
|
58 |
+
parser.add_argument("--language_abbr", type=str, help="Language to use for training; e.g., 'hi' ", required=True)
|
59 |
+
parser.add_argument(
|
60 |
+
"--task", type=str, default="transcribe", help="Task to use for training; e.g., 'transcribe' ", required=False
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--dataset_name",
|
64 |
+
type=str,
|
65 |
+
default="mozilla-foundation/common_voice_11_0",
|
66 |
+
help="Dataset to use for training; e.g., 'whisper' ",
|
67 |
+
required=False,
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--dataset_in_streaming_mode",
|
71 |
+
action="store_true",
|
72 |
+
help="Whether to use streaming mode for the dataset.",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--do_lower_case", action="store_true", help="lowercase the transcribed text before tokenizing"
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--do_remove_punctuation", action="store_true", help="remove punctuation from the transcribed text"
|
79 |
+
)
|
80 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
81 |
+
parser.add_argument(
|
82 |
+
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
|
83 |
+
)
|
84 |
+
parser.add_argument("--max_audio_input_length", type=float, default=30.0, help="Maximum audio length in seconds.")
|
85 |
+
parser.add_argument(
|
86 |
+
"--preprocessing_num_workers",
|
87 |
+
type=int,
|
88 |
+
default=None,
|
89 |
+
help="The number of processes to use for the preprocessing.",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--per_device_train_batch_size",
|
93 |
+
type=int,
|
94 |
+
default=8,
|
95 |
+
help="Batch size (per device) for the training dataloader.",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--per_device_eval_batch_size",
|
99 |
+
type=int,
|
100 |
+
default=8,
|
101 |
+
help="Batch size (per device) for the evaluation dataloader.",
|
102 |
+
)
|
103 |
+
parser.add_argument(
|
104 |
+
"--buffer_size",
|
105 |
+
type=int,
|
106 |
+
default=5000,
|
107 |
+
help="Number of samples to prefetch in the streaming mode.",
|
108 |
+
)
|
109 |
+
parser.add_argument(
|
110 |
+
"--dataloader_pin_memory",
|
111 |
+
action="store_true",
|
112 |
+
help="Whether or not to pin memory for the DataLoader.",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--dataloader_num_workers",
|
116 |
+
type=int,
|
117 |
+
default=0,
|
118 |
+
help="Number of subprocesses to use for data loading.",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--learning_rate",
|
122 |
+
type=float,
|
123 |
+
default=5e-5,
|
124 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
125 |
+
)
|
126 |
+
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
|
127 |
+
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
|
128 |
+
parser.add_argument(
|
129 |
+
"--max_train_steps",
|
130 |
+
type=int,
|
131 |
+
default=None,
|
132 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--gradient_accumulation_steps",
|
136 |
+
type=int,
|
137 |
+
default=1,
|
138 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
139 |
+
)
|
140 |
+
parser.add_argument(
|
141 |
+
"--lr_scheduler_type",
|
142 |
+
type=SchedulerType,
|
143 |
+
default="linear",
|
144 |
+
help="The scheduler type to use.",
|
145 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
146 |
+
)
|
147 |
+
parser.add_argument(
|
148 |
+
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
|
149 |
+
)
|
150 |
+
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
|
151 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
152 |
+
parser.add_argument(
|
153 |
+
"--load_best_model",
|
154 |
+
action="store_true",
|
155 |
+
help="Whether to load the best model at the end of training",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--with_tracking",
|
159 |
+
action="store_true",
|
160 |
+
help="Whether to enable experiment trackers for logging.",
|
161 |
+
)
|
162 |
+
parser.add_argument(
|
163 |
+
"--report_to",
|
164 |
+
type=str,
|
165 |
+
default="all",
|
166 |
+
help=(
|
167 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
168 |
+
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
169 |
+
"Only applicable when `--with_tracking` is passed."
|
170 |
+
),
|
171 |
+
)
|
172 |
+
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
|
173 |
+
parser.add_argument(
|
174 |
+
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
|
175 |
+
)
|
176 |
+
parser.add_argument(
|
177 |
+
"--checkpointing_steps",
|
178 |
+
type=int,
|
179 |
+
default=500,
|
180 |
+
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
181 |
+
)
|
182 |
+
parser.add_argument(
|
183 |
+
"--logging_steps",
|
184 |
+
type=int,
|
185 |
+
default=100,
|
186 |
+
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
187 |
+
)
|
188 |
+
parser.add_argument(
|
189 |
+
"--evaluation_steps",
|
190 |
+
type=int,
|
191 |
+
default=500,
|
192 |
+
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--resume_from_checkpoint",
|
196 |
+
type=str,
|
197 |
+
default=None,
|
198 |
+
help="If the training should continue from a checkpoint folder.",
|
199 |
+
)
|
200 |
+
|
201 |
+
# lora/adalora specific args
|
202 |
+
parser.add_argument(
|
203 |
+
"--use_peft",
|
204 |
+
action="store_true",
|
205 |
+
help="Whether to use PEFT",
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--use_adalora",
|
209 |
+
action="store_true",
|
210 |
+
help="Whether to use AdaLoRA or LoRA. If set, uses AdaLoRA instead of the default LoRA.",
|
211 |
+
)
|
212 |
+
parser.add_argument(
|
213 |
+
"--init_r",
|
214 |
+
type=int,
|
215 |
+
default=12,
|
216 |
+
help="Initial AdaLoRA rank",
|
217 |
+
)
|
218 |
+
parser.add_argument(
|
219 |
+
"--target_r",
|
220 |
+
type=int,
|
221 |
+
default=4,
|
222 |
+
help="Target AdaLoRA rank",
|
223 |
+
)
|
224 |
+
parser.add_argument(
|
225 |
+
"--tinit",
|
226 |
+
type=int,
|
227 |
+
default=200,
|
228 |
+
help="number of warmup steps for AdaLoRA wherein no pruning is performed",
|
229 |
+
)
|
230 |
+
parser.add_argument(
|
231 |
+
"--tfinal",
|
232 |
+
type=int,
|
233 |
+
default=1000,
|
234 |
+
help=" fix the resulting budget distribution and fine-tune the model for tfinal steps when using AdaLoRA ",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--delta_t",
|
238 |
+
type=int,
|
239 |
+
default=10,
|
240 |
+
help="interval of steps for AdaLoRA to update rank",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"--lora_alpha",
|
244 |
+
type=int,
|
245 |
+
default=32,
|
246 |
+
help="LORA alpha",
|
247 |
+
)
|
248 |
+
parser.add_argument(
|
249 |
+
"--r",
|
250 |
+
type=int,
|
251 |
+
default=8,
|
252 |
+
help="LORA rank",
|
253 |
+
)
|
254 |
+
parser.add_argument(
|
255 |
+
"--lora_dropout",
|
256 |
+
type=float,
|
257 |
+
default=0.1,
|
258 |
+
help="LORA dropout",
|
259 |
+
)
|
260 |
+
parser.add_argument(
|
261 |
+
"--orth_reg_weight",
|
262 |
+
type=float,
|
263 |
+
default=0.5,
|
264 |
+
help="Orthogonal regularization weight",
|
265 |
+
)
|
266 |
+
parser.add_argument(
|
267 |
+
"--debug_mode",
|
268 |
+
action="store_true",
|
269 |
+
help="Whether to use debug mode",
|
270 |
+
)
|
271 |
+
|
272 |
+
args = parser.parse_args()
|
273 |
+
|
274 |
+
if args.push_to_hub:
|
275 |
+
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
|
276 |
+
|
277 |
+
return args
|
278 |
+
|
279 |
+
|
280 |
+
def load_streaming_dataset(dataset_name, dataset_config_name, split, **kwargs):
|
281 |
+
if "+" in split:
|
282 |
+
# load multiple splits separated by the `+` symbol *with* streaming mode
|
283 |
+
dataset_splits = [
|
284 |
+
load_dataset(dataset_name, dataset_config_name, split=split_name, streaming=True, **kwargs)
|
285 |
+
for split_name in split.split("+")
|
286 |
+
]
|
287 |
+
# interleave multiple splits to form one dataset
|
288 |
+
interleaved_dataset = interleave_datasets(dataset_splits)
|
289 |
+
return interleaved_dataset
|
290 |
+
else:
|
291 |
+
# load a single split *with* streaming mode
|
292 |
+
dataset = load_dataset(dataset_name, dataset_config_name, split=split, streaming=True, **kwargs)
|
293 |
+
return dataset
|
294 |
+
|
295 |
+
|
296 |
+
def prepare_dataset_wrapper(do_lower_case, do_remove_punctuation, processor, normalizer):
|
297 |
+
def prepare_dataset(batch):
|
298 |
+
# load and (possibly) resample audio data to 16kHz
|
299 |
+
audio = batch["audio"]
|
300 |
+
|
301 |
+
# compute log-Mel input features from input audio array
|
302 |
+
batch["input_features"] = processor.feature_extractor(
|
303 |
+
audio["array"], sampling_rate=audio["sampling_rate"]
|
304 |
+
).input_features[0]
|
305 |
+
# compute input length of audio sample in seconds
|
306 |
+
batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]
|
307 |
+
|
308 |
+
# optional pre-processing steps
|
309 |
+
transcription = batch["sentence"]
|
310 |
+
if do_lower_case:
|
311 |
+
transcription = transcription.lower()
|
312 |
+
if do_remove_punctuation:
|
313 |
+
transcription = normalizer(transcription).strip()
|
314 |
+
|
315 |
+
# encode target text to label ids
|
316 |
+
batch["labels"] = processor.tokenizer(transcription).input_ids
|
317 |
+
return batch
|
318 |
+
|
319 |
+
return prepare_dataset
|
320 |
+
|
321 |
+
|
322 |
+
def save_model_hook(models, weights, output_dir):
|
323 |
+
for model in models:
|
324 |
+
model.save_pretrained(output_dir)
|
325 |
+
# make sure to pop weight so that corresponding model is not saved again
|
326 |
+
weights.pop()
|
327 |
+
|
328 |
+
|
329 |
+
def load_model_hook(models, input_dir):
|
330 |
+
while len(models) > 0:
|
331 |
+
model = models.pop()
|
332 |
+
# pop models so that they are not loaded again
|
333 |
+
PeftModel.from_pretrained(model.base_model.model, input_dir)
|
334 |
+
|
335 |
+
|
336 |
+
@dataclass
|
337 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
338 |
+
processor: Any
|
339 |
+
|
340 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
341 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
342 |
+
# first treat the audio inputs by simply returning torch tensors
|
343 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
344 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
345 |
+
|
346 |
+
# get the tokenized label sequences
|
347 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
348 |
+
# pad the labels to max length
|
349 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
350 |
+
|
351 |
+
# replace padding with -100 to ignore loss correctly
|
352 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
353 |
+
|
354 |
+
# if bos token is appended in previous tokenization step,
|
355 |
+
# cut bos token here as it's append later anyways
|
356 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
357 |
+
labels = labels[:, 1:]
|
358 |
+
|
359 |
+
batch["labels"] = labels
|
360 |
+
|
361 |
+
return batch
|
362 |
+
|
363 |
+
|
364 |
+
def get_audio_length_processor(max_input_length):
|
365 |
+
def is_audio_in_length_range(length):
|
366 |
+
return length < max_input_length
|
367 |
+
|
368 |
+
return is_audio_in_length_range
|
369 |
+
|
370 |
+
|
371 |
+
def evaluation_loop(model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator):
|
372 |
+
model.eval()
|
373 |
+
predictions = []
|
374 |
+
references = []
|
375 |
+
normalized_predictions = []
|
376 |
+
normalized_references = []
|
377 |
+
for _, batch in enumerate(tqdm(eval_dataloader)):
|
378 |
+
with torch.cuda.amp.autocast():
|
379 |
+
with torch.no_grad():
|
380 |
+
generated_tokens = (
|
381 |
+
model.generate(
|
382 |
+
input_features=batch["input_features"],
|
383 |
+
forced_decoder_ids=forced_decoder_ids,
|
384 |
+
max_new_tokens=255,
|
385 |
+
)
|
386 |
+
.cpu()
|
387 |
+
.numpy()
|
388 |
+
)
|
389 |
+
labels = batch["labels"].cpu().numpy()
|
390 |
+
labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)
|
391 |
+
decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
392 |
+
decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
393 |
+
predictions.extend(decoded_preds)
|
394 |
+
references.extend(decoded_labels)
|
395 |
+
normalized_predictions.extend([normalizer(pred).strip() for pred in decoded_preds])
|
396 |
+
normalized_references.extend([normalizer(label).strip() for label in decoded_labels])
|
397 |
+
del generated_tokens, labels, batch
|
398 |
+
gc.collect()
|
399 |
+
wer = 100 * metric.compute(predictions=predictions, references=references)
|
400 |
+
normalized_wer = 100 * metric.compute(predictions=normalized_predictions, references=normalized_references)
|
401 |
+
eval_metrics = {"eval/wer": wer, "eval/normalized_wer": normalized_wer}
|
402 |
+
if accelerator.get_tracker("wandb"):
|
403 |
+
sample_size = min(len(predictions), 256)
|
404 |
+
ids = [randint(0, len(predictions) - 1) for p in range(0, sample_size)]
|
405 |
+
sample_predictions = [predictions[i] for i in ids]
|
406 |
+
sample_references = [references[i] for i in ids]
|
407 |
+
sample_normalized_predictions = [normalized_predictions[i] for i in ids]
|
408 |
+
sample_normalized_references = [normalized_references[i] for i in ids]
|
409 |
+
table_rows = [
|
410 |
+
list(r)
|
411 |
+
for r in zip(
|
412 |
+
sample_predictions, sample_references, sample_normalized_predictions, sample_normalized_references
|
413 |
+
)
|
414 |
+
]
|
415 |
+
eval_metrics["eval_samples"] = wandb.Table(
|
416 |
+
columns=["predictions", "references", "normalized_predictions", "normalized_references"],
|
417 |
+
rows=table_rows,
|
418 |
+
)
|
419 |
+
return eval_metrics
|
420 |
+
|
421 |
+
|
422 |
+
def main():
|
423 |
+
args = parse_args()
|
424 |
+
|
425 |
+
# initialize accelerator
|
426 |
+
accelerator = (
|
427 |
+
Accelerator(
|
428 |
+
log_with=args.report_to,
|
429 |
+
project_dir=args.output_dir,
|
430 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
431 |
+
)
|
432 |
+
if args.with_tracking
|
433 |
+
else Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
434 |
+
)
|
435 |
+
|
436 |
+
# Make one log on every process with the configuration for debugging.
|
437 |
+
logging.basicConfig(
|
438 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
439 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
440 |
+
level=logging.INFO,
|
441 |
+
)
|
442 |
+
logger.info(accelerator.state, main_process_only=False)
|
443 |
+
if accelerator.is_local_main_process:
|
444 |
+
datasets.utils.logging.set_verbosity_warning()
|
445 |
+
transformers.utils.logging.set_verbosity_info()
|
446 |
+
else:
|
447 |
+
datasets.utils.logging.set_verbosity_error()
|
448 |
+
transformers.utils.logging.set_verbosity_error()
|
449 |
+
|
450 |
+
# If passed along, set the training seed now.
|
451 |
+
if args.seed is not None:
|
452 |
+
set_seed(args.seed)
|
453 |
+
|
454 |
+
# Handle the repository creation
|
455 |
+
if accelerator.is_main_process:
|
456 |
+
if args.push_to_hub:
|
457 |
+
if args.hub_model_id is None:
|
458 |
+
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
459 |
+
else:
|
460 |
+
repo_name = args.hub_model_id
|
461 |
+
repo = Repository(args.output_dir, clone_from=repo_name)
|
462 |
+
|
463 |
+
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
464 |
+
if "step_*" not in gitignore:
|
465 |
+
gitignore.write("step_*\n")
|
466 |
+
if "epoch_*" not in gitignore:
|
467 |
+
gitignore.write("epoch_*\n")
|
468 |
+
elif args.output_dir is not None:
|
469 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
470 |
+
accelerator.wait_for_everyone()
|
471 |
+
|
472 |
+
# load dataset either in streaming mode or not
|
473 |
+
processor = WhisperProcessor.from_pretrained(args.model_name_or_path, language=args.language, task=args.task)
|
474 |
+
normalizer = BasicTextNormalizer()
|
475 |
+
prepare_dataset = prepare_dataset_wrapper(args.do_lower_case, args.do_remove_punctuation, processor, normalizer)
|
476 |
+
is_audio_in_length_range = get_audio_length_processor(args.max_audio_input_length)
|
477 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
478 |
+
|
479 |
+
if args.dataset_in_streaming_mode:
|
480 |
+
raw_datasets = IterableDatasetDict()
|
481 |
+
loading_method = load_streaming_dataset
|
482 |
+
else:
|
483 |
+
raw_datasets = DatasetDict()
|
484 |
+
loading_method = load_dataset
|
485 |
+
|
486 |
+
if args.debug_mode:
|
487 |
+
train_split = "train[:100]"
|
488 |
+
test_split = "test[:10]"
|
489 |
+
else:
|
490 |
+
train_split = "train+validation"
|
491 |
+
test_split = "test"
|
492 |
+
|
493 |
+
raw_datasets["train"] = loading_method(
|
494 |
+
args.dataset_name, args.language_abbr, split=train_split, use_auth_token=True
|
495 |
+
)
|
496 |
+
raw_datasets["test"] = loading_method(args.dataset_name, args.language_abbr, split=test_split, use_auth_token=True)
|
497 |
+
raw_datasets = raw_datasets.cast_column("audio", Audio(sampling_rate=16000))
|
498 |
+
|
499 |
+
logger.info("Dataset loaded: %s", raw_datasets)
|
500 |
+
logger.info(f'{raw_datasets["train"][0]}')
|
501 |
+
|
502 |
+
vectorized_datasets = raw_datasets.map(
|
503 |
+
prepare_dataset,
|
504 |
+
remove_columns=list(next(iter(raw_datasets.values())).features),
|
505 |
+
num_proc=args.preprocessing_num_workers,
|
506 |
+
).with_format("torch")
|
507 |
+
|
508 |
+
if args.dataset_in_streaming_mode:
|
509 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(
|
510 |
+
buffer_size=args.buffer_size,
|
511 |
+
seed=args.seed,
|
512 |
+
)
|
513 |
+
|
514 |
+
# filter out audio files that are too long from the training set
|
515 |
+
is_audio_in_length_range = get_audio_length_processor(args.max_audio_input_length)
|
516 |
+
vectorized_datasets["train"] = vectorized_datasets["train"].filter(
|
517 |
+
is_audio_in_length_range, input_columns=["input_length"]
|
518 |
+
)
|
519 |
+
|
520 |
+
# get dataloaders
|
521 |
+
train_dataloader = DataLoader(
|
522 |
+
vectorized_datasets["train"],
|
523 |
+
batch_size=args.per_device_train_batch_size,
|
524 |
+
shuffle=True,
|
525 |
+
collate_fn=data_collator,
|
526 |
+
num_workers=args.dataloader_num_workers,
|
527 |
+
pin_memory=args.dataloader_pin_memory,
|
528 |
+
)
|
529 |
+
eval_dataloader = DataLoader(
|
530 |
+
vectorized_datasets["test"],
|
531 |
+
batch_size=args.per_device_eval_batch_size,
|
532 |
+
collate_fn=data_collator,
|
533 |
+
num_workers=args.dataloader_num_workers,
|
534 |
+
pin_memory=args.dataloader_pin_memory,
|
535 |
+
)
|
536 |
+
|
537 |
+
# metric
|
538 |
+
metric = evaluate.load("wer")
|
539 |
+
|
540 |
+
# model
|
541 |
+
model = WhisperForConditionalGeneration.from_pretrained(args.model_name_or_path, load_in_8bit=True)
|
542 |
+
model.config.forced_decoder_ids = None
|
543 |
+
model.config.suppress_tokens = []
|
544 |
+
if len(set(model.hf_device_map.values()).intersection({"cpu", "disk"})) > 0:
|
545 |
+
raise ValueError("Training on CPU or disk is not supported.")
|
546 |
+
if len(set(model.hf_device_map.values())) > 1:
|
547 |
+
device_map = model.hf_device_map.copy()
|
548 |
+
# required because `labels` are on main execution device (0) while the output of `proj_out` is on other device.
|
549 |
+
# So, this leads to device mismatch error when calculation cross-entropy between logits and labels.
|
550 |
+
# Won't arise during inference as `labels` aren't supplied during that time
|
551 |
+
# instead of changing device of one of the tied modules, I have to do this for all tied modules
|
552 |
+
# else the execution device of remaining tied modules isn't changed
|
553 |
+
device_map["model.decoder.embed_tokens"] = model._hf_hook.execution_device
|
554 |
+
device_map["model.decoder.embed_positions"] = model._hf_hook.execution_device
|
555 |
+
device_map["proj_out"] = model._hf_hook.execution_device
|
556 |
+
dispatch_model(model, device_map=device_map)
|
557 |
+
|
558 |
+
# preparing peft model
|
559 |
+
if args.use_peft:
|
560 |
+
from peft import prepare_model_for_int8_training
|
561 |
+
|
562 |
+
model = prepare_model_for_int8_training(model)
|
563 |
+
|
564 |
+
# as Whisper model uses Conv layer in encoder, checkpointing disables grad computation
|
565 |
+
# to avoid this, make the inputs trainable
|
566 |
+
def make_inputs_require_grad(module, input, output):
|
567 |
+
output.requires_grad_(True)
|
568 |
+
|
569 |
+
model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)
|
570 |
+
|
571 |
+
# wrapping model with adalora tuner
|
572 |
+
if args.use_adalora:
|
573 |
+
config = AdaLoraConfig(
|
574 |
+
init_r=args.init_r,
|
575 |
+
target_r=args.target_r,
|
576 |
+
beta1=0.85,
|
577 |
+
beta2=0.85,
|
578 |
+
tinit=args.tinit,
|
579 |
+
tfinal=args.tfinal,
|
580 |
+
deltaT=args.delta_t,
|
581 |
+
lora_alpha=args.lora_alpha,
|
582 |
+
lora_dropout=args.lora_dropout,
|
583 |
+
target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
584 |
+
orth_reg_weight=args.orth_reg_weight,
|
585 |
+
)
|
586 |
+
else:
|
587 |
+
config = LoraConfig(
|
588 |
+
r=args.r,
|
589 |
+
lora_alpha=args.lora_alpha,
|
590 |
+
target_modules=["q_proj", "v_proj"],
|
591 |
+
lora_dropout=args.lora_dropout,
|
592 |
+
)
|
593 |
+
|
594 |
+
model = get_peft_model(model, config)
|
595 |
+
model.print_trainable_parameters()
|
596 |
+
|
597 |
+
# optimizer
|
598 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
|
599 |
+
|
600 |
+
if args.max_train_steps is None:
|
601 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
602 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
603 |
+
else:
|
604 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
605 |
+
|
606 |
+
# scheduler
|
607 |
+
lr_scheduler = get_scheduler(
|
608 |
+
name=args.lr_scheduler_type,
|
609 |
+
optimizer=optimizer,
|
610 |
+
num_warmup_steps=args.num_warmup_steps,
|
611 |
+
num_training_steps=args.max_train_steps,
|
612 |
+
)
|
613 |
+
|
614 |
+
# Prepare everything with our `accelerator`.
|
615 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
|
616 |
+
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
|
617 |
+
)
|
618 |
+
|
619 |
+
accelerator.print(model)
|
620 |
+
|
621 |
+
# Note here that the max steps is adjusted by the accelerator's num_processes
|
622 |
+
args.max_train_steps = math.ceil(args.max_train_steps / accelerator.num_processes)
|
623 |
+
if args.use_peft and args.use_adalora:
|
624 |
+
model.base_model.peft_config["default"].total_step = args.max_train_steps
|
625 |
+
# model.base_model.peft_config.total_step = args.max_train_steps
|
626 |
+
|
627 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
628 |
+
# The trackers initializes automatically on the main process.
|
629 |
+
if args.with_tracking:
|
630 |
+
run_name = f"run-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
|
631 |
+
experiment_config = vars(args)
|
632 |
+
# TensorBoard cannot log Enums, need the raw value
|
633 |
+
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
|
634 |
+
accelerator.init_trackers(
|
635 |
+
"Whisper PEFT Fine-Tuning", config=experiment_config, init_kwargs={"wandb": {"name": run_name}}
|
636 |
+
)
|
637 |
+
|
638 |
+
# saving and loading checkpoints for resuming training
|
639 |
+
accelerator.register_save_state_pre_hook(save_model_hook)
|
640 |
+
accelerator.register_load_state_pre_hook(load_model_hook)
|
641 |
+
|
642 |
+
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
643 |
+
logger.info("***** Running training *****")
|
644 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
645 |
+
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
|
646 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
647 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
648 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
649 |
+
# Only show the progress bar once on each machine.
|
650 |
+
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
651 |
+
global_step = 0
|
652 |
+
starting_epoch = 0
|
653 |
+
best_metric = None
|
654 |
+
resume_step = 0
|
655 |
+
forced_decoder_ids = processor.get_decoder_prompt_ids(language=args.language, task=args.task)
|
656 |
+
|
657 |
+
# Potentially load in the weights and states from a previous save
|
658 |
+
if args.resume_from_checkpoint:
|
659 |
+
accelerator.load_state(args.resume_from_checkpoint)
|
660 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
661 |
+
training_difference = os.path.splitext(path)[0]
|
662 |
+
global_step = resume_step = int(training_difference.replace("step_", ""))
|
663 |
+
starting_epoch = resume_step // len(train_dataloader)
|
664 |
+
resume_step -= starting_epoch * len(train_dataloader)
|
665 |
+
|
666 |
+
# We need to adjust the progress bar to the current step
|
667 |
+
progress_bar.update(resume_step)
|
668 |
+
for epoch in range(starting_epoch, args.num_train_epochs):
|
669 |
+
model.train()
|
670 |
+
if args.with_tracking:
|
671 |
+
total_loss = 0
|
672 |
+
running_loss = 0
|
673 |
+
for step, batch in enumerate(accelerator.skip_first_batches(train_dataloader, num_batches=resume_step)):
|
674 |
+
with accelerator.accumulate(model):
|
675 |
+
outputs = model(**batch)
|
676 |
+
loss = outputs.loss
|
677 |
+
accelerator.backward(loss)
|
678 |
+
optimizer.step()
|
679 |
+
lr_scheduler.step()
|
680 |
+
|
681 |
+
# Update the importance of low-rank matrices
|
682 |
+
# and allocate the budget accordingly.
|
683 |
+
# This is only needed for AdaLora.
|
684 |
+
# Note that this requires parameter gradients.
|
685 |
+
# Hence being called before optimizer.zero_grad().
|
686 |
+
if args.use_peft and args.use_adalora:
|
687 |
+
model.update_and_allocate(global_step)
|
688 |
+
|
689 |
+
optimizer.zero_grad()
|
690 |
+
global_step += 1
|
691 |
+
progress_bar.update(1)
|
692 |
+
|
693 |
+
if args.with_tracking:
|
694 |
+
step_loss = accelerator.reduce(loss.detach().clone()).item()
|
695 |
+
total_loss += step_loss
|
696 |
+
running_loss += step_loss
|
697 |
+
|
698 |
+
if global_step % args.checkpointing_steps == 0:
|
699 |
+
output_dir = os.path.join(args.output_dir, f"step_{global_step}")
|
700 |
+
accelerator.save_state(output_dir)
|
701 |
+
|
702 |
+
if global_step % args.logging_steps == 0:
|
703 |
+
if args.with_tracking:
|
704 |
+
accelerator.log({"train/running_loss": running_loss / args.logging_steps}, step=global_step)
|
705 |
+
running_loss = 0
|
706 |
+
|
707 |
+
if global_step % args.evaluation_steps == 0:
|
708 |
+
eval_metrics = evaluation_loop(
|
709 |
+
model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator
|
710 |
+
)
|
711 |
+
if args.with_tracking:
|
712 |
+
logger.info(f"Step {global_step} eval metrics: {eval_metrics}")
|
713 |
+
accelerator.log(eval_metrics, step=global_step)
|
714 |
+
if best_metric is None or eval_metrics["eval/wer"] < best_metric:
|
715 |
+
best_metric = eval_metrics["eval/wer"]
|
716 |
+
accelerator.save_state(os.path.join(args.output_dir, "best_checkpoint"))
|
717 |
+
model.train()
|
718 |
+
|
719 |
+
if global_step >= args.max_train_steps:
|
720 |
+
break
|
721 |
+
|
722 |
+
if args.with_tracking:
|
723 |
+
train_epoch_loss = total_loss / (step + 1)
|
724 |
+
logger.info(f"Epoch {epoch} train loss: {train_epoch_loss}")
|
725 |
+
accelerator.log({"epoch/train_loss": train_epoch_loss}, step=epoch)
|
726 |
+
|
727 |
+
if args.push_to_hub and epoch <= args.num_train_epochs - 1:
|
728 |
+
accelerator.wait_for_everyone()
|
729 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
730 |
+
unwrapped_model.save_pretrained(args.output_dir, is_main_process=accelerator.is_main_process)
|
731 |
+
# evaluate the model at the end of training
|
732 |
+
eval_metrics = evaluation_loop(
|
733 |
+
model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator
|
734 |
+
)
|
735 |
+
if args.with_tracking:
|
736 |
+
logger.info(f"Step {global_step} eval metrics: {eval_metrics}")
|
737 |
+
accelerator.log(eval_metrics, step=global_step)
|
738 |
+
if best_metric is None or eval_metrics["eval/wer"] < best_metric:
|
739 |
+
best_metric = eval_metrics["eval/wer"]
|
740 |
+
accelerator.save_state(os.path.join(args.output_dir, "best_checkpoint"))
|
741 |
+
|
742 |
+
if accelerator.is_main_process:
|
743 |
+
processor.tokenizer.save_pretrained(args.output_dir)
|
744 |
+
repo.push_to_hub(
|
745 |
+
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
|
746 |
+
)
|
747 |
+
|
748 |
+
if args.load_best_model:
|
749 |
+
# load the best model
|
750 |
+
accelerator.load_state(os.path.join(args.output_dir, "best_checkpoint"))
|
751 |
+
model.resize_modules_by_rank_pattern(model.peft_config["default"].rank_pattern, "default")
|
752 |
+
eval_metrics = evaluation_loop(
|
753 |
+
model, eval_dataloader, processor, normalizer, metric, forced_decoder_ids, accelerator
|
754 |
+
)
|
755 |
+
if args.with_tracking:
|
756 |
+
best_metrics = {"best_" + k: v for k, v in eval_metrics.items()}
|
757 |
+
accelerator.log(best_metrics, step=global_step)
|
758 |
+
|
759 |
+
accelerator.wait_for_everyone()
|
760 |
+
unwrapped_model = accelerator.unwrap_model(model)
|
761 |
+
unwrapped_model.save_pretrained(args.output_dir, is_main_process=accelerator.is_main_process)
|
762 |
+
if accelerator.is_main_process:
|
763 |
+
processor.tokenizer.save_pretrained(args.output_dir)
|
764 |
+
if args.push_to_hub:
|
765 |
+
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
|
766 |
+
|
767 |
+
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
|
768 |
+
eval_metrics.pop("eval_samples")
|
769 |
+
json.dump(eval_metrics, f)
|
770 |
+
|
771 |
+
|
772 |
+
if __name__ == "__main__":
|
773 |
+
main()
|
src/merge_lora.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
|
5 |
+
from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizerFast,\
|
6 |
+
WhisperProcessor
|
7 |
+
from peft import PeftModel, PeftConfig
|
8 |
+
from utils.utils import print_arguments, add_arguments
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
11 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
12 |
+
add_arg("lora_model", type=str, default="output/whisper-tiny/checkpoint-best/", help="微调保存的模型路径")
|
13 |
+
add_arg('output_dir', type=str, default='models/', help="合并模型的保存目录")
|
14 |
+
add_arg("local_files_only", type=bool, default=False, help="是否只在本地加载模型,不尝试下载")
|
15 |
+
args = parser.parse_args()
|
16 |
+
print_arguments(args)
|
17 |
+
|
18 |
+
assert os.path.exists(args.lora_model), f"模型文件{args.lora_model}不存在"
|
19 |
+
|
20 |
+
peft_config = PeftConfig.from_pretrained(args.lora_model)
|
21 |
+
#
|
22 |
+
base_model = WhisperForConditionalGeneration.from_pretrained(peft_config.base_model_name_or_path, device_map={"": "cpu"},
|
23 |
+
local_files_only=args.local_files_only)
|
24 |
+
|
25 |
+
model = PeftModel.from_pretrained(base_model, args.lora_model, local_files_only=args.local_files_only)
|
26 |
+
feature_extractor = WhisperFeatureExtractor.from_pretrained(peft_config.base_model_name_or_path,
|
27 |
+
local_files_only=args.local_files_only)
|
28 |
+
tokenizer = WhisperTokenizerFast.from_pretrained(peft_config.base_model_name_or_path,
|
29 |
+
local_files_only=args.local_files_only)
|
30 |
+
processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path,
|
31 |
+
local_files_only=args.local_files_only)
|
32 |
+
|
33 |
+
|
34 |
+
model = model.merge_and_unload()
|
35 |
+
model.train(False)
|
36 |
+
|
37 |
+
save_directory = os.path.join(args.output_dir, f'{os.path.basename(peft_config.base_model_name_or_path)}-finetune')
|
38 |
+
os.makedirs(save_directory, exist_ok=True)
|
39 |
+
|
40 |
+
model.save_pretrained(save_directory)
|
41 |
+
feature_extractor.save_pretrained(save_directory)
|
42 |
+
tokenizer.save_pretrained(save_directory)
|
43 |
+
processor.save_pretrained(save_directory)
|
44 |
+
print(f'合并模型保持在:{save_directory}')
|
src/prepare_data.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
import datasets
|
4 |
+
from datasets import DatasetDict, load_dataset, concatenate_datasets
|
5 |
+
from tqdm import tqdm
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoFeatureExtractor,
|
9 |
+
AutoModelForSpeechSeq2Seq,
|
10 |
+
AutoTokenizer,
|
11 |
+
set_seed,
|
12 |
+
)
|
13 |
+
from transformers.utils.versions import require_version
|
14 |
+
from transformers.utils import check_min_version
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
from audiomentations import (
|
18 |
+
AddBackgroundNoise,
|
19 |
+
AddGaussianNoise,
|
20 |
+
Compose,
|
21 |
+
Gain,
|
22 |
+
OneOf,
|
23 |
+
PitchShift,
|
24 |
+
PolarityInversion,
|
25 |
+
TimeStretch,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
check_min_version("4.27.0.dev0")
|
30 |
+
|
31 |
+
require_version(
|
32 |
+
"datasets>=1.18.0",
|
33 |
+
"To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt",
|
34 |
+
)
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
from datasets import Dataset, DatasetDict
|
38 |
+
import torchaudio
|
39 |
+
from torchaudio import transforms as at
|
40 |
+
import pandas as pd
|
41 |
+
import torch
|
42 |
+
from pathlib import Path
|
43 |
+
import random
|
44 |
+
def main():
|
45 |
+
# Set seed before initializing model.
|
46 |
+
set_seed(42)
|
47 |
+
|
48 |
+
# 5. Load pretrained model, tokenizer, and feature extractor
|
49 |
+
#
|
50 |
+
# Distributed training:
|
51 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
52 |
+
config = AutoConfig.from_pretrained(
|
53 |
+
"openai/whisper-medium", revision="main", use_auth_token=True
|
54 |
+
)
|
55 |
+
|
56 |
+
config.update({"forced_decoder_ids": None, "suppress_tokens": None})
|
57 |
+
|
58 |
+
# *****************************SpecAugment for whisper models
|
59 |
+
# if getattr(config, "model_type", None) == "whisper":
|
60 |
+
config.update({"apply_spec_augment": True})
|
61 |
+
|
62 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
63 |
+
"openai/whisper-medium",
|
64 |
+
revision="main",
|
65 |
+
use_auth_token=True,
|
66 |
+
)
|
67 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
68 |
+
"openai/whisper-medium",
|
69 |
+
use_fast=True,
|
70 |
+
revision="main",
|
71 |
+
use_auth_token=True,
|
72 |
+
)
|
73 |
+
|
74 |
+
tokenizer.set_prefix_tokens(language="vi", task="transcribe")
|
75 |
+
|
76 |
+
# 7. Preprocessing the datasets.
|
77 |
+
# We need to read the audio files as arrays and tokenize the targets.
|
78 |
+
max_input_length = 30.0 * 16000
|
79 |
+
min_input_length = 0.0 * 16000
|
80 |
+
audio_column_name = "audio"
|
81 |
+
num_workers = 16
|
82 |
+
text_column_name = "text"
|
83 |
+
model_input_name = feature_extractor.model_input_names[0]
|
84 |
+
|
85 |
+
# if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
|
86 |
+
forward_attention_mask = True
|
87 |
+
|
88 |
+
# noise_dir = "../noise/ESC-50-master/audio/"
|
89 |
+
# define augmentation
|
90 |
+
augmentation = Compose(
|
91 |
+
[
|
92 |
+
TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=True),
|
93 |
+
Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1),
|
94 |
+
PitchShift(min_semitones=-4, max_semitones=4, p=0.2),
|
95 |
+
]
|
96 |
+
)
|
97 |
+
|
98 |
+
def augment_dataset(batch):
|
99 |
+
# load and (possibly) resample audio data to 16kHz
|
100 |
+
sample = batch["audio"]
|
101 |
+
|
102 |
+
# apply augmentation
|
103 |
+
augmented_waveform = augmentation(
|
104 |
+
sample, sample_rate=16000
|
105 |
+
)
|
106 |
+
batch["audio"]["array"] = augmented_waveform
|
107 |
+
return batch
|
108 |
+
|
109 |
+
def prepare_dataset(batch):
|
110 |
+
# process audio
|
111 |
+
sample = batch[audio_column_name]
|
112 |
+
inputs = feature_extractor(
|
113 |
+
sample,
|
114 |
+
sampling_rate= 16000,
|
115 |
+
return_attention_mask=forward_attention_mask,
|
116 |
+
)
|
117 |
+
# process audio length
|
118 |
+
batch[model_input_name] = inputs.get(model_input_name)[0]
|
119 |
+
batch["input_length"] = len(sample)
|
120 |
+
if forward_attention_mask:
|
121 |
+
batch["attention_mask"] = inputs.get("attention_mask")[0]
|
122 |
+
|
123 |
+
# process targets
|
124 |
+
input_str = batch[text_column_name]
|
125 |
+
batch["labels"] = tokenizer(input_str).input_ids
|
126 |
+
return batch
|
127 |
+
|
128 |
+
|
129 |
+
def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
|
130 |
+
waveform, sr = torchaudio.load(wave_path, normalize=True)
|
131 |
+
if sample_rate != sr:
|
132 |
+
waveform = at.Resample(sr, sample_rate)(waveform)
|
133 |
+
return waveform
|
134 |
+
|
135 |
+
|
136 |
+
def get_list_files_MITI(phase, sample_rate=16000, audio_max_sample_length=480000, fraction=0.15):
|
137 |
+
audio_list = []
|
138 |
+
text_list = []
|
139 |
+
if phase == 'train':
|
140 |
+
csv_file = 'vin_train.csv'
|
141 |
+
else:
|
142 |
+
csv_file = 'vin_test.csv'
|
143 |
+
df = pd.read_csv(csv_file)
|
144 |
+
|
145 |
+
# Calculate the number of samples to select based on the fraction
|
146 |
+
num_samples = int(len(df) * fraction)
|
147 |
+
|
148 |
+
# Randomly select the indices of samples
|
149 |
+
selected_indices = random.sample(range(len(df)), num_samples)
|
150 |
+
|
151 |
+
for index, row in tqdm(df.iterrows()):
|
152 |
+
if index not in selected_indices:
|
153 |
+
continue
|
154 |
+
|
155 |
+
new_path = Path(row['path'])
|
156 |
+
audio_id = index
|
157 |
+
text = row['sentence']
|
158 |
+
|
159 |
+
if new_path.exists():
|
160 |
+
audio = load_wave(new_path, sample_rate=sample_rate)[0]
|
161 |
+
if len(audio) > audio_max_sample_length or len(audio) < 0:
|
162 |
+
print('skip file:', new_path, 'with len audio', len(audio))
|
163 |
+
continue
|
164 |
+
audio_list.append(audio)
|
165 |
+
text_list.append(text)
|
166 |
+
|
167 |
+
return audio_list, text_list
|
168 |
+
|
169 |
+
# Assuming you have two CSV files, 'vin_train.csv' and 'vin_test.csv', in the same directory
|
170 |
+
|
171 |
+
# Get the training dataset
|
172 |
+
train_audio, train_text = get_list_files_MITI(phase='train')
|
173 |
+
|
174 |
+
# Get the testing dataset
|
175 |
+
test_audio, test_text = get_list_files_MITI(phase='test')
|
176 |
+
|
177 |
+
# Create the Dataset objects
|
178 |
+
train_dataset = Dataset.from_dict({"audio": train_audio, "text": train_text})
|
179 |
+
test_dataset = Dataset.from_dict({"audio": test_audio, "text": test_text})
|
180 |
+
|
181 |
+
# Create the DatasetDict
|
182 |
+
vin_100h = DatasetDict({"train": train_dataset, "test": test_dataset})
|
183 |
+
|
184 |
+
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
print(vin_100h)
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
vectorized_datasets = vin_100h.map(
|
195 |
+
prepare_dataset,
|
196 |
+
remove_columns=["audio", "text"],
|
197 |
+
num_proc=1,
|
198 |
+
desc="preprocess train dataset",
|
199 |
+
)
|
200 |
+
|
201 |
+
|
202 |
+
print(vectorized_datasets)
|
203 |
+
|
204 |
+
vectorized_datasets.save_to_disk(
|
205 |
+
"./vin_10h", num_proc=1
|
206 |
+
)
|
207 |
+
|
208 |
+
return
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
main()
|
src/realtime.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! python3.7
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import io
|
5 |
+
import os
|
6 |
+
import speech_recognition as sr
|
7 |
+
import whisperx
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from datetime import datetime, timedelta
|
11 |
+
from queue import Queue
|
12 |
+
from tempfile import NamedTemporaryFile
|
13 |
+
from time import sleep
|
14 |
+
from sys import platform
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--model", default="Vietnamese_ASR/ct2ranslate", help="Size of model or the local path for model ",
|
20 |
+
type=str)
|
21 |
+
parser.add_argument("--non_english", action='store_true',
|
22 |
+
help="Don't use the English model.")
|
23 |
+
parser.add_argument("--language", default="vi", help="The language to infer the model with whisper", type=str)
|
24 |
+
parser.add_argument("--device", default="cpu",
|
25 |
+
help="Choose device for inference "
|
26 |
+
, type=str)
|
27 |
+
parser.add_argument("--energy_threshold", default=900,
|
28 |
+
help="Energy level for mic to detect.", type=int)
|
29 |
+
parser.add_argument("--record_timeout", default=0.6,
|
30 |
+
help="How real-time the recording is in seconds.", type=float)
|
31 |
+
parser.add_argument("--phrase_timeout", default=3,
|
32 |
+
help="How much empty space between recordings before we "
|
33 |
+
"consider it a new line in the transcription.", type=float)
|
34 |
+
if 'linux' in platform:
|
35 |
+
parser.add_argument("--default_microphone", default='pulse',
|
36 |
+
help="Default microphone name for SpeechRecognition. "
|
37 |
+
"Run this with 'list' to view available Microphones.", type=str)
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
|
41 |
+
# The last time a recording was retreived from the queue.
|
42 |
+
phrase_time = None
|
43 |
+
# Current raw audio bytes.
|
44 |
+
last_sample = bytes()
|
45 |
+
# Thread safe Queue for passing data from the threaded recording callback.
|
46 |
+
data_queue = Queue()
|
47 |
+
# We use SpeechRecognizer to record our audio because it has a nice feauture where it can detect when speech ends.
|
48 |
+
recorder = sr.Recognizer()
|
49 |
+
recorder.energy_threshold = args.energy_threshold
|
50 |
+
# Definitely do this, dynamic energy compensation lowers the energy threshold dramtically to a point where the SpeechRecognizer never stops recording.
|
51 |
+
recorder.dynamic_energy_threshold = False
|
52 |
+
|
53 |
+
# Important for linux users.
|
54 |
+
# Prevents permanent application hang and crash by using the wrong Microphone
|
55 |
+
if 'linux' in platform:
|
56 |
+
mic_name = args.default_microphone
|
57 |
+
if not mic_name or mic_name == 'list':
|
58 |
+
print("Available microphone devices are: ")
|
59 |
+
for index, name in enumerate(sr.Microphone.list_microphone_names()):
|
60 |
+
print(f"Microphone with name \"{name}\" found")
|
61 |
+
return
|
62 |
+
else:
|
63 |
+
for index, name in enumerate(sr.Microphone.list_microphone_names()):
|
64 |
+
if mic_name in name:
|
65 |
+
source = sr.Microphone(sample_rate=16000, device_index=index)
|
66 |
+
break
|
67 |
+
else:
|
68 |
+
source = sr.Microphone(sample_rate=16000)
|
69 |
+
|
70 |
+
# Load / Download model
|
71 |
+
model = args.model
|
72 |
+
# if args.model != "large" and not args.non_english:
|
73 |
+
# model = model + ".en"
|
74 |
+
audio_model = whisperx.load_model(model, device=args.device, compute_type="float16", language = args.language)
|
75 |
+
|
76 |
+
record_timeout = args.record_timeout
|
77 |
+
phrase_timeout = args.phrase_timeout
|
78 |
+
|
79 |
+
temp_file = NamedTemporaryFile().name
|
80 |
+
transcription = ['']
|
81 |
+
|
82 |
+
with source:
|
83 |
+
recorder.adjust_for_ambient_noise(source)
|
84 |
+
|
85 |
+
def record_callback(_, audio:sr.AudioData) -> None:
|
86 |
+
"""
|
87 |
+
Threaded callback function to recieve audio data when recordings finish.
|
88 |
+
audio: An AudioData containing the recorded bytes.
|
89 |
+
"""
|
90 |
+
# Grab the raw bytes and push it into the thread safe queue.
|
91 |
+
data = audio.get_raw_data()
|
92 |
+
data_queue.put(data)
|
93 |
+
|
94 |
+
# Create a background thread that will pass us raw audio bytes.
|
95 |
+
# We could do this manually but SpeechRecognizer provides a nice helper.
|
96 |
+
recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
|
97 |
+
|
98 |
+
# Cue the user that we're ready to go.
|
99 |
+
print("Model loaded.\n")
|
100 |
+
|
101 |
+
while True:
|
102 |
+
try:
|
103 |
+
now = datetime.utcnow()
|
104 |
+
# Pull raw recorded audio from the queue.
|
105 |
+
if not data_queue.empty():
|
106 |
+
phrase_complete = False
|
107 |
+
# If enough time has passed between recordings, consider the phrase complete.
|
108 |
+
# Clear the current working audio buffer to start over with the new data.
|
109 |
+
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
|
110 |
+
last_sample = bytes()
|
111 |
+
phrase_complete = True
|
112 |
+
# This is the last time we received new audio data from the queue.
|
113 |
+
phrase_time = now
|
114 |
+
|
115 |
+
# Concatenate our current audio data with the latest audio data.
|
116 |
+
while not data_queue.empty():
|
117 |
+
data = data_queue.get()
|
118 |
+
last_sample += data
|
119 |
+
|
120 |
+
# Use AudioData to convert the raw data to wav data.
|
121 |
+
audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
|
122 |
+
wav_data = io.BytesIO(audio_data.get_wav_data())
|
123 |
+
|
124 |
+
# Write wav data to the temporary file as bytes.
|
125 |
+
with open(temp_file, 'w+b') as f:
|
126 |
+
f.write(wav_data.read())
|
127 |
+
|
128 |
+
# Read the transcription.
|
129 |
+
result = audio_model.transcribe(temp_file, language="en",batch_size = 8)
|
130 |
+
text = result['segments'][0]['text'].strip()
|
131 |
+
|
132 |
+
# If we detected a pause between recordings, add a new item to our transcripion.
|
133 |
+
# Otherwise edit the existing one.
|
134 |
+
if phrase_complete:
|
135 |
+
transcription.append(text)
|
136 |
+
else:
|
137 |
+
transcription[-1] = text
|
138 |
+
|
139 |
+
# Clear the console to reprint the updated transcription.
|
140 |
+
os.system('cls' if os.name=='nt' else 'clear')
|
141 |
+
for line in transcription:
|
142 |
+
print(line)
|
143 |
+
# Flush stdout.
|
144 |
+
print('', end='', flush=True)
|
145 |
+
|
146 |
+
# Infinite loops are bad for processors, must sleep.
|
147 |
+
sleep(0.25)
|
148 |
+
except KeyboardInterrupt:
|
149 |
+
break
|
150 |
+
|
151 |
+
print("\n\nTranscription:")
|
152 |
+
for line in transcription:
|
153 |
+
print(line)
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
main()
|
src/requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
git+https://github.com/huggingface/peft.git@main
|
2 |
+
bitsandbytes
|
3 |
+
accelerate
|
4 |
+
loralib
|
5 |
+
librosa
|
6 |
+
datasets>=2.6.1
|
7 |
+
evaluate>=0.3.0
|
8 |
+
jiwer
|
9 |
+
tensorboard
|
10 |
+
soundfile==0.12.1
|
11 |
+
git+https://github.com/m-bain/whisperX.git
|
12 |
+
#nvidia-cudnn-cu11-8.7.0.84 need
|
13 |
+
lightning-fabric
|
14 |
+
pyaudio
|
15 |
+
SpeechRecognition
|
src/test_whisper.ipynb
ADDED
@@ -0,0 +1,1546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"data": {
|
10 |
+
"application/vnd.jupyter.widget-view+json": {
|
11 |
+
"model_id": "9d7b03aae28b4282b143eb17c3d8d687",
|
12 |
+
"version_major": 2,
|
13 |
+
"version_minor": 0
|
14 |
+
},
|
15 |
+
"text/plain": [
|
16 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
17 |
+
]
|
18 |
+
},
|
19 |
+
"metadata": {},
|
20 |
+
"output_type": "display_data"
|
21 |
+
}
|
22 |
+
],
|
23 |
+
"source": [
|
24 |
+
"from huggingface_hub import notebook_login\n",
|
25 |
+
"\n",
|
26 |
+
"notebook_login()"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 2,
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [
|
34 |
+
{
|
35 |
+
"name": "stderr",
|
36 |
+
"output_type": "stream",
|
37 |
+
"text": [
|
38 |
+
"Found cached dataset vivos (/home/tesla/.cache/huggingface/datasets/vivos/default/1.1.0/ab59078eb266c1a0ea856786ba56b5b8d56f29b42dfb37d92115cf81a7b1a5e0)\n",
|
39 |
+
"Found cached dataset vivos (/home/tesla/.cache/huggingface/datasets/vivos/default/1.1.0/ab59078eb266c1a0ea856786ba56b5b8d56f29b42dfb37d92115cf81a7b1a5e0)\n"
|
40 |
+
]
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"source": [
|
44 |
+
"from datasets import load_dataset, DatasetDict\n",
|
45 |
+
"\n",
|
46 |
+
"vivos = DatasetDict()\n",
|
47 |
+
"\n",
|
48 |
+
"vivos[\"train\"] = load_dataset(\"vivos\", split=\"train\", use_auth_token=True)\n",
|
49 |
+
"vivos[\"test\"] = load_dataset(\"vivos\", split=\"test\", use_auth_token=True)\n",
|
50 |
+
"\n",
|
51 |
+
"\n"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": 38,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [
|
59 |
+
{
|
60 |
+
"data": {
|
61 |
+
"text/plain": [
|
62 |
+
"DatasetDict({\n",
|
63 |
+
" train: Dataset({\n",
|
64 |
+
" features: ['speaker_id', 'path', 'audio', 'sentence'],\n",
|
65 |
+
" num_rows: 11660\n",
|
66 |
+
" })\n",
|
67 |
+
" test: Dataset({\n",
|
68 |
+
" features: ['speaker_id', 'path', 'audio', 'sentence'],\n",
|
69 |
+
" num_rows: 760\n",
|
70 |
+
" })\n",
|
71 |
+
"})"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
"execution_count": 38,
|
75 |
+
"metadata": {},
|
76 |
+
"output_type": "execute_result"
|
77 |
+
}
|
78 |
+
],
|
79 |
+
"source": [
|
80 |
+
"vivos"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 3,
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [],
|
88 |
+
"source": [
|
89 |
+
"vivos_clean = vivos.remove_columns([\"speaker_id\", \"path\"])"
|
90 |
+
]
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"cell_type": "code",
|
94 |
+
"execution_count": 40,
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"data": {
|
99 |
+
"text/plain": [
|
100 |
+
"DatasetDict({\n",
|
101 |
+
" train: Dataset({\n",
|
102 |
+
" features: ['audio', 'sentence'],\n",
|
103 |
+
" num_rows: 11660\n",
|
104 |
+
" })\n",
|
105 |
+
" test: Dataset({\n",
|
106 |
+
" features: ['audio', 'sentence'],\n",
|
107 |
+
" num_rows: 760\n",
|
108 |
+
" })\n",
|
109 |
+
"})"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
"execution_count": 40,
|
113 |
+
"metadata": {},
|
114 |
+
"output_type": "execute_result"
|
115 |
+
}
|
116 |
+
],
|
117 |
+
"source": [
|
118 |
+
"vivos_clean"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 79,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [
|
126 |
+
{
|
127 |
+
"name": "stdout",
|
128 |
+
"output_type": "stream",
|
129 |
+
"text": [
|
130 |
+
"{'audio': {'path': 'vivos/train/waves/VIVOSSPK27/VIVOSSPK27_084.wav', 'array': array([ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, ...,\n",
|
131 |
+
" 9.15527344e-05, -5.18798828e-04, -9.15527344e-04]), 'sampling_rate': 16000}, 'sentence': 'CHƯA HẾT ĐI KHIẾU NẠI THÌ NHÀ MẠNG BẢO VỀ ĐẠI LÝ CHỌN SỐ KHÁC ĐI'}\n"
|
132 |
+
]
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"source": [
|
136 |
+
"print(vivos_clean['train'][12])"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 4,
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [
|
144 |
+
{
|
145 |
+
"name": "stderr",
|
146 |
+
"output_type": "stream",
|
147 |
+
"text": [
|
148 |
+
"Found cached dataset common_voice_13_0 (/home/tesla/.cache/huggingface/datasets/mozilla-foundation___common_voice_13_0/vi/13.0.0/2506e9a8950f5807ceae08c2920e814222909fd7f477b74f5d225802e9f04055)\n",
|
149 |
+
"Found cached dataset common_voice_13_0 (/home/tesla/.cache/huggingface/datasets/mozilla-foundation___common_voice_13_0/vi/13.0.0/2506e9a8950f5807ceae08c2920e814222909fd7f477b74f5d225802e9f04055)\n"
|
150 |
+
]
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"source": [
|
154 |
+
"\n",
|
155 |
+
"common_voice = DatasetDict()\n",
|
156 |
+
"\n",
|
157 |
+
"common_voice[\"train\"] = load_dataset(\"mozilla-foundation/common_voice_13_0\", \"vi\", split=\"train+validation\", use_auth_token=True)\n",
|
158 |
+
"common_voice[\"test\"] = load_dataset(\"mozilla-foundation/common_voice_13_0\", \"vi\", split=\"test\", use_auth_token=True)\n",
|
159 |
+
"\n",
|
160 |
+
"\n"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": 42,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [
|
168 |
+
{
|
169 |
+
"data": {
|
170 |
+
"text/plain": [
|
171 |
+
"DatasetDict({\n",
|
172 |
+
" train: Dataset({\n",
|
173 |
+
" features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],\n",
|
174 |
+
" num_rows: 2854\n",
|
175 |
+
" })\n",
|
176 |
+
" test: Dataset({\n",
|
177 |
+
" features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'variant'],\n",
|
178 |
+
" num_rows: 1225\n",
|
179 |
+
" })\n",
|
180 |
+
"})"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
"execution_count": 42,
|
184 |
+
"metadata": {},
|
185 |
+
"output_type": "execute_result"
|
186 |
+
}
|
187 |
+
],
|
188 |
+
"source": [
|
189 |
+
"common_voice"
|
190 |
+
]
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"execution_count": 67,
|
195 |
+
"metadata": {},
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"common_voice_clean = common_voice.remove_columns([\"client_id\", \"path\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\", \"age\", \"accent\", \"variant\"])\n"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"cell_type": "code",
|
203 |
+
"execution_count": 89,
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"data": {
|
208 |
+
"text/plain": [
|
209 |
+
"DatasetDict({\n",
|
210 |
+
" train: Dataset({\n",
|
211 |
+
" features: ['audio', 'sentence'],\n",
|
212 |
+
" num_rows: 2854\n",
|
213 |
+
" })\n",
|
214 |
+
" test: Dataset({\n",
|
215 |
+
" features: ['audio', 'sentence'],\n",
|
216 |
+
" num_rows: 1225\n",
|
217 |
+
" })\n",
|
218 |
+
"})"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
"execution_count": 89,
|
222 |
+
"metadata": {},
|
223 |
+
"output_type": "execute_result"
|
224 |
+
}
|
225 |
+
],
|
226 |
+
"source": [
|
227 |
+
"common_voice_clean"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": 1,
|
233 |
+
"metadata": {},
|
234 |
+
"outputs": [
|
235 |
+
{
|
236 |
+
"ename": "NameError",
|
237 |
+
"evalue": "name 'common_voice_clean' is not defined",
|
238 |
+
"output_type": "error",
|
239 |
+
"traceback": [
|
240 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
241 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
242 |
+
"Cell \u001b[0;32mIn[1], line 9\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[39mreturn\u001b[39;00m example\n\u001b[1;32m 7\u001b[0m common_voice_clear \u001b[39m=\u001b[39m DatasetDict()\n\u001b[0;32m----> 9\u001b[0m common_voice_clear[\u001b[39m\"\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m common_voice_clean[\u001b[39m\"\u001b[39m\u001b[39mtrain\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mmap(convert_to_uppercase)\n\u001b[1;32m 10\u001b[0m common_voice_clear[\u001b[39m\"\u001b[39m\u001b[39mtest\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m common_voice_clean[\u001b[39m\"\u001b[39m\u001b[39mtest\u001b[39m\u001b[39m\"\u001b[39m]\u001b[39m.\u001b[39mmap(convert_to_uppercase)\n",
|
243 |
+
"\u001b[0;31mNameError\u001b[0m: name 'common_voice_clean' is not defined"
|
244 |
+
]
|
245 |
+
}
|
246 |
+
],
|
247 |
+
"source": [
|
248 |
+
"from datasets import DatasetDict\n",
|
249 |
+
"\n",
|
250 |
+
"def convert_to_uppercase(example):\n",
|
251 |
+
" example[\"sentence\"] = example[\"sentence\"].upper()\n",
|
252 |
+
" return example\n",
|
253 |
+
"\n",
|
254 |
+
"common_voice_clear = DatasetDict()\n",
|
255 |
+
"\n",
|
256 |
+
"common_voice_clear[\"train\"] = common_voice_clean[\"train\"].map(convert_to_uppercase)\n",
|
257 |
+
"common_voice_clear[\"test\"] = common_voice_clean[\"test\"].map(convert_to_uppercase)"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": 93,
|
263 |
+
"metadata": {},
|
264 |
+
"outputs": [
|
265 |
+
{
|
266 |
+
"name": "stdout",
|
267 |
+
"output_type": "stream",
|
268 |
+
"text": [
|
269 |
+
"{'audio': {'path': '/home/tesla/.cache/huggingface/datasets/downloads/extracted/acb70896120347904e003bb826dcabc1ddd05a02210935cb44ce1c807e8742a5/vi_train_0/common_voice_vi_23901118.mp3', 'array': array([ 0.00000000e+00, 4.20543185e-14, 1.38823347e-14, ...,\n",
|
270 |
+
" -8.41874498e-06, -8.36193431e-06, -6.76584477e-06]), 'sampling_rate': 48000}, 'sentence': 'KHI CON CÓ MẸ'}\n"
|
271 |
+
]
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"source": [
|
275 |
+
"print(common_voice_clear['train'][1])"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "code",
|
280 |
+
"execution_count": 94,
|
281 |
+
"metadata": {},
|
282 |
+
"outputs": [
|
283 |
+
{
|
284 |
+
"name": "stderr",
|
285 |
+
"output_type": "stream",
|
286 |
+
"text": [
|
287 |
+
"100%|██████████| 2854/2854 [33:25<00:00, 1.42it/s]\n"
|
288 |
+
]
|
289 |
+
}
|
290 |
+
],
|
291 |
+
"source": [
|
292 |
+
"from pydub import AudioSegment\n",
|
293 |
+
"import os\n",
|
294 |
+
"\n",
|
295 |
+
"from tqdm import tqdm\n",
|
296 |
+
"\n",
|
297 |
+
"def convert_mp3_to_wav(mp3_path, wav_path, target_sampling_rate):\n",
|
298 |
+
" audio = AudioSegment.from_mp3(mp3_path)\n",
|
299 |
+
" audio = audio.set_frame_rate(target_sampling_rate)\n",
|
300 |
+
" audio.export(wav_path, format='wav')\n",
|
301 |
+
"\n",
|
302 |
+
"target_sampling_rate = 16000\n",
|
303 |
+
"\n",
|
304 |
+
"for example in tqdm(common_voice_clear[\"train\"]):\n",
|
305 |
+
" mp3_path = example[\"audio\"][\"path\"]\n",
|
306 |
+
" wav_path = os.path.splitext(mp3_path)[0] + \".wav\"\n",
|
307 |
+
" convert_mp3_to_wav(mp3_path, wav_path, target_sampling_rate)\n",
|
308 |
+
" example[\"audio\"][\"path\"] = wav_path\n",
|
309 |
+
"\n"
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "code",
|
314 |
+
"execution_count": 95,
|
315 |
+
"metadata": {},
|
316 |
+
"outputs": [],
|
317 |
+
"source": [
|
318 |
+
"import datasets\n",
|
319 |
+
"from datasets import Audio\n",
|
320 |
+
"\n",
|
321 |
+
"common_voice_clean = common_voice_clean.cast_column(\"audio\", Audio(sampling_rate=16000))"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "code",
|
326 |
+
"execution_count": 47,
|
327 |
+
"metadata": {},
|
328 |
+
"outputs": [],
|
329 |
+
"source": [
|
330 |
+
"concat = DatasetDict()"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": 96,
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [],
|
338 |
+
"source": [
|
339 |
+
"concat[\"train\"] = datasets.concatenate_datasets([common_voice_clean[\"train\"], vivos_clean[\"train\"]])\n",
|
340 |
+
"\n",
|
341 |
+
"#concat['test']= datasets.concatenate_datasets([common_voice_clean[\"test\"], vivos_clean[\"test\"]])\n",
|
342 |
+
"concat['test']= vivos_clean[\"test\"]\n"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": 97,
|
348 |
+
"metadata": {},
|
349 |
+
"outputs": [
|
350 |
+
{
|
351 |
+
"data": {
|
352 |
+
"text/plain": [
|
353 |
+
"DatasetDict({\n",
|
354 |
+
" train: Dataset({\n",
|
355 |
+
" features: ['audio', 'sentence'],\n",
|
356 |
+
" num_rows: 14514\n",
|
357 |
+
" })\n",
|
358 |
+
" test: Dataset({\n",
|
359 |
+
" features: ['audio', 'sentence'],\n",
|
360 |
+
" num_rows: 760\n",
|
361 |
+
" })\n",
|
362 |
+
"})"
|
363 |
+
]
|
364 |
+
},
|
365 |
+
"execution_count": 97,
|
366 |
+
"metadata": {},
|
367 |
+
"output_type": "execute_result"
|
368 |
+
}
|
369 |
+
],
|
370 |
+
"source": [
|
371 |
+
"concat"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": 98,
|
377 |
+
"metadata": {},
|
378 |
+
"outputs": [],
|
379 |
+
"source": [
|
380 |
+
"from transformers import WhisperFeatureExtractor\n",
|
381 |
+
"\n",
|
382 |
+
"feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-small\")\n"
|
383 |
+
]
|
384 |
+
},
|
385 |
+
{
|
386 |
+
"cell_type": "code",
|
387 |
+
"execution_count": 99,
|
388 |
+
"metadata": {},
|
389 |
+
"outputs": [],
|
390 |
+
"source": [
|
391 |
+
"from transformers import WhisperTokenizerFast\n",
|
392 |
+
"\n",
|
393 |
+
"tokenizer = WhisperTokenizerFast.from_pretrained(\"openai/whisper-small\", language=\"Vietnamese\", task=\"transcribe\")\n"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"cell_type": "code",
|
398 |
+
"execution_count": 80,
|
399 |
+
"metadata": {},
|
400 |
+
"outputs": [
|
401 |
+
{
|
402 |
+
"name": "stdout",
|
403 |
+
"output_type": "stream",
|
404 |
+
"text": [
|
405 |
+
"Input: KHÔNG CÓ AI BÁC BỎ QUYỀN ĐÓ\n",
|
406 |
+
"Decoded w/ special: <|startoftranscript|><|notimestamps|>KHÔNG CÓ AI BÁC BỎ QUYỀN ĐÓ<|endoftext|>\n",
|
407 |
+
"Decoded w/out special: KHÔNG CÓ AI BÁC BỎ QUYỀN ĐÓ\n",
|
408 |
+
"Are equal: True\n"
|
409 |
+
]
|
410 |
+
}
|
411 |
+
],
|
412 |
+
"source": [
|
413 |
+
"input_str = concat[\"train\"][8550][\"sentence\"]\n",
|
414 |
+
"labels = tokenizer(input_str).input_ids\n",
|
415 |
+
"decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)\n",
|
416 |
+
"decoded_str = tokenizer.decode(labels, skip_special_tokens=True)\n",
|
417 |
+
"\n",
|
418 |
+
"print(f\"Input: {input_str}\")\n",
|
419 |
+
"print(f\"Decoded w/ special: {decoded_with_special}\")\n",
|
420 |
+
"print(f\"Decoded w/out special: {decoded_str}\")\n",
|
421 |
+
"print(f\"Are equal: {input_str == decoded_str}\")\n"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"cell_type": "code",
|
426 |
+
"execution_count": 100,
|
427 |
+
"metadata": {},
|
428 |
+
"outputs": [],
|
429 |
+
"source": [
|
430 |
+
"from transformers import WhisperProcessor\n",
|
431 |
+
"\n",
|
432 |
+
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", language=\"Vietnamese\", task=\"transcribe\")\n"
|
433 |
+
]
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"cell_type": "code",
|
437 |
+
"execution_count": 19,
|
438 |
+
"metadata": {},
|
439 |
+
"outputs": [],
|
440 |
+
"source": [
|
441 |
+
"from datasets import Audio\n",
|
442 |
+
"\n",
|
443 |
+
"concat = concat.cast_column(\"audio\", Audio(sampling_rate=16000))"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"cell_type": "code",
|
448 |
+
"execution_count": 59,
|
449 |
+
"metadata": {},
|
450 |
+
"outputs": [
|
451 |
+
{
|
452 |
+
"name": "stdout",
|
453 |
+
"output_type": "stream",
|
454 |
+
"text": [
|
455 |
+
"{'audio': {'path': 'vivos/train/waves/VIVOSSPK12/VIVOSSPK12_R077.wav', 'array': array([ 0.00000000e+00, 0.00000000e+00, -3.05175781e-05, ...,\n",
|
456 |
+
" 1.31225586e-03, 1.12915039e-03, 1.55639648e-03]), 'sampling_rate': 16000}, 'sentence': 'KIÊN GIANG'}\n"
|
457 |
+
]
|
458 |
+
}
|
459 |
+
],
|
460 |
+
"source": [
|
461 |
+
"print(concat[\"train\"][4500])"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "code",
|
466 |
+
"execution_count": 101,
|
467 |
+
"metadata": {},
|
468 |
+
"outputs": [],
|
469 |
+
"source": [
|
470 |
+
"def prepare_dataset(batch):\n",
|
471 |
+
" # load and resample audio data from 48 to 16kHz\n",
|
472 |
+
" audio = batch[\"audio\"]\n",
|
473 |
+
"\n",
|
474 |
+
" # compute log-Mel input features from input audio array \n",
|
475 |
+
" batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
|
476 |
+
"\n",
|
477 |
+
" # encode target text to label ids \n",
|
478 |
+
" batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
|
479 |
+
" return batch\n"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"cell_type": "code",
|
484 |
+
"execution_count": 102,
|
485 |
+
"metadata": {},
|
486 |
+
"outputs": [
|
487 |
+
{
|
488 |
+
"data": {
|
489 |
+
"application/vnd.jupyter.widget-view+json": {
|
490 |
+
"model_id": "c35c921e0dde433fb0ef9346310238a3",
|
491 |
+
"version_major": 2,
|
492 |
+
"version_minor": 0
|
493 |
+
},
|
494 |
+
"text/plain": [
|
495 |
+
"Map (num_proc=6): 0%| | 0/14514 [00:00<?, ? examples/s]"
|
496 |
+
]
|
497 |
+
},
|
498 |
+
"metadata": {},
|
499 |
+
"output_type": "display_data"
|
500 |
+
},
|
501 |
+
{
|
502 |
+
"data": {
|
503 |
+
"application/vnd.jupyter.widget-view+json": {
|
504 |
+
"model_id": "8c5af4ed5f8141d2b0673972f7616941",
|
505 |
+
"version_major": 2,
|
506 |
+
"version_minor": 0
|
507 |
+
},
|
508 |
+
"text/plain": [
|
509 |
+
"Map (num_proc=6): 0%| | 0/760 [00:00<?, ? examples/s]"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
"metadata": {},
|
513 |
+
"output_type": "display_data"
|
514 |
+
}
|
515 |
+
],
|
516 |
+
"source": [
|
517 |
+
"concat = concat.map(prepare_dataset, remove_columns=concat.column_names[\"train\"], num_proc=6)"
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"cell_type": "code",
|
522 |
+
"execution_count": 103,
|
523 |
+
"metadata": {},
|
524 |
+
"outputs": [],
|
525 |
+
"source": [
|
526 |
+
"import torch\n",
|
527 |
+
"\n",
|
528 |
+
"from dataclasses import dataclass\n",
|
529 |
+
"from typing import Any, Dict, List, Union\n",
|
530 |
+
"\n",
|
531 |
+
"@dataclass\n",
|
532 |
+
"class DataCollatorSpeechSeq2SeqWithPadding:\n",
|
533 |
+
" processor: Any\n",
|
534 |
+
"\n",
|
535 |
+
" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
|
536 |
+
" # split inputs and labels since they have to be of different lengths and need different padding methods\n",
|
537 |
+
" # first treat the audio inputs by simply returning torch tensors\n",
|
538 |
+
" input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
|
539 |
+
" batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
|
540 |
+
"\n",
|
541 |
+
" # get the tokenized label sequences\n",
|
542 |
+
" label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
|
543 |
+
" # pad the labels to max length\n",
|
544 |
+
" labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
|
545 |
+
"\n",
|
546 |
+
" # replace padding with -100 to ignore loss correctly\n",
|
547 |
+
" labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
|
548 |
+
"\n",
|
549 |
+
" # if bos token is appended in previous tokenization step,\n",
|
550 |
+
" # cut bos token here as it's append later anyways\n",
|
551 |
+
" if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
|
552 |
+
" labels = labels[:, 1:]\n",
|
553 |
+
"\n",
|
554 |
+
" batch[\"labels\"] = labels\n",
|
555 |
+
"\n",
|
556 |
+
" return batch\n"
|
557 |
+
]
|
558 |
+
},
|
559 |
+
{
|
560 |
+
"cell_type": "code",
|
561 |
+
"execution_count": 104,
|
562 |
+
"metadata": {},
|
563 |
+
"outputs": [],
|
564 |
+
"source": [
|
565 |
+
"data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
|
566 |
+
]
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"cell_type": "code",
|
570 |
+
"execution_count": 105,
|
571 |
+
"metadata": {},
|
572 |
+
"outputs": [],
|
573 |
+
"source": [
|
574 |
+
"import os\n",
|
575 |
+
"\n",
|
576 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\" "
|
577 |
+
]
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"cell_type": "markdown",
|
581 |
+
"metadata": {},
|
582 |
+
"source": [
|
583 |
+
"Train\n"
|
584 |
+
]
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"cell_type": "code",
|
588 |
+
"execution_count": 106,
|
589 |
+
"metadata": {},
|
590 |
+
"outputs": [],
|
591 |
+
"source": [
|
592 |
+
"import evaluate\n",
|
593 |
+
"\n",
|
594 |
+
"metric = evaluate.load(\"wer\")\n",
|
595 |
+
"\n",
|
596 |
+
"\n",
|
597 |
+
"def compute_metrics(pred):\n",
|
598 |
+
" pred_ids = pred.predictions\n",
|
599 |
+
" label_ids = pred.label_ids\n",
|
600 |
+
"\n",
|
601 |
+
" # replace -100 with the pad_token_id\n",
|
602 |
+
" label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
|
603 |
+
"\n",
|
604 |
+
" # we do not want to group tokens when computing the metrics\n",
|
605 |
+
" pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
|
606 |
+
" label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
|
607 |
+
"\n",
|
608 |
+
" wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n",
|
609 |
+
"\n",
|
610 |
+
" return {\"wer\": wer}\n"
|
611 |
+
]
|
612 |
+
},
|
613 |
+
{
|
614 |
+
"cell_type": "code",
|
615 |
+
"execution_count": 107,
|
616 |
+
"metadata": {},
|
617 |
+
"outputs": [],
|
618 |
+
"source": [
|
619 |
+
"from transformers import WhisperForConditionalGeneration\n",
|
620 |
+
"\n",
|
621 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-small\")\n"
|
622 |
+
]
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"cell_type": "code",
|
626 |
+
"execution_count": 108,
|
627 |
+
"metadata": {},
|
628 |
+
"outputs": [],
|
629 |
+
"source": [
|
630 |
+
"model.config.forced_decoder_ids = None\n",
|
631 |
+
"model.config.suppress_tokens = []"
|
632 |
+
]
|
633 |
+
},
|
634 |
+
{
|
635 |
+
"cell_type": "code",
|
636 |
+
"execution_count": 109,
|
637 |
+
"metadata": {},
|
638 |
+
"outputs": [],
|
639 |
+
"source": [
|
640 |
+
"from transformers import Seq2SeqTrainingArguments\n",
|
641 |
+
"\n",
|
642 |
+
"training_args = Seq2SeqTrainingArguments(\n",
|
643 |
+
" output_dir=\"./vi_whisper-small\", # change to a repo name of your choice\n",
|
644 |
+
" per_device_train_batch_size=16,\n",
|
645 |
+
" gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
|
646 |
+
" learning_rate=1e-4,\n",
|
647 |
+
" warmup_steps=1000,\n",
|
648 |
+
" max_steps=8000,\n",
|
649 |
+
" gradient_checkpointing=True,\n",
|
650 |
+
" fp16=True,\n",
|
651 |
+
" evaluation_strategy=\"steps\",\n",
|
652 |
+
" per_device_eval_batch_size=8,\n",
|
653 |
+
" predict_with_generate=True,\n",
|
654 |
+
" generation_max_length=225,\n",
|
655 |
+
" save_steps=4000,\n",
|
656 |
+
" eval_steps=1000,\n",
|
657 |
+
" logging_steps=25,\n",
|
658 |
+
" report_to=[\"tensorboard\"],\n",
|
659 |
+
" load_best_model_at_end=True,\n",
|
660 |
+
" metric_for_best_model=\"wer\",\n",
|
661 |
+
" greater_is_better=False,\n",
|
662 |
+
" push_to_hub=True,\n",
|
663 |
+
")\n"
|
664 |
+
]
|
665 |
+
},
|
666 |
+
{
|
667 |
+
"cell_type": "code",
|
668 |
+
"execution_count": 126,
|
669 |
+
"metadata": {},
|
670 |
+
"outputs": [
|
671 |
+
{
|
672 |
+
"name": "stderr",
|
673 |
+
"output_type": "stream",
|
674 |
+
"text": [
|
675 |
+
"/media/tesla/New Volume1/DEMO/DUY/Vietnamese_ASR/./vi_whisper-small is already a clone of https://huggingface.co/DuyTa/vi_whisper-small. Make sure you pull the latest changes with `repo.git_pull()`.\n"
|
676 |
+
]
|
677 |
+
},
|
678 |
+
{
|
679 |
+
"ename": "OSError",
|
680 |
+
"evalue": "From https://huggingface.co/DuyTa/vi_whisper-small\n d7893fc..47c00b5 main -> origin/main\nhint: You have divergent branches and need to specify how to reconcile them.\nhint: You can do so by running one of the following commands sometime before\nhint: your next pull:\nhint: \nhint: git config pull.rebase false # merge (the default strategy)\nhint: git config pull.rebase true # rebase\nhint: git config pull.ff only # fast-forward only\nhint: \nhint: You can replace \"git config\" with \"git config --global\" to set a default\nhint: preference for all repositories. You can also pass --rebase, --no-rebase,\nhint: or --ff-only on the command line to override the configured default per\nhint: invocation.\nfatal: Need to specify how to reconcile divergent branches.\n",
|
681 |
+
"output_type": "error",
|
682 |
+
"traceback": [
|
683 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
684 |
+
"\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)",
|
685 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:984\u001b[0m, in \u001b[0;36mRepository.git_pull\u001b[0;34m(self, rebase, lfs)\u001b[0m\n\u001b[1;32m 983\u001b[0m \u001b[39mwith\u001b[39;00m _lfs_log_progress():\n\u001b[0;32m--> 984\u001b[0m result \u001b[39m=\u001b[39m run_subprocess(command, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mlocal_dir)\n\u001b[1;32m 985\u001b[0m logger\u001b[39m.\u001b[39minfo(result\u001b[39m.\u001b[39mstdout)\n",
|
686 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/utils/_subprocess.py:83\u001b[0m, in \u001b[0;36mrun_subprocess\u001b[0;34m(command, folder, check, **kwargs)\u001b[0m\n\u001b[1;32m 81\u001b[0m folder \u001b[39m=\u001b[39m \u001b[39mstr\u001b[39m(folder)\n\u001b[0;32m---> 83\u001b[0m \u001b[39mreturn\u001b[39;00m subprocess\u001b[39m.\u001b[39;49mrun(\n\u001b[1;32m 84\u001b[0m command,\n\u001b[1;32m 85\u001b[0m stderr\u001b[39m=\u001b[39;49msubprocess\u001b[39m.\u001b[39;49mPIPE,\n\u001b[1;32m 86\u001b[0m stdout\u001b[39m=\u001b[39;49msubprocess\u001b[39m.\u001b[39;49mPIPE,\n\u001b[1;32m 87\u001b[0m check\u001b[39m=\u001b[39;49mcheck,\n\u001b[1;32m 88\u001b[0m encoding\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mutf-8\u001b[39;49m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 89\u001b[0m errors\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mreplace\u001b[39;49m\u001b[39m\"\u001b[39;49m, \u001b[39m# if not utf-8, replace char by �\u001b[39;49;00m\n\u001b[1;32m 90\u001b[0m cwd\u001b[39m=\u001b[39;49mfolder \u001b[39mor\u001b[39;49;00m os\u001b[39m.\u001b[39;49mgetcwd(),\n\u001b[1;32m 91\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 92\u001b[0m )\n",
|
687 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/subprocess.py:528\u001b[0m, in \u001b[0;36mrun\u001b[0;34m(input, capture_output, timeout, check, *popenargs, **kwargs)\u001b[0m\n\u001b[1;32m 527\u001b[0m \u001b[39mif\u001b[39;00m check \u001b[39mand\u001b[39;00m retcode:\n\u001b[0;32m--> 528\u001b[0m \u001b[39mraise\u001b[39;00m CalledProcessError(retcode, process\u001b[39m.\u001b[39margs,\n\u001b[1;32m 529\u001b[0m output\u001b[39m=\u001b[39mstdout, stderr\u001b[39m=\u001b[39mstderr)\n\u001b[1;32m 530\u001b[0m \u001b[39mreturn\u001b[39;00m CompletedProcess(process\u001b[39m.\u001b[39margs, retcode, stdout, stderr)\n",
|
688 |
+
"\u001b[0;31mCalledProcessError\u001b[0m: Command '['git', 'pull']' returned non-zero exit status 128.",
|
689 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
690 |
+
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
691 |
+
"Cell \u001b[0;32mIn[126], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformers\u001b[39;00m \u001b[39mimport\u001b[39;00m Seq2SeqTrainer\n\u001b[0;32m----> 3\u001b[0m trainer \u001b[39m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 4\u001b[0m args\u001b[39m=\u001b[39;49mtraining_args,\n\u001b[1;32m 5\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 6\u001b[0m train_dataset\u001b[39m=\u001b[39;49mconcat[\u001b[39m\"\u001b[39;49m\u001b[39mtrain\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 7\u001b[0m \n\u001b[1;32m 8\u001b[0m \n\u001b[1;32m 9\u001b[0m \n\u001b[1;32m 10\u001b[0m \n\u001b[1;32m 11\u001b[0m eval_dataset\u001b[39m=\u001b[39;49mconcat[\u001b[39m\"\u001b[39;49m\u001b[39mtest\u001b[39;49m\u001b[39m\"\u001b[39;49m],\n\u001b[1;32m 12\u001b[0m data_collator\u001b[39m=\u001b[39;49mdata_collator,\n\u001b[1;32m 13\u001b[0m compute_metrics\u001b[39m=\u001b[39;49mcompute_metrics,\n\u001b[1;32m 14\u001b[0m tokenizer\u001b[39m=\u001b[39;49mprocessor\u001b[39m.\u001b[39;49mfeature_extractor,\n\u001b[1;32m 15\u001b[0m )\n",
|
692 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer_seq2seq.py:56\u001b[0m, in \u001b[0;36mSeq2SeqTrainer.__init__\u001b[0;34m(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\n\u001b[1;32m 43\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 44\u001b[0m model: Union[\u001b[39m\"\u001b[39m\u001b[39mPreTrainedModel\u001b[39m\u001b[39m\"\u001b[39m, nn\u001b[39m.\u001b[39mModule] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 54\u001b[0m preprocess_logits_for_metrics: Optional[Callable[[torch\u001b[39m.\u001b[39mTensor, torch\u001b[39m.\u001b[39mTensor], torch\u001b[39m.\u001b[39mTensor]] \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 55\u001b[0m ):\n\u001b[0;32m---> 56\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\n\u001b[1;32m 57\u001b[0m model\u001b[39m=\u001b[39;49mmodel,\n\u001b[1;32m 58\u001b[0m args\u001b[39m=\u001b[39;49margs,\n\u001b[1;32m 59\u001b[0m data_collator\u001b[39m=\u001b[39;49mdata_collator,\n\u001b[1;32m 60\u001b[0m train_dataset\u001b[39m=\u001b[39;49mtrain_dataset,\n\u001b[1;32m 61\u001b[0m eval_dataset\u001b[39m=\u001b[39;49meval_dataset,\n\u001b[1;32m 62\u001b[0m tokenizer\u001b[39m=\u001b[39;49mtokenizer,\n\u001b[1;32m 63\u001b[0m model_init\u001b[39m=\u001b[39;49mmodel_init,\n\u001b[1;32m 64\u001b[0m compute_metrics\u001b[39m=\u001b[39;49mcompute_metrics,\n\u001b[1;32m 65\u001b[0m callbacks\u001b[39m=\u001b[39;49mcallbacks,\n\u001b[1;32m 66\u001b[0m optimizers\u001b[39m=\u001b[39;49moptimizers,\n\u001b[1;32m 67\u001b[0m preprocess_logits_for_metrics\u001b[39m=\u001b[39;49mpreprocess_logits_for_metrics,\n\u001b[1;32m 68\u001b[0m )\n\u001b[1;32m 70\u001b[0m \u001b[39m# Override self.model.generation_config if a GenerationConfig is specified in args.\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \u001b[39m# Priority: args.generation_config > model.generation_config > default GenerationConfig.\u001b[39;00m\n\u001b[1;32m 72\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mgeneration_config \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
|
693 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer.py:551\u001b[0m, in \u001b[0;36mTrainer.__init__\u001b[0;34m(self, model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[39m# Create clone of distant repo and output directory if needed\u001b[39;00m\n\u001b[1;32m 550\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mpush_to_hub:\n\u001b[0;32m--> 551\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49minit_git_repo(at_init\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 552\u001b[0m \u001b[39m# In case of pull, we need to make sure every process has the latest.\u001b[39;00m\n\u001b[1;32m 553\u001b[0m \u001b[39mif\u001b[39;00m is_torch_tpu_available():\n",
|
694 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer.py:3449\u001b[0m, in \u001b[0;36mTrainer.init_git_repo\u001b[0;34m(self, at_init)\u001b[0m\n\u001b[1;32m 3446\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 3447\u001b[0m \u001b[39mraise\u001b[39;00m\n\u001b[0;32m-> 3449\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrepo\u001b[39m.\u001b[39;49mgit_pull()\n\u001b[1;32m 3451\u001b[0m \u001b[39m# By default, ignore the checkpoint folders\u001b[39;00m\n\u001b[1;32m 3452\u001b[0m \u001b[39mif\u001b[39;00m (\n\u001b[1;32m 3453\u001b[0m \u001b[39mnot\u001b[39;00m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mexists(os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39moutput_dir, \u001b[39m\"\u001b[39m\u001b[39m.gitignore\u001b[39m\u001b[39m\"\u001b[39m))\n\u001b[1;32m 3454\u001b[0m \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mhub_strategy \u001b[39m!=\u001b[39m HubStrategy\u001b[39m.\u001b[39mALL_CHECKPOINTS\n\u001b[1;32m 3455\u001b[0m ):\n",
|
695 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:987\u001b[0m, in \u001b[0;36mRepository.git_pull\u001b[0;34m(self, rebase, lfs)\u001b[0m\n\u001b[1;32m 985\u001b[0m logger\u001b[39m.\u001b[39minfo(result\u001b[39m.\u001b[39mstdout)\n\u001b[1;32m 986\u001b[0m \u001b[39mexcept\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError \u001b[39mas\u001b[39;00m exc:\n\u001b[0;32m--> 987\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(exc\u001b[39m.\u001b[39mstderr)\n",
|
696 |
+
"\u001b[0;31mOSError\u001b[0m: From https://huggingface.co/DuyTa/vi_whisper-small\n d7893fc..47c00b5 main -> origin/main\nhint: You have divergent branches and need to specify how to reconcile them.\nhint: You can do so by running one of the following commands sometime before\nhint: your next pull:\nhint: \nhint: git config pull.rebase false # merge (the default strategy)\nhint: git config pull.rebase true # rebase\nhint: git config pull.ff only # fast-forward only\nhint: \nhint: You can replace \"git config\" with \"git config --global\" to set a default\nhint: preference for all repositories. You can also pass --rebase, --no-rebase,\nhint: or --ff-only on the command line to override the configured default per\nhint: invocation.\nfatal: Need to specify how to reconcile divergent branches.\n"
|
697 |
+
]
|
698 |
+
}
|
699 |
+
],
|
700 |
+
"source": [
|
701 |
+
"from transformers import Seq2SeqTrainer\n",
|
702 |
+
"\n",
|
703 |
+
"trainer = Seq2SeqTrainer(\n",
|
704 |
+
" args=training_args,\n",
|
705 |
+
" model=model,\n",
|
706 |
+
" train_dataset=concat[\"train\"],\n",
|
707 |
+
"\n",
|
708 |
+
" eval_dataset=concat[\"test\"],\n",
|
709 |
+
" data_collator=data_collator,\n",
|
710 |
+
" compute_metrics=compute_metrics,\n",
|
711 |
+
" tokenizer=processor.feature_extractor,\n",
|
712 |
+
")\n"
|
713 |
+
]
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"cell_type": "code",
|
717 |
+
"execution_count": 130,
|
718 |
+
"metadata": {},
|
719 |
+
"outputs": [
|
720 |
+
{
|
721 |
+
"data": {
|
722 |
+
"text/plain": [
|
723 |
+
"('./vi_whisper-small/tokenizer_config.json',\n",
|
724 |
+
" './vi_whisper-small/special_tokens_map.json',\n",
|
725 |
+
" './vi_whisper-small/vocab.json',\n",
|
726 |
+
" './vi_whisper-small/merges.txt',\n",
|
727 |
+
" './vi_whisper-small/normalizer.json',\n",
|
728 |
+
" './vi_whisper-small/added_tokens.json',\n",
|
729 |
+
" './vi_whisper-small/tokenizer.json')"
|
730 |
+
]
|
731 |
+
},
|
732 |
+
"execution_count": 130,
|
733 |
+
"metadata": {},
|
734 |
+
"output_type": "execute_result"
|
735 |
+
}
|
736 |
+
],
|
737 |
+
"source": [
|
738 |
+
"tokenizer.save_pretrained(\"./vi_whisper-small/\")"
|
739 |
+
]
|
740 |
+
},
|
741 |
+
{
|
742 |
+
"cell_type": "code",
|
743 |
+
"execution_count": 31,
|
744 |
+
"metadata": {},
|
745 |
+
"outputs": [
|
746 |
+
{
|
747 |
+
"name": "stdout",
|
748 |
+
"output_type": "stream",
|
749 |
+
"text": [
|
750 |
+
"Device 0:\n",
|
751 |
+
" Currently allocated memory: 922.884765625 MB\n",
|
752 |
+
" Peak memory usage: 922.884765625 MB\n"
|
753 |
+
]
|
754 |
+
}
|
755 |
+
],
|
756 |
+
"source": [
|
757 |
+
"import torch\n",
|
758 |
+
"\n",
|
759 |
+
"device_count = torch.cuda.device_count()\n",
|
760 |
+
"\n",
|
761 |
+
"for device in range(device_count):\n",
|
762 |
+
" torch.cuda.device(device)\n",
|
763 |
+
" allocated_memory = torch.cuda.memory_allocated(device)\n",
|
764 |
+
" peak_memory = torch.cuda.max_memory_allocated(device)\n",
|
765 |
+
" print(f\"Device {device}:\")\n",
|
766 |
+
" print(f\" Currently allocated memory: {allocated_memory / 1024**2} MB\")\n",
|
767 |
+
" print(f\" Peak memory usage: {peak_memory / 1024**2} MB\")\n",
|
768 |
+
"\n"
|
769 |
+
]
|
770 |
+
},
|
771 |
+
{
|
772 |
+
"cell_type": "code",
|
773 |
+
"execution_count": 32,
|
774 |
+
"metadata": {},
|
775 |
+
"outputs": [
|
776 |
+
{
|
777 |
+
"name": "stdout",
|
778 |
+
"output_type": "stream",
|
779 |
+
"text": [
|
780 |
+
"Device 0:\n",
|
781 |
+
" Name: Tesla T4\n",
|
782 |
+
" Max Memory: 14966.375 MB\n"
|
783 |
+
]
|
784 |
+
}
|
785 |
+
],
|
786 |
+
"source": [
|
787 |
+
"device_count = torch.cuda.device_count()\n",
|
788 |
+
"\n",
|
789 |
+
"for device in range(device_count):\n",
|
790 |
+
" properties = torch.cuda.get_device_properties(device)\n",
|
791 |
+
" print(f\"Device {device}:\")\n",
|
792 |
+
" print(f\" Name: {properties.name}\")\n",
|
793 |
+
" print(f\" Max Memory: {properties.total_memory / 1024**2} MB\")\n"
|
794 |
+
]
|
795 |
+
},
|
796 |
+
{
|
797 |
+
"cell_type": "code",
|
798 |
+
"execution_count": 111,
|
799 |
+
"metadata": {},
|
800 |
+
"outputs": [
|
801 |
+
{
|
802 |
+
"name": "stderr",
|
803 |
+
"output_type": "stream",
|
804 |
+
"text": [
|
805 |
+
"/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
|
806 |
+
" warnings.warn(\n"
|
807 |
+
]
|
808 |
+
},
|
809 |
+
{
|
810 |
+
"data": {
|
811 |
+
"application/vnd.jupyter.widget-view+json": {
|
812 |
+
"model_id": "da5a536a979e4d34bc59eaede9204d06",
|
813 |
+
"version_major": 2,
|
814 |
+
"version_minor": 0
|
815 |
+
},
|
816 |
+
"text/plain": [
|
817 |
+
" 0%| | 0/8000 [00:00<?, ?it/s]"
|
818 |
+
]
|
819 |
+
},
|
820 |
+
"metadata": {},
|
821 |
+
"output_type": "display_data"
|
822 |
+
},
|
823 |
+
{
|
824 |
+
"name": "stdout",
|
825 |
+
"output_type": "stream",
|
826 |
+
"text": [
|
827 |
+
"{'loss': 3.8537, 'learning_rate': 2.1000000000000002e-06, 'epoch': 0.03}\n",
|
828 |
+
"{'loss': 2.2347, 'learning_rate': 4.6e-06, 'epoch': 0.06}\n",
|
829 |
+
"{'loss': 1.2627, 'learning_rate': 7.1e-06, 'epoch': 0.08}\n",
|
830 |
+
"{'loss': 0.8976, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.11}\n",
|
831 |
+
"{'loss': 0.7313, 'learning_rate': 1.2100000000000001e-05, 'epoch': 0.14}\n",
|
832 |
+
"{'loss': 0.6526, 'learning_rate': 1.4599999999999999e-05, 'epoch': 0.17}\n",
|
833 |
+
"{'loss': 0.7221, 'learning_rate': 1.7000000000000003e-05, 'epoch': 0.19}\n",
|
834 |
+
"{'loss': 0.6478, 'learning_rate': 1.9500000000000003e-05, 'epoch': 0.22}\n",
|
835 |
+
"{'loss': 1.7029, 'learning_rate': 2.19e-05, 'epoch': 0.25}\n",
|
836 |
+
"{'loss': 1.1476, 'learning_rate': 2.44e-05, 'epoch': 0.28}\n",
|
837 |
+
"{'loss': 0.5837, 'learning_rate': 2.6900000000000003e-05, 'epoch': 0.3}\n",
|
838 |
+
"{'loss': 0.5912, 'learning_rate': 2.94e-05, 'epoch': 0.33}\n",
|
839 |
+
"{'loss': 0.6872, 'learning_rate': 3.19e-05, 'epoch': 0.36}\n",
|
840 |
+
"{'loss': 0.4103, 'learning_rate': 3.4399999999999996e-05, 'epoch': 0.39}\n",
|
841 |
+
"{'loss': 0.4293, 'learning_rate': 3.69e-05, 'epoch': 0.41}\n",
|
842 |
+
"{'loss': 0.3055, 'learning_rate': 3.94e-05, 'epoch': 0.44}\n",
|
843 |
+
"{'loss': 0.311, 'learning_rate': 4.19e-05, 'epoch': 0.47}\n",
|
844 |
+
"{'loss': 0.3212, 'learning_rate': 4.44e-05, 'epoch': 0.5}\n",
|
845 |
+
"{'loss': 0.2917, 'learning_rate': 4.69e-05, 'epoch': 0.52}\n",
|
846 |
+
"{'loss': 0.2975, 'learning_rate': 4.94e-05, 'epoch': 0.55}\n",
|
847 |
+
"{'loss': 0.3254, 'learning_rate': 5.19e-05, 'epoch': 0.58}\n",
|
848 |
+
"{'loss': 0.2825, 'learning_rate': 5.440000000000001e-05, 'epoch': 0.61}\n",
|
849 |
+
"{'loss': 0.2929, 'learning_rate': 5.69e-05, 'epoch': 0.63}\n",
|
850 |
+
"{'loss': 0.3056, 'learning_rate': 5.94e-05, 'epoch': 0.66}\n",
|
851 |
+
"{'loss': 0.3105, 'learning_rate': 6.19e-05, 'epoch': 0.69}\n",
|
852 |
+
"{'loss': 0.3702, 'learning_rate': 6.440000000000001e-05, 'epoch': 0.72}\n",
|
853 |
+
"{'loss': 0.2684, 'learning_rate': 6.690000000000001e-05, 'epoch': 0.74}\n",
|
854 |
+
"{'loss': 0.2767, 'learning_rate': 6.939999999999999e-05, 'epoch': 0.77}\n",
|
855 |
+
"{'loss': 0.315, 'learning_rate': 7.19e-05, 'epoch': 0.8}\n",
|
856 |
+
"{'loss': 0.3132, 'learning_rate': 7.44e-05, 'epoch': 0.83}\n",
|
857 |
+
"{'loss': 0.3933, 'learning_rate': 7.69e-05, 'epoch': 0.85}\n",
|
858 |
+
"{'loss': 0.311, 'learning_rate': 7.94e-05, 'epoch': 0.88}\n",
|
859 |
+
"{'loss': 0.3104, 'learning_rate': 8.19e-05, 'epoch': 0.91}\n",
|
860 |
+
"{'loss': 0.297, 'learning_rate': 8.44e-05, 'epoch': 0.94}\n",
|
861 |
+
"{'loss': 0.3094, 'learning_rate': 8.69e-05, 'epoch': 0.96}\n",
|
862 |
+
"{'loss': 0.29, 'learning_rate': 8.94e-05, 'epoch': 0.99}\n",
|
863 |
+
"{'loss': 0.2712, 'learning_rate': 9.190000000000001e-05, 'epoch': 1.02}\n",
|
864 |
+
"{'loss': 0.262, 'learning_rate': 9.44e-05, 'epoch': 1.05}\n",
|
865 |
+
"{'loss': 0.2481, 'learning_rate': 9.69e-05, 'epoch': 1.07}\n",
|
866 |
+
"{'loss': 0.249, 'learning_rate': 9.94e-05, 'epoch': 1.1}\n"
|
867 |
+
]
|
868 |
+
},
|
869 |
+
{
|
870 |
+
"data": {
|
871 |
+
"application/vnd.jupyter.widget-view+json": {
|
872 |
+
"model_id": "ea1e7cf193cc4b5dacd0883200ff6ef6",
|
873 |
+
"version_major": 2,
|
874 |
+
"version_minor": 0
|
875 |
+
},
|
876 |
+
"text/plain": [
|
877 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
878 |
+
]
|
879 |
+
},
|
880 |
+
"metadata": {},
|
881 |
+
"output_type": "display_data"
|
882 |
+
},
|
883 |
+
{
|
884 |
+
"name": "stdout",
|
885 |
+
"output_type": "stream",
|
886 |
+
"text": [
|
887 |
+
"{'eval_loss': 0.3765707015991211, 'eval_wer': 32.16783216783217, 'eval_runtime': 349.1035, 'eval_samples_per_second': 2.177, 'eval_steps_per_second': 0.272, 'epoch': 1.1}\n",
|
888 |
+
"{'loss': 0.2729, 'learning_rate': 9.972857142857144e-05, 'epoch': 1.13}\n",
|
889 |
+
"{'loss': 0.267, 'learning_rate': 9.937142857142857e-05, 'epoch': 1.16}\n",
|
890 |
+
"{'loss': 0.2617, 'learning_rate': 9.901428571428571e-05, 'epoch': 1.18}\n",
|
891 |
+
"{'loss': 0.2613, 'learning_rate': 9.865714285714286e-05, 'epoch': 1.21}\n",
|
892 |
+
"{'loss': 0.2736, 'learning_rate': 9.83e-05, 'epoch': 1.24}\n",
|
893 |
+
"{'loss': 0.245, 'learning_rate': 9.794285714285714e-05, 'epoch': 1.27}\n",
|
894 |
+
"{'loss': 0.2385, 'learning_rate': 9.75857142857143e-05, 'epoch': 1.29}\n",
|
895 |
+
"{'loss': 0.258, 'learning_rate': 9.722857142857144e-05, 'epoch': 1.32}\n",
|
896 |
+
"{'loss': 0.2623, 'learning_rate': 9.687142857142858e-05, 'epoch': 1.35}\n",
|
897 |
+
"{'loss': 0.2346, 'learning_rate': 9.651428571428572e-05, 'epoch': 1.38}\n",
|
898 |
+
"{'loss': 0.2376, 'learning_rate': 9.615714285714286e-05, 'epoch': 1.4}\n",
|
899 |
+
"{'loss': 0.246, 'learning_rate': 9.58e-05, 'epoch': 1.43}\n",
|
900 |
+
"{'loss': 0.2201, 'learning_rate': 9.544285714285715e-05, 'epoch': 1.46}\n",
|
901 |
+
"{'loss': 0.2233, 'learning_rate': 9.508571428571429e-05, 'epoch': 1.49}\n",
|
902 |
+
"{'loss': 0.2154, 'learning_rate': 9.472857142857143e-05, 'epoch': 1.51}\n",
|
903 |
+
"{'loss': 0.2348, 'learning_rate': 9.437142857142857e-05, 'epoch': 1.54}\n",
|
904 |
+
"{'loss': 0.2159, 'learning_rate': 9.401428571428572e-05, 'epoch': 1.57}\n",
|
905 |
+
"{'loss': 0.2265, 'learning_rate': 9.365714285714286e-05, 'epoch': 1.6}\n",
|
906 |
+
"{'loss': 0.2118, 'learning_rate': 9.33e-05, 'epoch': 1.62}\n",
|
907 |
+
"{'loss': 0.2223, 'learning_rate': 9.294285714285714e-05, 'epoch': 1.65}\n",
|
908 |
+
"{'loss': 0.2, 'learning_rate': 9.258571428571428e-05, 'epoch': 1.68}\n",
|
909 |
+
"{'loss': 0.206, 'learning_rate': 9.222857142857142e-05, 'epoch': 1.71}\n",
|
910 |
+
"{'loss': 0.1979, 'learning_rate': 9.187142857142858e-05, 'epoch': 1.73}\n",
|
911 |
+
"{'loss': 0.2022, 'learning_rate': 9.151428571428572e-05, 'epoch': 1.76}\n",
|
912 |
+
"{'loss': 0.2028, 'learning_rate': 9.115714285714286e-05, 'epoch': 1.79}\n",
|
913 |
+
"{'loss': 0.2161, 'learning_rate': 9.080000000000001e-05, 'epoch': 1.82}\n",
|
914 |
+
"{'loss': 0.1964, 'learning_rate': 9.044285714285715e-05, 'epoch': 1.84}\n",
|
915 |
+
"{'loss': 0.2151, 'learning_rate': 9.008571428571429e-05, 'epoch': 1.87}\n",
|
916 |
+
"{'loss': 0.2056, 'learning_rate': 8.972857142857143e-05, 'epoch': 1.9}\n",
|
917 |
+
"{'loss': 0.189, 'learning_rate': 8.937142857142857e-05, 'epoch': 1.93}\n",
|
918 |
+
"{'loss': 0.1944, 'learning_rate': 8.901428571428571e-05, 'epoch': 1.95}\n",
|
919 |
+
"{'loss': 0.1834, 'learning_rate': 8.865714285714287e-05, 'epoch': 1.98}\n",
|
920 |
+
"{'loss': 0.1557, 'learning_rate': 8.83e-05, 'epoch': 2.01}\n",
|
921 |
+
"{'loss': 0.1337, 'learning_rate': 8.794285714285714e-05, 'epoch': 2.04}\n",
|
922 |
+
"{'loss': 0.1338, 'learning_rate': 8.75857142857143e-05, 'epoch': 2.06}\n",
|
923 |
+
"{'loss': 0.1338, 'learning_rate': 8.722857142857144e-05, 'epoch': 2.09}\n",
|
924 |
+
"{'loss': 0.1385, 'learning_rate': 8.687142857142856e-05, 'epoch': 2.12}\n",
|
925 |
+
"{'loss': 0.1259, 'learning_rate': 8.651428571428572e-05, 'epoch': 2.15}\n",
|
926 |
+
"{'loss': 0.1268, 'learning_rate': 8.615714285714286e-05, 'epoch': 2.18}\n",
|
927 |
+
"{'loss': 0.1416, 'learning_rate': 8.58e-05, 'epoch': 2.2}\n"
|
928 |
+
]
|
929 |
+
},
|
930 |
+
{
|
931 |
+
"data": {
|
932 |
+
"application/vnd.jupyter.widget-view+json": {
|
933 |
+
"model_id": "2b48a513542d43558d8fef6a8ef00629",
|
934 |
+
"version_major": 2,
|
935 |
+
"version_minor": 0
|
936 |
+
},
|
937 |
+
"text/plain": [
|
938 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
939 |
+
]
|
940 |
+
},
|
941 |
+
"metadata": {},
|
942 |
+
"output_type": "display_data"
|
943 |
+
},
|
944 |
+
{
|
945 |
+
"name": "stdout",
|
946 |
+
"output_type": "stream",
|
947 |
+
"text": [
|
948 |
+
"{'eval_loss': 0.2880653738975525, 'eval_wer': 46.464646464646464, 'eval_runtime': 337.0754, 'eval_samples_per_second': 2.255, 'eval_steps_per_second': 0.282, 'epoch': 2.2}\n",
|
949 |
+
"{'loss': 0.1271, 'learning_rate': 8.544285714285715e-05, 'epoch': 2.23}\n",
|
950 |
+
"{'loss': 0.1345, 'learning_rate': 8.508571428571429e-05, 'epoch': 2.26}\n",
|
951 |
+
"{'loss': 0.149, 'learning_rate': 8.472857142857143e-05, 'epoch': 2.29}\n",
|
952 |
+
"{'loss': 0.1289, 'learning_rate': 8.437142857142859e-05, 'epoch': 2.31}\n",
|
953 |
+
"{'loss': 0.1391, 'learning_rate': 8.401428571428573e-05, 'epoch': 2.34}\n",
|
954 |
+
"{'loss': 0.1532, 'learning_rate': 8.365714285714285e-05, 'epoch': 2.37}\n",
|
955 |
+
"{'loss': 0.1283, 'learning_rate': 8.33e-05, 'epoch': 2.4}\n",
|
956 |
+
"{'loss': 0.1336, 'learning_rate': 8.294285714285715e-05, 'epoch': 2.42}\n",
|
957 |
+
"{'loss': 0.129, 'learning_rate': 8.258571428571429e-05, 'epoch': 2.45}\n",
|
958 |
+
"{'loss': 0.1399, 'learning_rate': 8.222857142857144e-05, 'epoch': 2.48}\n",
|
959 |
+
"{'loss': 0.1411, 'learning_rate': 8.187142857142858e-05, 'epoch': 2.51}\n",
|
960 |
+
"{'loss': 0.1298, 'learning_rate': 8.151428571428572e-05, 'epoch': 2.53}\n",
|
961 |
+
"{'loss': 0.1397, 'learning_rate': 8.115714285714286e-05, 'epoch': 2.56}\n",
|
962 |
+
"{'loss': 0.1356, 'learning_rate': 8.080000000000001e-05, 'epoch': 2.59}\n",
|
963 |
+
"{'loss': 0.1366, 'learning_rate': 8.044285714285714e-05, 'epoch': 2.62}\n",
|
964 |
+
"{'loss': 0.1331, 'learning_rate': 8.008571428571429e-05, 'epoch': 2.64}\n",
|
965 |
+
"{'loss': 0.1297, 'learning_rate': 7.972857142857143e-05, 'epoch': 2.67}\n",
|
966 |
+
"{'loss': 0.1414, 'learning_rate': 7.937142857142857e-05, 'epoch': 2.7}\n",
|
967 |
+
"{'loss': 0.1189, 'learning_rate': 7.901428571428571e-05, 'epoch': 2.73}\n",
|
968 |
+
"{'loss': 0.1416, 'learning_rate': 7.865714285714287e-05, 'epoch': 2.75}\n",
|
969 |
+
"{'loss': 0.1378, 'learning_rate': 7.83e-05, 'epoch': 2.78}\n",
|
970 |
+
"{'loss': 0.1305, 'learning_rate': 7.794285714285715e-05, 'epoch': 2.81}\n",
|
971 |
+
"{'loss': 0.1571, 'learning_rate': 7.75857142857143e-05, 'epoch': 2.84}\n",
|
972 |
+
"{'loss': 0.1285, 'learning_rate': 7.722857142857143e-05, 'epoch': 2.86}\n",
|
973 |
+
"{'loss': 0.1339, 'learning_rate': 7.687142857142857e-05, 'epoch': 2.89}\n",
|
974 |
+
"{'loss': 0.1216, 'learning_rate': 7.651428571428572e-05, 'epoch': 2.92}\n",
|
975 |
+
"{'loss': 0.1321, 'learning_rate': 7.615714285714286e-05, 'epoch': 2.95}\n",
|
976 |
+
"{'loss': 0.1259, 'learning_rate': 7.58e-05, 'epoch': 2.97}\n",
|
977 |
+
"{'loss': 0.1259, 'learning_rate': 7.544285714285715e-05, 'epoch': 3.0}\n",
|
978 |
+
"{'loss': 0.0851, 'learning_rate': 7.508571428571429e-05, 'epoch': 3.03}\n",
|
979 |
+
"{'loss': 0.0764, 'learning_rate': 7.472857142857143e-05, 'epoch': 3.06}\n",
|
980 |
+
"{'loss': 0.0986, 'learning_rate': 7.438571428571429e-05, 'epoch': 3.08}\n",
|
981 |
+
"{'loss': 0.0883, 'learning_rate': 7.402857142857143e-05, 'epoch': 3.11}\n",
|
982 |
+
"{'loss': 0.0811, 'learning_rate': 7.367142857142858e-05, 'epoch': 3.14}\n",
|
983 |
+
"{'loss': 0.0872, 'learning_rate': 7.331428571428571e-05, 'epoch': 3.17}\n",
|
984 |
+
"{'loss': 0.0872, 'learning_rate': 7.295714285714286e-05, 'epoch': 3.19}\n",
|
985 |
+
"{'loss': 0.0805, 'learning_rate': 7.26e-05, 'epoch': 3.22}\n",
|
986 |
+
"{'loss': 0.0803, 'learning_rate': 7.224285714285714e-05, 'epoch': 3.25}\n",
|
987 |
+
"{'loss': 0.0753, 'learning_rate': 7.188571428571428e-05, 'epoch': 3.28}\n",
|
988 |
+
"{'loss': 0.0839, 'learning_rate': 7.152857142857144e-05, 'epoch': 3.3}\n"
|
989 |
+
]
|
990 |
+
},
|
991 |
+
{
|
992 |
+
"data": {
|
993 |
+
"application/vnd.jupyter.widget-view+json": {
|
994 |
+
"model_id": "ce95b9d2a270464fbeab22454f214a1d",
|
995 |
+
"version_major": 2,
|
996 |
+
"version_minor": 0
|
997 |
+
},
|
998 |
+
"text/plain": [
|
999 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
1000 |
+
]
|
1001 |
+
},
|
1002 |
+
"metadata": {},
|
1003 |
+
"output_type": "display_data"
|
1004 |
+
},
|
1005 |
+
{
|
1006 |
+
"name": "stdout",
|
1007 |
+
"output_type": "stream",
|
1008 |
+
"text": [
|
1009 |
+
"{'eval_loss': 0.279912531375885, 'eval_wer': 22.779072779072777, 'eval_runtime': 345.4945, 'eval_samples_per_second': 2.2, 'eval_steps_per_second': 0.275, 'epoch': 3.3}\n",
|
1010 |
+
"{'loss': 0.0885, 'learning_rate': 7.117142857142858e-05, 'epoch': 3.33}\n",
|
1011 |
+
"{'loss': 0.0845, 'learning_rate': 7.081428571428572e-05, 'epoch': 3.36}\n",
|
1012 |
+
"{'loss': 0.0761, 'learning_rate': 7.045714285714287e-05, 'epoch': 3.39}\n",
|
1013 |
+
"{'loss': 0.0756, 'learning_rate': 7.01e-05, 'epoch': 3.41}\n",
|
1014 |
+
"{'loss': 0.0859, 'learning_rate': 6.974285714285715e-05, 'epoch': 3.44}\n",
|
1015 |
+
"{'loss': 0.0972, 'learning_rate': 6.938571428571429e-05, 'epoch': 3.47}\n",
|
1016 |
+
"{'loss': 0.0822, 'learning_rate': 6.902857142857143e-05, 'epoch': 3.5}\n",
|
1017 |
+
"{'loss': 0.0892, 'learning_rate': 6.867142857142857e-05, 'epoch': 3.52}\n",
|
1018 |
+
"{'loss': 0.0735, 'learning_rate': 6.831428571428572e-05, 'epoch': 3.55}\n",
|
1019 |
+
"{'loss': 0.0893, 'learning_rate': 6.795714285714286e-05, 'epoch': 3.58}\n",
|
1020 |
+
"{'loss': 0.0869, 'learning_rate': 6.76e-05, 'epoch': 3.61}\n",
|
1021 |
+
"{'loss': 0.0877, 'learning_rate': 6.724285714285714e-05, 'epoch': 3.63}\n",
|
1022 |
+
"{'loss': 0.07, 'learning_rate': 6.688571428571428e-05, 'epoch': 3.66}\n",
|
1023 |
+
"{'loss': 0.0807, 'learning_rate': 6.652857142857142e-05, 'epoch': 3.69}\n",
|
1024 |
+
"{'loss': 0.0831, 'learning_rate': 6.617142857142858e-05, 'epoch': 3.72}\n",
|
1025 |
+
"{'loss': 0.0836, 'learning_rate': 6.581428571428572e-05, 'epoch': 3.74}\n",
|
1026 |
+
"{'loss': 0.0875, 'learning_rate': 6.545714285714286e-05, 'epoch': 3.77}\n",
|
1027 |
+
"{'loss': 0.0846, 'learning_rate': 6.510000000000001e-05, 'epoch': 3.8}\n",
|
1028 |
+
"{'loss': 0.0779, 'learning_rate': 6.474285714285715e-05, 'epoch': 3.83}\n",
|
1029 |
+
"{'loss': 0.0871, 'learning_rate': 6.438571428571429e-05, 'epoch': 3.85}\n",
|
1030 |
+
"{'loss': 0.0777, 'learning_rate': 6.402857142857143e-05, 'epoch': 3.88}\n",
|
1031 |
+
"{'loss': 0.0856, 'learning_rate': 6.367142857142857e-05, 'epoch': 3.91}\n",
|
1032 |
+
"{'loss': 0.083, 'learning_rate': 6.331428571428571e-05, 'epoch': 3.94}\n",
|
1033 |
+
"{'loss': 0.0667, 'learning_rate': 6.295714285714286e-05, 'epoch': 3.96}\n",
|
1034 |
+
"{'loss': 0.083, 'learning_rate': 6.26e-05, 'epoch': 3.99}\n",
|
1035 |
+
"{'loss': 0.0505, 'learning_rate': 6.224285714285714e-05, 'epoch': 4.02}\n",
|
1036 |
+
"{'loss': 0.0426, 'learning_rate': 6.18857142857143e-05, 'epoch': 4.05}\n",
|
1037 |
+
"{'loss': 0.0453, 'learning_rate': 6.152857142857144e-05, 'epoch': 4.07}\n",
|
1038 |
+
"{'loss': 0.0482, 'learning_rate': 6.117142857142858e-05, 'epoch': 4.1}\n",
|
1039 |
+
"{'loss': 0.0511, 'learning_rate': 6.081428571428571e-05, 'epoch': 4.13}\n",
|
1040 |
+
"{'loss': 0.0583, 'learning_rate': 6.045714285714286e-05, 'epoch': 4.16}\n",
|
1041 |
+
"{'loss': 0.0466, 'learning_rate': 6.0100000000000004e-05, 'epoch': 4.19}\n",
|
1042 |
+
"{'loss': 0.0502, 'learning_rate': 5.9742857142857144e-05, 'epoch': 4.21}\n",
|
1043 |
+
"{'loss': 0.0414, 'learning_rate': 5.938571428571429e-05, 'epoch': 4.24}\n",
|
1044 |
+
"{'loss': 0.0501, 'learning_rate': 5.902857142857143e-05, 'epoch': 4.27}\n",
|
1045 |
+
"{'loss': 0.0478, 'learning_rate': 5.867142857142858e-05, 'epoch': 4.3}\n",
|
1046 |
+
"{'loss': 0.0482, 'learning_rate': 5.8314285714285724e-05, 'epoch': 4.32}\n",
|
1047 |
+
"{'loss': 0.0463, 'learning_rate': 5.7957142857142864e-05, 'epoch': 4.35}\n",
|
1048 |
+
"{'loss': 0.0513, 'learning_rate': 5.76e-05, 'epoch': 4.38}\n",
|
1049 |
+
"{'loss': 0.0546, 'learning_rate': 5.7242857142857144e-05, 'epoch': 4.41}\n"
|
1050 |
+
]
|
1051 |
+
},
|
1052 |
+
{
|
1053 |
+
"data": {
|
1054 |
+
"application/vnd.jupyter.widget-view+json": {
|
1055 |
+
"model_id": "5b1ec44b408a4de8aa8a3e9702dae453",
|
1056 |
+
"version_major": 2,
|
1057 |
+
"version_minor": 0
|
1058 |
+
},
|
1059 |
+
"text/plain": [
|
1060 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
1061 |
+
]
|
1062 |
+
},
|
1063 |
+
"metadata": {},
|
1064 |
+
"output_type": "display_data"
|
1065 |
+
},
|
1066 |
+
{
|
1067 |
+
"name": "stdout",
|
1068 |
+
"output_type": "stream",
|
1069 |
+
"text": [
|
1070 |
+
"{'eval_loss': 0.28944167494773865, 'eval_wer': 21.885521885521886, 'eval_runtime': 344.5818, 'eval_samples_per_second': 2.206, 'eval_steps_per_second': 0.276, 'epoch': 4.41}\n",
|
1071 |
+
"{'loss': 0.0515, 'learning_rate': 5.6885714285714284e-05, 'epoch': 4.43}\n",
|
1072 |
+
"{'loss': 0.0394, 'learning_rate': 5.652857142857143e-05, 'epoch': 4.46}\n",
|
1073 |
+
"{'loss': 0.0562, 'learning_rate': 5.617142857142858e-05, 'epoch': 4.49}\n",
|
1074 |
+
"{'loss': 0.0532, 'learning_rate': 5.581428571428572e-05, 'epoch': 4.52}\n",
|
1075 |
+
"{'loss': 0.0525, 'learning_rate': 5.5457142857142864e-05, 'epoch': 4.54}\n",
|
1076 |
+
"{'loss': 0.0553, 'learning_rate': 5.5100000000000004e-05, 'epoch': 4.57}\n",
|
1077 |
+
"{'loss': 0.0464, 'learning_rate': 5.474285714285714e-05, 'epoch': 4.6}\n",
|
1078 |
+
"{'loss': 0.0425, 'learning_rate': 5.4385714285714284e-05, 'epoch': 4.63}\n",
|
1079 |
+
"{'loss': 0.0529, 'learning_rate': 5.402857142857143e-05, 'epoch': 4.65}\n",
|
1080 |
+
"{'loss': 0.0534, 'learning_rate': 5.367142857142857e-05, 'epoch': 4.68}\n",
|
1081 |
+
"{'loss': 0.0505, 'learning_rate': 5.331428571428572e-05, 'epoch': 4.71}\n",
|
1082 |
+
"{'loss': 0.0416, 'learning_rate': 5.295714285714286e-05, 'epoch': 4.74}\n",
|
1083 |
+
"{'loss': 0.0438, 'learning_rate': 5.2600000000000005e-05, 'epoch': 4.76}\n",
|
1084 |
+
"{'loss': 0.0568, 'learning_rate': 5.224285714285715e-05, 'epoch': 4.79}\n",
|
1085 |
+
"{'loss': 0.0519, 'learning_rate': 5.188571428571429e-05, 'epoch': 4.82}\n",
|
1086 |
+
"{'loss': 0.0415, 'learning_rate': 5.1528571428571425e-05, 'epoch': 4.85}\n",
|
1087 |
+
"{'loss': 0.0502, 'learning_rate': 5.117142857142857e-05, 'epoch': 4.87}\n",
|
1088 |
+
"{'loss': 0.0433, 'learning_rate': 5.081428571428571e-05, 'epoch': 4.9}\n",
|
1089 |
+
"{'loss': 0.0527, 'learning_rate': 5.045714285714286e-05, 'epoch': 4.93}\n",
|
1090 |
+
"{'loss': 0.0434, 'learning_rate': 5.0100000000000005e-05, 'epoch': 4.96}\n",
|
1091 |
+
"{'loss': 0.0485, 'learning_rate': 4.9742857142857145e-05, 'epoch': 4.98}\n",
|
1092 |
+
"{'loss': 0.0358, 'learning_rate': 4.938571428571429e-05, 'epoch': 5.01}\n",
|
1093 |
+
"{'loss': 0.0218, 'learning_rate': 4.902857142857143e-05, 'epoch': 5.04}\n",
|
1094 |
+
"{'loss': 0.0245, 'learning_rate': 4.867142857142857e-05, 'epoch': 5.07}\n",
|
1095 |
+
"{'loss': 0.0272, 'learning_rate': 4.831428571428572e-05, 'epoch': 5.09}\n",
|
1096 |
+
"{'loss': 0.0258, 'learning_rate': 4.795714285714286e-05, 'epoch': 5.12}\n",
|
1097 |
+
"{'loss': 0.0228, 'learning_rate': 4.76e-05, 'epoch': 5.15}\n",
|
1098 |
+
"{'loss': 0.0275, 'learning_rate': 4.7242857142857145e-05, 'epoch': 5.18}\n",
|
1099 |
+
"{'loss': 0.0269, 'learning_rate': 4.6885714285714285e-05, 'epoch': 5.2}\n",
|
1100 |
+
"{'loss': 0.0237, 'learning_rate': 4.652857142857143e-05, 'epoch': 5.23}\n",
|
1101 |
+
"{'loss': 0.0288, 'learning_rate': 4.617142857142857e-05, 'epoch': 5.26}\n",
|
1102 |
+
"{'loss': 0.0269, 'learning_rate': 4.581428571428572e-05, 'epoch': 5.29}\n",
|
1103 |
+
"{'loss': 0.0276, 'learning_rate': 4.545714285714286e-05, 'epoch': 5.31}\n",
|
1104 |
+
"{'loss': 0.0276, 'learning_rate': 4.5100000000000005e-05, 'epoch': 5.34}\n",
|
1105 |
+
"{'loss': 0.0242, 'learning_rate': 4.4742857142857145e-05, 'epoch': 5.37}\n",
|
1106 |
+
"{'loss': 0.0238, 'learning_rate': 4.4385714285714285e-05, 'epoch': 5.4}\n",
|
1107 |
+
"{'loss': 0.0302, 'learning_rate': 4.402857142857143e-05, 'epoch': 5.42}\n",
|
1108 |
+
"{'loss': 0.0253, 'learning_rate': 4.367142857142857e-05, 'epoch': 5.45}\n",
|
1109 |
+
"{'loss': 0.0256, 'learning_rate': 4.331428571428572e-05, 'epoch': 5.48}\n",
|
1110 |
+
"{'loss': 0.0256, 'learning_rate': 4.295714285714286e-05, 'epoch': 5.51}\n"
|
1111 |
+
]
|
1112 |
+
},
|
1113 |
+
{
|
1114 |
+
"data": {
|
1115 |
+
"application/vnd.jupyter.widget-view+json": {
|
1116 |
+
"model_id": "2a3e7ace41cd44768ebc093aa571360e",
|
1117 |
+
"version_major": 2,
|
1118 |
+
"version_minor": 0
|
1119 |
+
},
|
1120 |
+
"text/plain": [
|
1121 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
1122 |
+
]
|
1123 |
+
},
|
1124 |
+
"metadata": {},
|
1125 |
+
"output_type": "display_data"
|
1126 |
+
},
|
1127 |
+
{
|
1128 |
+
"name": "stdout",
|
1129 |
+
"output_type": "stream",
|
1130 |
+
"text": [
|
1131 |
+
"{'eval_loss': 0.3023395836353302, 'eval_wer': 32.2973322973323, 'eval_runtime': 361.3589, 'eval_samples_per_second': 2.103, 'eval_steps_per_second': 0.263, 'epoch': 5.51}\n",
|
1132 |
+
"{'loss': 0.0215, 'learning_rate': 4.26e-05, 'epoch': 5.53}\n",
|
1133 |
+
"{'loss': 0.0272, 'learning_rate': 4.2242857142857145e-05, 'epoch': 5.56}\n",
|
1134 |
+
"{'loss': 0.0268, 'learning_rate': 4.188571428571429e-05, 'epoch': 5.59}\n",
|
1135 |
+
"{'loss': 0.028, 'learning_rate': 4.1528571428571425e-05, 'epoch': 5.62}\n",
|
1136 |
+
"{'loss': 0.0209, 'learning_rate': 4.117142857142857e-05, 'epoch': 5.64}\n",
|
1137 |
+
"{'loss': 0.0258, 'learning_rate': 4.081428571428572e-05, 'epoch': 5.67}\n",
|
1138 |
+
"{'loss': 0.0249, 'learning_rate': 4.045714285714286e-05, 'epoch': 5.7}\n",
|
1139 |
+
"{'loss': 0.0249, 'learning_rate': 4.0100000000000006e-05, 'epoch': 5.73}\n",
|
1140 |
+
"{'loss': 0.0209, 'learning_rate': 3.9742857142857146e-05, 'epoch': 5.75}\n",
|
1141 |
+
"{'loss': 0.02, 'learning_rate': 3.9385714285714286e-05, 'epoch': 5.78}\n",
|
1142 |
+
"{'loss': 0.0244, 'learning_rate': 3.902857142857143e-05, 'epoch': 5.81}\n",
|
1143 |
+
"{'loss': 0.025, 'learning_rate': 3.867142857142857e-05, 'epoch': 5.84}\n",
|
1144 |
+
"{'loss': 0.0282, 'learning_rate': 3.831428571428571e-05, 'epoch': 5.86}\n",
|
1145 |
+
"{'loss': 0.0271, 'learning_rate': 3.795714285714286e-05, 'epoch': 5.89}\n",
|
1146 |
+
"{'loss': 0.0233, 'learning_rate': 3.76e-05, 'epoch': 5.92}\n",
|
1147 |
+
"{'loss': 0.0219, 'learning_rate': 3.7242857142857146e-05, 'epoch': 5.95}\n",
|
1148 |
+
"{'loss': 0.0232, 'learning_rate': 3.688571428571429e-05, 'epoch': 5.97}\n",
|
1149 |
+
"{'loss': 0.019, 'learning_rate': 3.6528571428571426e-05, 'epoch': 6.0}\n",
|
1150 |
+
"{'loss': 0.0152, 'learning_rate': 3.617142857142857e-05, 'epoch': 6.03}\n",
|
1151 |
+
"{'loss': 0.0111, 'learning_rate': 3.581428571428572e-05, 'epoch': 6.06}\n",
|
1152 |
+
"{'loss': 0.0162, 'learning_rate': 3.545714285714286e-05, 'epoch': 6.08}\n",
|
1153 |
+
"{'loss': 0.0126, 'learning_rate': 3.51e-05, 'epoch': 6.11}\n",
|
1154 |
+
"{'loss': 0.012, 'learning_rate': 3.4742857142857146e-05, 'epoch': 6.14}\n",
|
1155 |
+
"{'loss': 0.0153, 'learning_rate': 3.4385714285714286e-05, 'epoch': 6.17}\n",
|
1156 |
+
"{'loss': 0.0133, 'learning_rate': 3.402857142857143e-05, 'epoch': 6.19}\n",
|
1157 |
+
"{'loss': 0.0112, 'learning_rate': 3.367142857142857e-05, 'epoch': 6.22}\n",
|
1158 |
+
"{'loss': 0.0187, 'learning_rate': 3.331428571428571e-05, 'epoch': 6.25}\n",
|
1159 |
+
"{'loss': 0.0134, 'learning_rate': 3.295714285714286e-05, 'epoch': 6.28}\n",
|
1160 |
+
"{'loss': 0.0112, 'learning_rate': 3.26e-05, 'epoch': 6.31}\n",
|
1161 |
+
"{'loss': 0.0096, 'learning_rate': 3.2242857142857146e-05, 'epoch': 6.33}\n",
|
1162 |
+
"{'loss': 0.0112, 'learning_rate': 3.1885714285714286e-05, 'epoch': 6.36}\n",
|
1163 |
+
"{'loss': 0.0146, 'learning_rate': 3.1528571428571426e-05, 'epoch': 6.39}\n",
|
1164 |
+
"{'loss': 0.0106, 'learning_rate': 3.117142857142857e-05, 'epoch': 6.42}\n",
|
1165 |
+
"{'loss': 0.01, 'learning_rate': 3.081428571428572e-05, 'epoch': 6.44}\n",
|
1166 |
+
"{'loss': 0.0117, 'learning_rate': 3.0457142857142856e-05, 'epoch': 6.47}\n",
|
1167 |
+
"{'loss': 0.0135, 'learning_rate': 3.01e-05, 'epoch': 6.5}\n",
|
1168 |
+
"{'loss': 0.0137, 'learning_rate': 2.9742857142857143e-05, 'epoch': 6.53}\n",
|
1169 |
+
"{'loss': 0.0089, 'learning_rate': 2.938571428571429e-05, 'epoch': 6.55}\n",
|
1170 |
+
"{'loss': 0.0096, 'learning_rate': 2.9028571428571427e-05, 'epoch': 6.58}\n",
|
1171 |
+
"{'loss': 0.0111, 'learning_rate': 2.867142857142857e-05, 'epoch': 6.61}\n"
|
1172 |
+
]
|
1173 |
+
},
|
1174 |
+
{
|
1175 |
+
"data": {
|
1176 |
+
"application/vnd.jupyter.widget-view+json": {
|
1177 |
+
"model_id": "a911edc03d0b43edbd4c59958024dca9",
|
1178 |
+
"version_major": 2,
|
1179 |
+
"version_minor": 0
|
1180 |
+
},
|
1181 |
+
"text/plain": [
|
1182 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
1183 |
+
]
|
1184 |
+
},
|
1185 |
+
"metadata": {},
|
1186 |
+
"output_type": "display_data"
|
1187 |
+
},
|
1188 |
+
{
|
1189 |
+
"name": "stdout",
|
1190 |
+
"output_type": "stream",
|
1191 |
+
"text": [
|
1192 |
+
"{'eval_loss': 0.3060542345046997, 'eval_wer': 31.015281015281015, 'eval_runtime': 366.4932, 'eval_samples_per_second': 2.074, 'eval_steps_per_second': 0.259, 'epoch': 6.61}\n",
|
1193 |
+
"{'loss': 0.0091, 'learning_rate': 2.8314285714285717e-05, 'epoch': 6.64}\n",
|
1194 |
+
"{'loss': 0.0075, 'learning_rate': 2.795714285714286e-05, 'epoch': 6.66}\n",
|
1195 |
+
"{'loss': 0.0096, 'learning_rate': 2.7600000000000003e-05, 'epoch': 6.69}\n",
|
1196 |
+
"{'loss': 0.0071, 'learning_rate': 2.7242857142857143e-05, 'epoch': 6.72}\n",
|
1197 |
+
"{'loss': 0.0089, 'learning_rate': 2.6885714285714287e-05, 'epoch': 6.75}\n",
|
1198 |
+
"{'loss': 0.0103, 'learning_rate': 2.652857142857143e-05, 'epoch': 6.77}\n",
|
1199 |
+
"{'loss': 0.0125, 'learning_rate': 2.6171428571428574e-05, 'epoch': 6.8}\n",
|
1200 |
+
"{'loss': 0.0082, 'learning_rate': 2.5814285714285713e-05, 'epoch': 6.83}\n",
|
1201 |
+
"{'loss': 0.0079, 'learning_rate': 2.5457142857142857e-05, 'epoch': 6.86}\n",
|
1202 |
+
"{'loss': 0.0108, 'learning_rate': 2.51e-05, 'epoch': 6.88}\n",
|
1203 |
+
"{'loss': 0.0084, 'learning_rate': 2.4742857142857147e-05, 'epoch': 6.91}\n",
|
1204 |
+
"{'loss': 0.0107, 'learning_rate': 2.4385714285714287e-05, 'epoch': 6.94}\n",
|
1205 |
+
"{'loss': 0.009, 'learning_rate': 2.402857142857143e-05, 'epoch': 6.97}\n",
|
1206 |
+
"{'loss': 0.0081, 'learning_rate': 2.3671428571428574e-05, 'epoch': 6.99}\n",
|
1207 |
+
"{'loss': 0.0077, 'learning_rate': 2.3314285714285717e-05, 'epoch': 7.02}\n",
|
1208 |
+
"{'loss': 0.0064, 'learning_rate': 2.2957142857142857e-05, 'epoch': 7.05}\n",
|
1209 |
+
"{'loss': 0.0079, 'learning_rate': 2.26e-05, 'epoch': 7.08}\n",
|
1210 |
+
"{'loss': 0.0063, 'learning_rate': 2.2242857142857144e-05, 'epoch': 7.1}\n",
|
1211 |
+
"{'loss': 0.0044, 'learning_rate': 2.1885714285714287e-05, 'epoch': 7.13}\n",
|
1212 |
+
"{'loss': 0.0041, 'learning_rate': 2.1528571428571427e-05, 'epoch': 7.16}\n",
|
1213 |
+
"{'loss': 0.0048, 'learning_rate': 2.1171428571428574e-05, 'epoch': 7.19}\n",
|
1214 |
+
"{'loss': 0.0041, 'learning_rate': 2.0814285714285714e-05, 'epoch': 7.21}\n",
|
1215 |
+
"{'loss': 0.0031, 'learning_rate': 2.0457142857142857e-05, 'epoch': 7.24}\n",
|
1216 |
+
"{'loss': 0.0026, 'learning_rate': 2.01e-05, 'epoch': 7.27}\n",
|
1217 |
+
"{'loss': 0.0031, 'learning_rate': 1.9742857142857144e-05, 'epoch': 7.3}\n",
|
1218 |
+
"{'loss': 0.0029, 'learning_rate': 1.9385714285714287e-05, 'epoch': 7.32}\n",
|
1219 |
+
"{'loss': 0.0045, 'learning_rate': 1.9028571428571427e-05, 'epoch': 7.35}\n",
|
1220 |
+
"{'loss': 0.0024, 'learning_rate': 1.8671428571428574e-05, 'epoch': 7.38}\n",
|
1221 |
+
"{'loss': 0.002, 'learning_rate': 1.8314285714285714e-05, 'epoch': 7.41}\n",
|
1222 |
+
"{'loss': 0.0038, 'learning_rate': 1.7957142857142858e-05, 'epoch': 7.43}\n",
|
1223 |
+
"{'loss': 0.0035, 'learning_rate': 1.76e-05, 'epoch': 7.46}\n",
|
1224 |
+
"{'loss': 0.0058, 'learning_rate': 1.7242857142857144e-05, 'epoch': 7.49}\n",
|
1225 |
+
"{'loss': 0.0034, 'learning_rate': 1.6885714285714284e-05, 'epoch': 7.52}\n",
|
1226 |
+
"{'loss': 0.0036, 'learning_rate': 1.652857142857143e-05, 'epoch': 7.54}\n",
|
1227 |
+
"{'loss': 0.0031, 'learning_rate': 1.6171428571428574e-05, 'epoch': 7.57}\n",
|
1228 |
+
"{'loss': 0.0041, 'learning_rate': 1.5814285714285714e-05, 'epoch': 7.6}\n",
|
1229 |
+
"{'loss': 0.0021, 'learning_rate': 1.5457142857142858e-05, 'epoch': 7.63}\n",
|
1230 |
+
"{'loss': 0.0032, 'learning_rate': 1.51e-05, 'epoch': 7.65}\n",
|
1231 |
+
"{'loss': 0.0039, 'learning_rate': 1.4742857142857144e-05, 'epoch': 7.68}\n",
|
1232 |
+
"{'loss': 0.0028, 'learning_rate': 1.4385714285714286e-05, 'epoch': 7.71}\n"
|
1233 |
+
]
|
1234 |
+
},
|
1235 |
+
{
|
1236 |
+
"data": {
|
1237 |
+
"application/vnd.jupyter.widget-view+json": {
|
1238 |
+
"model_id": "3ff56df660d1427daf8f920cd57d67fc",
|
1239 |
+
"version_major": 2,
|
1240 |
+
"version_minor": 0
|
1241 |
+
},
|
1242 |
+
"text/plain": [
|
1243 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
1244 |
+
]
|
1245 |
+
},
|
1246 |
+
"metadata": {},
|
1247 |
+
"output_type": "display_data"
|
1248 |
+
},
|
1249 |
+
{
|
1250 |
+
"name": "stdout",
|
1251 |
+
"output_type": "stream",
|
1252 |
+
"text": [
|
1253 |
+
"{'eval_loss': 0.3143082559108734, 'eval_wer': 27.169127169127172, 'eval_runtime': 357.6908, 'eval_samples_per_second': 2.125, 'eval_steps_per_second': 0.266, 'epoch': 7.71}\n",
|
1254 |
+
"{'loss': 0.0032, 'learning_rate': 1.402857142857143e-05, 'epoch': 7.74}\n",
|
1255 |
+
"{'loss': 0.0028, 'learning_rate': 1.3671428571428571e-05, 'epoch': 7.76}\n",
|
1256 |
+
"{'loss': 0.0033, 'learning_rate': 1.3314285714285715e-05, 'epoch': 7.79}\n",
|
1257 |
+
"{'loss': 0.0052, 'learning_rate': 1.2957142857142856e-05, 'epoch': 7.82}\n",
|
1258 |
+
"{'loss': 0.0023, 'learning_rate': 1.2600000000000001e-05, 'epoch': 7.85}\n",
|
1259 |
+
"{'loss': 0.0034, 'learning_rate': 1.2242857142857143e-05, 'epoch': 7.87}\n",
|
1260 |
+
"{'loss': 0.0026, 'learning_rate': 1.1885714285714286e-05, 'epoch': 7.9}\n",
|
1261 |
+
"{'loss': 0.0022, 'learning_rate': 1.1528571428571428e-05, 'epoch': 7.93}\n",
|
1262 |
+
"{'loss': 0.0037, 'learning_rate': 1.1171428571428571e-05, 'epoch': 7.96}\n",
|
1263 |
+
"{'loss': 0.003, 'learning_rate': 1.0814285714285715e-05, 'epoch': 7.98}\n",
|
1264 |
+
"{'loss': 0.0009, 'learning_rate': 1.0457142857142856e-05, 'epoch': 8.01}\n",
|
1265 |
+
"{'loss': 0.0006, 'learning_rate': 1.0100000000000002e-05, 'epoch': 8.04}\n",
|
1266 |
+
"{'loss': 0.0022, 'learning_rate': 9.742857142857143e-06, 'epoch': 8.07}\n",
|
1267 |
+
"{'loss': 0.0005, 'learning_rate': 9.385714285714287e-06, 'epoch': 8.09}\n",
|
1268 |
+
"{'loss': 0.0007, 'learning_rate': 9.02857142857143e-06, 'epoch': 8.12}\n",
|
1269 |
+
"{'loss': 0.0006, 'learning_rate': 8.671428571428572e-06, 'epoch': 8.15}\n",
|
1270 |
+
"{'loss': 0.0006, 'learning_rate': 8.314285714285715e-06, 'epoch': 8.18}\n",
|
1271 |
+
"{'loss': 0.0005, 'learning_rate': 7.957142857142857e-06, 'epoch': 8.2}\n",
|
1272 |
+
"{'loss': 0.0024, 'learning_rate': 7.6e-06, 'epoch': 8.23}\n",
|
1273 |
+
"{'loss': 0.0009, 'learning_rate': 7.242857142857143e-06, 'epoch': 8.26}\n",
|
1274 |
+
"{'loss': 0.0004, 'learning_rate': 6.885714285714286e-06, 'epoch': 8.29}\n",
|
1275 |
+
"{'loss': 0.0006, 'learning_rate': 6.5285714285714285e-06, 'epoch': 8.31}\n",
|
1276 |
+
"{'loss': 0.0031, 'learning_rate': 6.171428571428572e-06, 'epoch': 8.34}\n",
|
1277 |
+
"{'loss': 0.001, 'learning_rate': 5.814285714285714e-06, 'epoch': 8.37}\n",
|
1278 |
+
"{'loss': 0.0013, 'learning_rate': 5.457142857142857e-06, 'epoch': 8.4}\n",
|
1279 |
+
"{'loss': 0.0011, 'learning_rate': 5.1e-06, 'epoch': 8.43}\n",
|
1280 |
+
"{'loss': 0.0004, 'learning_rate': 4.742857142857144e-06, 'epoch': 8.45}\n",
|
1281 |
+
"{'loss': 0.0012, 'learning_rate': 4.385714285714286e-06, 'epoch': 8.48}\n",
|
1282 |
+
"{'loss': 0.0013, 'learning_rate': 4.028571428571429e-06, 'epoch': 8.51}\n",
|
1283 |
+
"{'loss': 0.0008, 'learning_rate': 3.6714285714285717e-06, 'epoch': 8.54}\n",
|
1284 |
+
"{'loss': 0.0009, 'learning_rate': 3.314285714285714e-06, 'epoch': 8.56}\n",
|
1285 |
+
"{'loss': 0.0019, 'learning_rate': 2.957142857142857e-06, 'epoch': 8.59}\n",
|
1286 |
+
"{'loss': 0.0006, 'learning_rate': 2.6e-06, 'epoch': 8.62}\n",
|
1287 |
+
"{'loss': 0.0004, 'learning_rate': 2.242857142857143e-06, 'epoch': 8.65}\n",
|
1288 |
+
"{'loss': 0.0005, 'learning_rate': 1.8857142857142858e-06, 'epoch': 8.67}\n",
|
1289 |
+
"{'loss': 0.0005, 'learning_rate': 1.5285714285714287e-06, 'epoch': 8.7}\n",
|
1290 |
+
"{'loss': 0.001, 'learning_rate': 1.1714285714285715e-06, 'epoch': 8.73}\n",
|
1291 |
+
"{'loss': 0.0012, 'learning_rate': 8.142857142857143e-07, 'epoch': 8.76}\n",
|
1292 |
+
"{'loss': 0.001, 'learning_rate': 4.571428571428572e-07, 'epoch': 8.78}\n",
|
1293 |
+
"{'loss': 0.0014, 'learning_rate': 1.0000000000000001e-07, 'epoch': 8.81}\n"
|
1294 |
+
]
|
1295 |
+
},
|
1296 |
+
{
|
1297 |
+
"data": {
|
1298 |
+
"application/vnd.jupyter.widget-view+json": {
|
1299 |
+
"model_id": "5120f3a800554c1689dfb561c55440fb",
|
1300 |
+
"version_major": 2,
|
1301 |
+
"version_minor": 0
|
1302 |
+
},
|
1303 |
+
"text/plain": [
|
1304 |
+
" 0%| | 0/95 [00:00<?, ?it/s]"
|
1305 |
+
]
|
1306 |
+
},
|
1307 |
+
"metadata": {},
|
1308 |
+
"output_type": "display_data"
|
1309 |
+
},
|
1310 |
+
{
|
1311 |
+
"name": "stdout",
|
1312 |
+
"output_type": "stream",
|
1313 |
+
"text": [
|
1314 |
+
"{'eval_loss': 0.318661630153656, 'eval_wer': 27.36337736337736, 'eval_runtime': 356.3989, 'eval_samples_per_second': 2.132, 'eval_steps_per_second': 0.267, 'epoch': 8.81}\n",
|
1315 |
+
"{'train_runtime': 48216.7702, 'train_samples_per_second': 2.655, 'train_steps_per_second': 0.166, 'train_loss': 0.13303363310021815, 'epoch': 8.81}\n"
|
1316 |
+
]
|
1317 |
+
},
|
1318 |
+
{
|
1319 |
+
"data": {
|
1320 |
+
"text/plain": [
|
1321 |
+
"TrainOutput(global_step=8000, training_loss=0.13303363310021815, metrics={'train_runtime': 48216.7702, 'train_samples_per_second': 2.655, 'train_steps_per_second': 0.166, 'train_loss': 0.13303363310021815, 'epoch': 8.81})"
|
1322 |
+
]
|
1323 |
+
},
|
1324 |
+
"execution_count": 111,
|
1325 |
+
"metadata": {},
|
1326 |
+
"output_type": "execute_result"
|
1327 |
+
}
|
1328 |
+
],
|
1329 |
+
"source": [
|
1330 |
+
"trainer.train()"
|
1331 |
+
]
|
1332 |
+
},
|
1333 |
+
{
|
1334 |
+
"cell_type": "code",
|
1335 |
+
"execution_count": 34,
|
1336 |
+
"metadata": {},
|
1337 |
+
"outputs": [
|
1338 |
+
{
|
1339 |
+
"ename": "OSError",
|
1340 |
+
"evalue": "It looks like the config file at './whisper-base-vi/pytorch_model.bin' is not a valid JSON file.",
|
1341 |
+
"output_type": "error",
|
1342 |
+
"traceback": [
|
1343 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1344 |
+
"\u001b[0;31mUnicodeDecodeError\u001b[0m Traceback (most recent call last)",
|
1345 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:702\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 700\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 701\u001b[0m \u001b[39m# Load config dict\u001b[39;00m\n\u001b[0;32m--> 702\u001b[0m config_dict \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_dict_from_json_file(resolved_config_file)\n\u001b[1;32m 703\u001b[0m config_dict[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m commit_hash\n",
|
1346 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:793\u001b[0m, in \u001b[0;36mPretrainedConfig._dict_from_json_file\u001b[0;34m(cls, json_file)\u001b[0m\n\u001b[1;32m 792\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39m(json_file, \u001b[39m\"\u001b[39m\u001b[39mr\u001b[39m\u001b[39m\"\u001b[39m, encoding\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mas\u001b[39;00m reader:\n\u001b[0;32m--> 793\u001b[0m text \u001b[39m=\u001b[39m reader\u001b[39m.\u001b[39;49mread()\n\u001b[1;32m 794\u001b[0m \u001b[39mreturn\u001b[39;00m json\u001b[39m.\u001b[39mloads(text)\n",
|
1347 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/codecs.py:322\u001b[0m, in \u001b[0;36mBufferedIncrementalDecoder.decode\u001b[0;34m(self, input, final)\u001b[0m\n\u001b[1;32m 321\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mbuffer \u001b[39m+\u001b[39m \u001b[39minput\u001b[39m\n\u001b[0;32m--> 322\u001b[0m (result, consumed) \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_buffer_decode(data, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49merrors, final)\n\u001b[1;32m 323\u001b[0m \u001b[39m# keep undecoded input until the next call\u001b[39;00m\n",
|
1348 |
+
"\u001b[0;31mUnicodeDecodeError\u001b[0m: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte",
|
1349 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
1350 |
+
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
1351 |
+
"Cell \u001b[0;32mIn[34], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m pt_model \u001b[39m=\u001b[39m WhisperForConditionalGeneration\u001b[39m.\u001b[39;49mfrom_pretrained(\u001b[39m\"\u001b[39;49m\u001b[39m./whisper-base-vi/pytorch_model.bin\u001b[39;49m\u001b[39m\"\u001b[39;49m, from_tf\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n\u001b[1;32m 2\u001b[0m pt_model\u001b[39m.\u001b[39msave_pretrained(\u001b[39m\"\u001b[39m\u001b[39m./whisper-base-vi/vi_whisper.pt\u001b[39m\u001b[39m\"\u001b[39m)\n",
|
1352 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/modeling_utils.py:2325\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m 2323\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m 2324\u001b[0m config_path \u001b[39m=\u001b[39m config \u001b[39mif\u001b[39;00m config \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m pretrained_model_name_or_path\n\u001b[0;32m-> 2325\u001b[0m config, model_kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mconfig_class\u001b[39m.\u001b[39;49mfrom_pretrained(\n\u001b[1;32m 2326\u001b[0m config_path,\n\u001b[1;32m 2327\u001b[0m cache_dir\u001b[39m=\u001b[39;49mcache_dir,\n\u001b[1;32m 2328\u001b[0m return_unused_kwargs\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 2329\u001b[0m force_download\u001b[39m=\u001b[39;49mforce_download,\n\u001b[1;32m 2330\u001b[0m resume_download\u001b[39m=\u001b[39;49mresume_download,\n\u001b[1;32m 2331\u001b[0m proxies\u001b[39m=\u001b[39;49mproxies,\n\u001b[1;32m 2332\u001b[0m local_files_only\u001b[39m=\u001b[39;49mlocal_files_only,\n\u001b[1;32m 2333\u001b[0m token\u001b[39m=\u001b[39;49mtoken,\n\u001b[1;32m 2334\u001b[0m revision\u001b[39m=\u001b[39;49mrevision,\n\u001b[1;32m 2335\u001b[0m subfolder\u001b[39m=\u001b[39;49msubfolder,\n\u001b[1;32m 2336\u001b[0m _from_auto\u001b[39m=\u001b[39;49mfrom_auto_class,\n\u001b[1;32m 2337\u001b[0m _from_pipeline\u001b[39m=\u001b[39;49mfrom_pipeline,\n\u001b[1;32m 2338\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 2339\u001b[0m )\n\u001b[1;32m 2340\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 2341\u001b[0m model_kwargs \u001b[39m=\u001b[39m kwargs\n",
|
1353 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:590\u001b[0m, in \u001b[0;36mPretrainedConfig.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs)\u001b[0m\n\u001b[1;32m 586\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mrevision\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m revision\n\u001b[1;32m 588\u001b[0m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39m_set_token_in_kwargs(kwargs, token)\n\u001b[0;32m--> 590\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49mget_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 591\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict \u001b[39mand\u001b[39;00m \u001b[39mhasattr\u001b[39m(\u001b[39mcls\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m) \u001b[39mand\u001b[39;00m config_dict[\u001b[39m\"\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m!=\u001b[39m \u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type:\n\u001b[1;32m 592\u001b[0m logger\u001b[39m.\u001b[39mwarning(\n\u001b[1;32m 593\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou are using a model of type \u001b[39m\u001b[39m{\u001b[39;00mconfig_dict[\u001b[39m'\u001b[39m\u001b[39mmodel_type\u001b[39m\u001b[39m'\u001b[39m]\u001b[39m}\u001b[39;00m\u001b[39m to instantiate a model of type \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 594\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mcls\u001b[39m\u001b[39m.\u001b[39mmodel_type\u001b[39m}\u001b[39;00m\u001b[39m. This is not supported for all configurations of models and can yield errors.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 595\u001b[0m )\n",
|
1354 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:617\u001b[0m, in \u001b[0;36mPretrainedConfig.get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 615\u001b[0m original_kwargs \u001b[39m=\u001b[39m copy\u001b[39m.\u001b[39mdeepcopy(kwargs)\n\u001b[1;32m 616\u001b[0m \u001b[39m# Get config dict associated with the base config file\u001b[39;00m\n\u001b[0;32m--> 617\u001b[0m config_dict, kwargs \u001b[39m=\u001b[39m \u001b[39mcls\u001b[39;49m\u001b[39m.\u001b[39;49m_get_config_dict(pretrained_model_name_or_path, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 618\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m config_dict:\n\u001b[1;32m 619\u001b[0m original_kwargs[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m config_dict[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m]\n",
|
1355 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/configuration_utils.py:705\u001b[0m, in \u001b[0;36mPretrainedConfig._get_config_dict\u001b[0;34m(cls, pretrained_model_name_or_path, **kwargs)\u001b[0m\n\u001b[1;32m 703\u001b[0m config_dict[\u001b[39m\"\u001b[39m\u001b[39m_commit_hash\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m commit_hash\n\u001b[1;32m 704\u001b[0m \u001b[39mexcept\u001b[39;00m (json\u001b[39m.\u001b[39mJSONDecodeError, \u001b[39mUnicodeDecodeError\u001b[39;00m):\n\u001b[0;32m--> 705\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(\n\u001b[1;32m 706\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIt looks like the config file at \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mresolved_config_file\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m is not a valid JSON file.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 707\u001b[0m )\n\u001b[1;32m 709\u001b[0m \u001b[39mif\u001b[39;00m is_local:\n\u001b[1;32m 710\u001b[0m logger\u001b[39m.\u001b[39minfo(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mloading configuration file \u001b[39m\u001b[39m{\u001b[39;00mresolved_config_file\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n",
|
1356 |
+
"\u001b[0;31mOSError\u001b[0m: It looks like the config file at './whisper-base-vi/pytorch_model.bin' is not a valid JSON file."
|
1357 |
+
]
|
1358 |
+
}
|
1359 |
+
],
|
1360 |
+
"source": [
|
1361 |
+
"pt_model = WhisperForConditionalGeneration.from_pretrained(\"./whisper-base-vi/pytorch_model.bin\", from_tf=True)\n",
|
1362 |
+
"pt_model.save_pretrained(\"./whisper-base-vi/vi_whisper.pt\")"
|
1363 |
+
]
|
1364 |
+
},
|
1365 |
+
{
|
1366 |
+
"cell_type": "code",
|
1367 |
+
"execution_count": null,
|
1368 |
+
"metadata": {},
|
1369 |
+
"outputs": [],
|
1370 |
+
"source": [
|
1371 |
+
"kwargs = {\n",
|
1372 |
+
" \"dataset_tags\": \"vivos-commonvoice\",\n",
|
1373 |
+
" \"dataset\": \"Vivos\", \n",
|
1374 |
+
" \"language\": \"vi\",\n",
|
1375 |
+
" \"model_name\": \"Whisper Small Vi - Duy Ta\", \n",
|
1376 |
+
" \"finetuned_from\": \"openai/whisper-small\",\n",
|
1377 |
+
" \"tasks\": \"automatic-speech-recognition\",\n",
|
1378 |
+
" \"config\" : None\n",
|
1379 |
+
"}\n"
|
1380 |
+
]
|
1381 |
+
},
|
1382 |
+
{
|
1383 |
+
"cell_type": "code",
|
1384 |
+
"execution_count": 131,
|
1385 |
+
"metadata": {},
|
1386 |
+
"outputs": [
|
1387 |
+
{
|
1388 |
+
"name": "stderr",
|
1389 |
+
"output_type": "stream",
|
1390 |
+
"text": [
|
1391 |
+
"Several commits (2) will be pushed upstream.\n",
|
1392 |
+
"The progress bars may be unreliable.\n",
|
1393 |
+
"error: The destination you provided is not a full refname (i.e.,\n",
|
1394 |
+
"starting with \"refs/\"). We tried to guess what you meant by:\n",
|
1395 |
+
"\n",
|
1396 |
+
"- Looking for a ref that matches 'HEAD' on the remote side.\n",
|
1397 |
+
"- Checking if the <src> being pushed ('HEAD')\n",
|
1398 |
+
" is a ref in \"refs/{heads,tags}/\". If so we add a corresponding\n",
|
1399 |
+
" refs/{heads,tags}/ prefix on the remote side.\n",
|
1400 |
+
"\n",
|
1401 |
+
"Neither worked, so we gave up. You must fully qualify the ref.\n",
|
1402 |
+
"hint: The <src> part of the refspec is a commit object.\n",
|
1403 |
+
"hint: Did you mean to create a new branch by pushing to\n",
|
1404 |
+
"hint: 'HEAD:refs/heads/HEAD'?\n",
|
1405 |
+
"error: failed to push some refs to 'https://huggingface.co/DuyTa/vi_whisper-small'\n",
|
1406 |
+
"\n"
|
1407 |
+
]
|
1408 |
+
},
|
1409 |
+
{
|
1410 |
+
"ename": "OSError",
|
1411 |
+
"evalue": "error: The destination you provided is not a full refname (i.e.,\nstarting with \"refs/\"). We tried to guess what you meant by:\n\n- Looking for a ref that matches 'HEAD' on the remote side.\n- Checking if the <src> being pushed ('HEAD')\n is a ref in \"refs/{heads,tags}/\". If so we add a corresponding\n refs/{heads,tags}/ prefix on the remote side.\n\nNeither worked, so we gave up. You must fully qualify the ref.\nhint: The <src> part of the refspec is a commit object.\nhint: Did you mean to create a new branch by pushing to\nhint: 'HEAD:refs/heads/HEAD'?\nerror: failed to push some refs to 'https://huggingface.co/DuyTa/vi_whisper-small'\n",
|
1412 |
+
"output_type": "error",
|
1413 |
+
"traceback": [
|
1414 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1415 |
+
"\u001b[0;31mCalledProcessError\u001b[0m Traceback (most recent call last)",
|
1416 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:1099\u001b[0m, in \u001b[0;36mRepository.git_push\u001b[0;34m(self, upstream, blocking, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1098\u001b[0m \u001b[39mif\u001b[39;00m return_code:\n\u001b[0;32m-> 1099\u001b[0m \u001b[39mraise\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError(return_code, process\u001b[39m.\u001b[39margs, output\u001b[39m=\u001b[39mstdout, stderr\u001b[39m=\u001b[39mstderr)\n\u001b[1;32m 1101\u001b[0m \u001b[39mexcept\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError \u001b[39mas\u001b[39;00m exc:\n",
|
1417 |
+
"\u001b[0;31mCalledProcessError\u001b[0m: Command '['git', 'push', '--set-upstream', 'origin', 'HEAD']' returned non-zero exit status 1.",
|
1418 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
1419 |
+
"\u001b[0;31mOSError\u001b[0m Traceback (most recent call last)",
|
1420 |
+
"Cell \u001b[0;32mIn[131], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m trainer\u001b[39m.\u001b[39;49mpush_to_hub(commit_message\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mchange\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n",
|
1421 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/trainer.py:3609\u001b[0m, in \u001b[0;36mTrainer.push_to_hub\u001b[0;34m(self, commit_message, blocking, **kwargs)\u001b[0m\n\u001b[1;32m 3606\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpush_in_progress\u001b[39m.\u001b[39m_process\u001b[39m.\u001b[39mkill()\n\u001b[1;32m 3607\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mpush_in_progress \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[0;32m-> 3609\u001b[0m git_head_commit_url \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrepo\u001b[39m.\u001b[39;49mpush_to_hub(\n\u001b[1;32m 3610\u001b[0m commit_message\u001b[39m=\u001b[39;49mcommit_message, blocking\u001b[39m=\u001b[39;49mblocking, auto_lfs_prune\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m\n\u001b[1;32m 3611\u001b[0m )\n\u001b[1;32m 3612\u001b[0m \u001b[39m# push separately the model card to be independant from the rest of the model\u001b[39;00m\n\u001b[1;32m 3613\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mshould_save:\n",
|
1422 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:1307\u001b[0m, in \u001b[0;36mRepository.push_to_hub\u001b[0;34m(self, commit_message, blocking, clean_ok, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1305\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgit_add(auto_lfs_track\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m 1306\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mgit_commit(commit_message)\n\u001b[0;32m-> 1307\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgit_push(\n\u001b[1;32m 1308\u001b[0m upstream\u001b[39m=\u001b[39;49m\u001b[39mf\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39morigin \u001b[39;49m\u001b[39m{\u001b[39;49;00m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcurrent_branch\u001b[39m}\u001b[39;49;00m\u001b[39m\"\u001b[39;49m,\n\u001b[1;32m 1309\u001b[0m blocking\u001b[39m=\u001b[39;49mblocking,\n\u001b[1;32m 1310\u001b[0m auto_lfs_prune\u001b[39m=\u001b[39;49mauto_lfs_prune,\n\u001b[1;32m 1311\u001b[0m )\n",
|
1423 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/repository.py:1102\u001b[0m, in \u001b[0;36mRepository.git_push\u001b[0;34m(self, upstream, blocking, auto_lfs_prune)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[39mraise\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError(return_code, process\u001b[39m.\u001b[39margs, output\u001b[39m=\u001b[39mstdout, stderr\u001b[39m=\u001b[39mstderr)\n\u001b[1;32m 1101\u001b[0m \u001b[39mexcept\u001b[39;00m subprocess\u001b[39m.\u001b[39mCalledProcessError \u001b[39mas\u001b[39;00m exc:\n\u001b[0;32m-> 1102\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mEnvironmentError\u001b[39;00m(exc\u001b[39m.\u001b[39mstderr)\n\u001b[1;32m 1104\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m blocking:\n\u001b[1;32m 1106\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstatus_method\u001b[39m():\n",
|
1424 |
+
"\u001b[0;31mOSError\u001b[0m: error: The destination you provided is not a full refname (i.e.,\nstarting with \"refs/\"). We tried to guess what you meant by:\n\n- Looking for a ref that matches 'HEAD' on the remote side.\n- Checking if the <src> being pushed ('HEAD')\n is a ref in \"refs/{heads,tags}/\". If so we add a corresponding\n refs/{heads,tags}/ prefix on the remote side.\n\nNeither worked, so we gave up. You must fully qualify the ref.\nhint: The <src> part of the refspec is a commit object.\nhint: Did you mean to create a new branch by pushing to\nhint: 'HEAD:refs/heads/HEAD'?\nerror: failed to push some refs to 'https://huggingface.co/DuyTa/vi_whisper-small'\n"
|
1425 |
+
]
|
1426 |
+
}
|
1427 |
+
],
|
1428 |
+
"source": [
|
1429 |
+
"trainer.push_to_hub(commit_message=\"change\")"
|
1430 |
+
]
|
1431 |
+
},
|
1432 |
+
{
|
1433 |
+
"cell_type": "code",
|
1434 |
+
"execution_count": null,
|
1435 |
+
"metadata": {
|
1436 |
+
"tags": [
|
1437 |
+
"parameters"
|
1438 |
+
]
|
1439 |
+
},
|
1440 |
+
"outputs": [
|
1441 |
+
{
|
1442 |
+
"data": {
|
1443 |
+
"application/vnd.jupyter.widget-view+json": {
|
1444 |
+
"model_id": "b6e666bab7b2450abf3e2adf07679122",
|
1445 |
+
"version_major": 2,
|
1446 |
+
"version_minor": 0
|
1447 |
+
},
|
1448 |
+
"text/plain": [
|
1449 |
+
"Downloading (…)lve/main/config.json: 0%| | 0.00/1.31k [00:00<?, ?B/s]"
|
1450 |
+
]
|
1451 |
+
},
|
1452 |
+
"metadata": {},
|
1453 |
+
"output_type": "display_data"
|
1454 |
+
},
|
1455 |
+
{
|
1456 |
+
"data": {
|
1457 |
+
"application/vnd.jupyter.widget-view+json": {
|
1458 |
+
"model_id": "b212026dca9241cf994f9710f0b93c22",
|
1459 |
+
"version_major": 2,
|
1460 |
+
"version_minor": 0
|
1461 |
+
},
|
1462 |
+
"text/plain": [
|
1463 |
+
"Downloading (…)okenizer_config.json: 0%| | 0.00/838 [00:00<?, ?B/s]"
|
1464 |
+
]
|
1465 |
+
},
|
1466 |
+
"metadata": {},
|
1467 |
+
"output_type": "display_data"
|
1468 |
+
}
|
1469 |
+
],
|
1470 |
+
"source": [
|
1471 |
+
"from transformers import WhisperForConditionalGeneration, WhisperProcessor\n",
|
1472 |
+
"\n",
|
1473 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\"DuyTa/vi_whisper\")\n",
|
1474 |
+
"processor = WhisperProcessor.from_pretrained(\"DuyTa/vi_whisper\")\n"
|
1475 |
+
]
|
1476 |
+
},
|
1477 |
+
{
|
1478 |
+
"cell_type": "code",
|
1479 |
+
"execution_count": 36,
|
1480 |
+
"metadata": {},
|
1481 |
+
"outputs": [
|
1482 |
+
{
|
1483 |
+
"ename": "RuntimeError",
|
1484 |
+
"evalue": "Instantiating a pipeline without a task set raised an error: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vi_whisper-small'.",
|
1485 |
+
"output_type": "error",
|
1486 |
+
"traceback": [
|
1487 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1488 |
+
"\u001b[0;31mHFValidationError\u001b[0m Traceback (most recent call last)",
|
1489 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/pipelines/__init__.py:432\u001b[0m, in \u001b[0;36mget_task\u001b[0;34m(model, use_auth_token)\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 432\u001b[0m info \u001b[39m=\u001b[39m model_info(model, token\u001b[39m=\u001b[39;49muse_auth_token)\n\u001b[1;32m 433\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n",
|
1490 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:110\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[39mif\u001b[39;00m arg_name \u001b[39min\u001b[39;00m [\u001b[39m\"\u001b[39m\u001b[39mrepo_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mfrom_id\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mto_id\u001b[39m\u001b[39m\"\u001b[39m]:\n\u001b[0;32m--> 110\u001b[0m validate_repo_id(arg_value)\n\u001b[1;32m 112\u001b[0m \u001b[39melif\u001b[39;00m arg_name \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mtoken\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mand\u001b[39;00m arg_value \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
|
1491 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/huggingface_hub/utils/_validators.py:164\u001b[0m, in \u001b[0;36mvalidate_repo_id\u001b[0;34m(repo_id)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m REPO_ID_REGEX\u001b[39m.\u001b[39mmatch(repo_id):\n\u001b[0;32m--> 164\u001b[0m \u001b[39mraise\u001b[39;00m HFValidationError(\n\u001b[1;32m 165\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mRepo id must use alphanumeric chars or \u001b[39m\u001b[39m'\u001b[39m\u001b[39m-\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m_\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m\u001b[39m, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m--\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39m..\u001b[39m\u001b[39m'\u001b[39m\u001b[39m are\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 166\u001b[0m \u001b[39m\"\u001b[39m\u001b[39m forbidden, \u001b[39m\u001b[39m'\u001b[39m\u001b[39m-\u001b[39m\u001b[39m'\u001b[39m\u001b[39m and \u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m'\u001b[39m\u001b[39m cannot start or end the name, max length is 96:\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 167\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m\u001b[39m{\u001b[39;00mrepo_id\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 168\u001b[0m )\n\u001b[1;32m 170\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m--\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m repo_id \u001b[39mor\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m..\u001b[39m\u001b[39m\"\u001b[39m \u001b[39min\u001b[39;00m repo_id:\n",
|
1492 |
+
"\u001b[0;31mHFValidationError\u001b[0m: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vi_whisper-small'.",
|
1493 |
+
"\nDuring handling of the above exception, another exception occurred:\n",
|
1494 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
1495 |
+
"Cell \u001b[0;32mIn[36], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtransformers\u001b[39;00m \u001b[39mimport\u001b[39;00m pipeline\n\u001b[1;32m 2\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mgradio\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mgr\u001b[39;00m\n\u001b[0;32m----> 4\u001b[0m pipe \u001b[39m=\u001b[39m pipeline(model\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39m./vi_whisper-small\u001b[39;49m\u001b[39m\"\u001b[39;49m) \n\u001b[1;32m 6\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mtranscribe\u001b[39m(audio):\n\u001b[1;32m 7\u001b[0m text \u001b[39m=\u001b[39m pipe(audio)[\u001b[39m\"\u001b[39m\u001b[39mtext\u001b[39m\u001b[39m\"\u001b[39m]\n",
|
1496 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/pipelines/__init__.py:726\u001b[0m, in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, image_processor, framework, revision, use_fast, use_auth_token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m 721\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(model, \u001b[39mstr\u001b[39m):\n\u001b[1;32m 722\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 723\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mInferring the task automatically requires to check the hub with a model_id defined as a `str`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 724\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mmodel\u001b[39m}\u001b[39;00m\u001b[39m is not a valid model_id.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 725\u001b[0m )\n\u001b[0;32m--> 726\u001b[0m task \u001b[39m=\u001b[39m get_task(model, use_auth_token)\n\u001b[1;32m 728\u001b[0m \u001b[39m# Retrieve the task\u001b[39;00m\n\u001b[1;32m 729\u001b[0m \u001b[39mif\u001b[39;00m task \u001b[39min\u001b[39;00m custom_tasks:\n",
|
1497 |
+
"File \u001b[0;32m~/miniconda3/envs/DUY/lib/python3.9/site-packages/transformers/pipelines/__init__.py:434\u001b[0m, in \u001b[0;36mget_task\u001b[0;34m(model, use_auth_token)\u001b[0m\n\u001b[1;32m 432\u001b[0m info \u001b[39m=\u001b[39m model_info(model, token\u001b[39m=\u001b[39muse_auth_token)\n\u001b[1;32m 433\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[0;32m--> 434\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mInstantiating a pipeline without a task set raised an error: \u001b[39m\u001b[39m{\u001b[39;00me\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 435\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m info\u001b[39m.\u001b[39mpipeline_tag:\n\u001b[1;32m 436\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mRuntimeError\u001b[39;00m(\n\u001b[1;32m 437\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mThe model \u001b[39m\u001b[39m{\u001b[39;00mmodel\u001b[39m}\u001b[39;00m\u001b[39m does not seem to have a correct `pipeline_tag` set to infer the task automatically\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 438\u001b[0m )\n",
|
1498 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Instantiating a pipeline without a task set raised an error: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: './vi_whisper-small'."
|
1499 |
+
]
|
1500 |
+
}
|
1501 |
+
],
|
1502 |
+
"source": [
|
1503 |
+
"from transformers import pipeline\n",
|
1504 |
+
"import gradio as gr\n",
|
1505 |
+
"\n",
|
1506 |
+
"pipe = pipeline(model=\"./vi_whisper-small\") \n",
|
1507 |
+
"\n",
|
1508 |
+
"def transcribe(audio):\n",
|
1509 |
+
" text = pipe(audio)[\"text\"]\n",
|
1510 |
+
" return text\n",
|
1511 |
+
"\n",
|
1512 |
+
"iface = gr.Interface(\n",
|
1513 |
+
" fn=transcribe,\n",
|
1514 |
+
" inputs=gr.Audio(source=\"upload\", type=\"filepath\"),\n",
|
1515 |
+
" outputs=\"text\",\n",
|
1516 |
+
" title=\"Whisper Base Vietnamese\",\n",
|
1517 |
+
" description=\"Realtime demo for Vietnamese speech recognition using a fine-tuned Whisper base model.\",\n",
|
1518 |
+
")\n",
|
1519 |
+
"\n",
|
1520 |
+
"iface.launch()"
|
1521 |
+
]
|
1522 |
+
}
|
1523 |
+
],
|
1524 |
+
"metadata": {
|
1525 |
+
"kernelspec": {
|
1526 |
+
"display_name": "DUY",
|
1527 |
+
"language": "python",
|
1528 |
+
"name": "python3"
|
1529 |
+
},
|
1530 |
+
"language_info": {
|
1531 |
+
"codemirror_mode": {
|
1532 |
+
"name": "ipython",
|
1533 |
+
"version": 3
|
1534 |
+
},
|
1535 |
+
"file_extension": ".py",
|
1536 |
+
"mimetype": "text/x-python",
|
1537 |
+
"name": "python",
|
1538 |
+
"nbconvert_exporter": "python",
|
1539 |
+
"pygments_lexer": "ipython3",
|
1540 |
+
"version": "3.9.17"
|
1541 |
+
},
|
1542 |
+
"orig_nbformat": 4
|
1543 |
+
},
|
1544 |
+
"nbformat": 4,
|
1545 |
+
"nbformat_minor": 2
|
1546 |
+
}
|
src/training.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import interpreter_login
|
2 |
+
|
3 |
+
from datasets import load_dataset, DatasetDict, load_from_disk
|
4 |
+
|
5 |
+
from transformers import WhisperProcessor
|
6 |
+
from transformers import WhisperForConditionalGeneration
|
7 |
+
from transformers import Seq2SeqTrainingArguments
|
8 |
+
from transformers import Seq2SeqTrainer
|
9 |
+
from transformers import EarlyStoppingCallback
|
10 |
+
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
11 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
12 |
+
|
13 |
+
from peft import prepare_model_for_int8_training
|
14 |
+
from peft import PeftModel, LoraModel, LoraConfig, get_peft_model
|
15 |
+
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Any, Dict, List, Union
|
20 |
+
|
21 |
+
import evaluate
|
22 |
+
|
23 |
+
import os
|
24 |
+
|
25 |
+
class SavePeftModelCallback(TrainerCallback):
|
26 |
+
def on_save(
|
27 |
+
self,
|
28 |
+
args: TrainingArguments,
|
29 |
+
state: TrainerState,
|
30 |
+
control: TrainerControl,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
34 |
+
|
35 |
+
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
|
36 |
+
kwargs["model"].save_pretrained(peft_model_path)
|
37 |
+
|
38 |
+
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
|
39 |
+
if os.path.exists(pytorch_model_path):
|
40 |
+
os.remove(pytorch_model_path)
|
41 |
+
return control
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class DataCollatorSpeechSeq2SeqWithPadding:
|
45 |
+
processor: Any
|
46 |
+
|
47 |
+
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
48 |
+
# split inputs and labels since they have to be of different lengths and need different padding methods
|
49 |
+
# first treat the audio inputs by simply returning torch tensors
|
50 |
+
input_features = [{"input_features": feature["input_features"]} for feature in features]
|
51 |
+
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
|
52 |
+
|
53 |
+
# get the tokenized label sequences
|
54 |
+
label_features = [{"input_ids": feature["labels"]} for feature in features]
|
55 |
+
|
56 |
+
# ******************This is only in the case of augmented data ***************** Remove if not
|
57 |
+
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
|
58 |
+
|
59 |
+
# pad the labels to max length
|
60 |
+
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
|
61 |
+
|
62 |
+
# replace padding with -100 to ignore loss correctly
|
63 |
+
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
|
64 |
+
|
65 |
+
# if bos token is appended in previous tokenization step,
|
66 |
+
# cut bos token here as it's append later anyways
|
67 |
+
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
|
68 |
+
labels = labels[:, 1:]
|
69 |
+
|
70 |
+
batch["labels"] = labels
|
71 |
+
|
72 |
+
return batch
|
73 |
+
|
74 |
+
def compute_metrics(pred):
|
75 |
+
pred_ids = pred.predictions
|
76 |
+
label_ids = pred.label_ids
|
77 |
+
|
78 |
+
# replace -100 with the pad_token_id
|
79 |
+
label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
|
80 |
+
|
81 |
+
# we do not want to group tokens when computing the metrics
|
82 |
+
pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
83 |
+
label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
84 |
+
|
85 |
+
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
86 |
+
|
87 |
+
return {"wer": wer}
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
|
93 |
+
|
94 |
+
early_stopping_callback = EarlyStoppingCallback(
|
95 |
+
early_stopping_patience=3, # Stop training if the metric doesn't improve for 3 evaluations
|
96 |
+
early_stopping_threshold=0.0005, # Minimum change in the metric to be considered an improvement
|
97 |
+
)
|
98 |
+
|
99 |
+
# Load Dataset
|
100 |
+
processed_dataset = DatasetDict()
|
101 |
+
processed_dataset = load_from_disk("./vin_clean")
|
102 |
+
|
103 |
+
|
104 |
+
print(processed_dataset)
|
105 |
+
|
106 |
+
# load processor
|
107 |
+
processor = WhisperProcessor.from_pretrained("openai/whisper-medium", language="Vietnamese", task="transcribe")
|
108 |
+
|
109 |
+
|
110 |
+
# intialize data collator
|
111 |
+
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
|
112 |
+
|
113 |
+
# download metric
|
114 |
+
metric = evaluate.load("wer")
|
115 |
+
|
116 |
+
# Download model in 8bit
|
117 |
+
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium", load_in_8bit=True, device_map="auto")
|
118 |
+
model.config.forced_decoder_ids = None
|
119 |
+
model.config.suppress_tokens = []
|
120 |
+
|
121 |
+
# preparing model with PEFT
|
122 |
+
model = prepare_model_for_int8_training(model, output_imbedding_layer="proj_out")
|
123 |
+
|
124 |
+
config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
|
125 |
+
|
126 |
+
model = get_peft_model(model, config)
|
127 |
+
model.print_trainable_parameters()
|
128 |
+
|
129 |
+
|
130 |
+
# Define trainnig arguments
|
131 |
+
training_args = Seq2SeqTrainingArguments(
|
132 |
+
output_dir="./whisper-medium-Lora", # change to a repo name of your choice
|
133 |
+
per_device_train_batch_size=32,
|
134 |
+
gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size
|
135 |
+
learning_rate=5e-5,
|
136 |
+
warmup_steps=500,
|
137 |
+
max_steps=10000,
|
138 |
+
evaluation_strategy="steps",
|
139 |
+
gradient_checkpointing=True,
|
140 |
+
optim="adamw_torch",
|
141 |
+
fp16=True,
|
142 |
+
per_device_eval_batch_size=8,
|
143 |
+
generation_max_length=225,
|
144 |
+
save_steps=2000,
|
145 |
+
eval_steps=500,
|
146 |
+
logging_steps=25,
|
147 |
+
report_to=["tensorboard"],
|
148 |
+
predict_with_generate=True,
|
149 |
+
# load_best_model_at_end=True,
|
150 |
+
metric_for_best_model="wer",
|
151 |
+
greater_is_better=False,
|
152 |
+
# required as the PeftModel forward doesn't have the signature of the wrapped model's forward
|
153 |
+
remove_unused_columns=False,
|
154 |
+
label_names=["labels"], # same reason as above
|
155 |
+
push_to_hub=False,
|
156 |
+
)
|
157 |
+
|
158 |
+
# initialize trainer
|
159 |
+
trainer = Seq2SeqTrainer(
|
160 |
+
args=training_args,
|
161 |
+
model=model,
|
162 |
+
train_dataset=processed_dataset["train"],
|
163 |
+
eval_dataset=processed_dataset["test"],
|
164 |
+
data_collator=data_collator,
|
165 |
+
tokenizer=processor.feature_extractor,
|
166 |
+
callbacks=[early_stopping_callback, SavePeftModelCallback],
|
167 |
+
)
|
168 |
+
|
169 |
+
|
170 |
+
# start training
|
171 |
+
trainer.train()
|
172 |
+
|
173 |
+
|
174 |
+
# set up args and push to hub
|
175 |
+
kwargs = {
|
176 |
+
"dataset": "vin100h",
|
177 |
+
"language": "vi",
|
178 |
+
"model_name": "Whisper Medium LoRA - Clean Data",
|
179 |
+
"finetuned_from": "openai/whisper-medium",
|
180 |
+
"tasks": "automatic-speech-recognition",
|
181 |
+
}
|
182 |
+
|
183 |
+
model.push_to_hub(**kwargs)
|
src/vin_whisper_medium.ipynb
ADDED
@@ -0,0 +1,1164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/whisper/__init__.py\n"
|
13 |
+
]
|
14 |
+
}
|
15 |
+
],
|
16 |
+
"source": [
|
17 |
+
"import whisper\n",
|
18 |
+
"print(whisper.__file__)\n"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 1,
|
24 |
+
"metadata": {},
|
25 |
+
"outputs": [],
|
26 |
+
"source": [
|
27 |
+
"import os\n",
|
28 |
+
"\n",
|
29 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": 3,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [
|
37 |
+
{
|
38 |
+
"name": "stdout",
|
39 |
+
"output_type": "stream",
|
40 |
+
"text": [
|
41 |
+
"True\n",
|
42 |
+
"2\n",
|
43 |
+
"0\n",
|
44 |
+
"<torch.cuda.device object at 0x7f69e1e31eb0>\n",
|
45 |
+
"Tesla T4\n"
|
46 |
+
]
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"source": [
|
50 |
+
"import torch\n",
|
51 |
+
"\n",
|
52 |
+
"print(torch.cuda.is_available())\n",
|
53 |
+
"\n",
|
54 |
+
"\n",
|
55 |
+
"print(torch.cuda.device_count())\n",
|
56 |
+
"\n",
|
57 |
+
"\n",
|
58 |
+
"print(torch.cuda.current_device())\n",
|
59 |
+
"print(torch.cuda.device(0))\n",
|
60 |
+
"\n",
|
61 |
+
"print(torch.cuda.get_device_name(0))\n"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"execution_count": 2,
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [
|
69 |
+
{
|
70 |
+
"data": {
|
71 |
+
"application/vnd.jupyter.widget-view+json": {
|
72 |
+
"model_id": "f290d4efc37a4112a662c062e621e482",
|
73 |
+
"version_major": 2,
|
74 |
+
"version_minor": 0
|
75 |
+
},
|
76 |
+
"text/plain": [
|
77 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
"metadata": {},
|
81 |
+
"output_type": "display_data"
|
82 |
+
}
|
83 |
+
],
|
84 |
+
"source": [
|
85 |
+
"from huggingface_hub import notebook_login\n",
|
86 |
+
"\n",
|
87 |
+
"notebook_login()"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 2,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"model_name_or_path = \"openai/whisper-medium\"\n",
|
97 |
+
"task = \"transcribe\""
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 3,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"dataset_name = \"Vin100h_MITI_private\"\n",
|
107 |
+
"language = \"Vietnamese\"\n",
|
108 |
+
"language_abbr = \"vi\" # Short hand code for the language we want to fine-tune"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 4,
|
114 |
+
"metadata": {},
|
115 |
+
"outputs": [
|
116 |
+
{
|
117 |
+
"name": "stdout",
|
118 |
+
"output_type": "stream",
|
119 |
+
"text": [
|
120 |
+
"DatasetDict({\n",
|
121 |
+
" train: Dataset({\n",
|
122 |
+
" features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
|
123 |
+
" num_rows: 1679\n",
|
124 |
+
" })\n",
|
125 |
+
" test: Dataset({\n",
|
126 |
+
" features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
|
127 |
+
" num_rows: 420\n",
|
128 |
+
" })\n",
|
129 |
+
"})\n",
|
130 |
+
"DatasetDict({\n",
|
131 |
+
" train: Dataset({\n",
|
132 |
+
" features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
|
133 |
+
" num_rows: 6735\n",
|
134 |
+
" })\n",
|
135 |
+
" test: Dataset({\n",
|
136 |
+
" features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
|
137 |
+
" num_rows: 1688\n",
|
138 |
+
" })\n",
|
139 |
+
"})\n"
|
140 |
+
]
|
141 |
+
}
|
142 |
+
],
|
143 |
+
"source": [
|
144 |
+
" # Load Dataset\n",
|
145 |
+
"from datasets import load_dataset, DatasetDict, load_from_disk\n",
|
146 |
+
"processed_dataset = DatasetDict()\n",
|
147 |
+
"processed_dataset = load_from_disk(\"./MITI_clean\")\n",
|
148 |
+
"processed_dataset2 = load_from_disk(\"./vin_10h/\")\n",
|
149 |
+
"\n",
|
150 |
+
"print(processed_dataset)\n",
|
151 |
+
"print(processed_dataset2)"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 49,
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [],
|
159 |
+
"source": [
|
160 |
+
"from datasets import Dataset\n",
|
161 |
+
"\n",
|
162 |
+
"# Assuming you have already loaded your dataset\n",
|
163 |
+
"# processed_dataset2 = ...\n",
|
164 |
+
"\n",
|
165 |
+
"# Randomly select 5000 indices from the train dataset\n",
|
166 |
+
"import random\n",
|
167 |
+
"num_samples_train = 5000\n",
|
168 |
+
"num_samples_test = 600\n",
|
169 |
+
"random_indices_train = random.sample(range(len(processed_dataset2['train'])), num_samples_train)\n",
|
170 |
+
"random_indices_test = random.sample(range(len(processed_dataset2['test'])), num_samples_test)\n",
|
171 |
+
"\n",
|
172 |
+
"# Initialize lists for train dataset\n",
|
173 |
+
"input_features_train = []\n",
|
174 |
+
"input_length_train = []\n",
|
175 |
+
"attention_mask_train = []\n",
|
176 |
+
"labels_train = []\n",
|
177 |
+
"\n",
|
178 |
+
"# Initialize lists for test dataset\n",
|
179 |
+
"input_features_test = []\n",
|
180 |
+
"input_length_test = []\n",
|
181 |
+
"attention_mask_test = []\n",
|
182 |
+
"labels_test = []\n",
|
183 |
+
"\n",
|
184 |
+
"# Populate lists for train dataset\n",
|
185 |
+
"for i in random_indices_train:\n",
|
186 |
+
" input_features_train.append(processed_dataset2['train'][i]['input_features'])\n",
|
187 |
+
" input_length_train.append(processed_dataset2['train'][i]['input_length'])\n",
|
188 |
+
" attention_mask_train.append(processed_dataset2['train'][i]['attention_mask'])\n",
|
189 |
+
" labels_train.append(processed_dataset2['train'][i]['labels'])\n",
|
190 |
+
"\n",
|
191 |
+
"# Populate lists for test dataset\n",
|
192 |
+
"for i in random_indices_test:\n",
|
193 |
+
" input_features_test.append(processed_dataset2['test'][i]['input_features'])\n",
|
194 |
+
" input_length_test.append(processed_dataset2['test'][i]['input_length'])\n",
|
195 |
+
" attention_mask_test.append(processed_dataset2['test'][i]['attention_mask'])\n",
|
196 |
+
" labels_test.append(processed_dataset2['test'][i]['labels'])\n",
|
197 |
+
"\n",
|
198 |
+
"# Create a new dataset with the randomly selected rows\n",
|
199 |
+
"random_subset = Dataset.from_dict({\n",
|
200 |
+
" 'train': {\n",
|
201 |
+
" 'input_features': input_features_train,\n",
|
202 |
+
" 'input_length': input_length_train,\n",
|
203 |
+
" 'attention_mask': attention_mask_train,\n",
|
204 |
+
" 'labels': labels_train,\n",
|
205 |
+
" },\n",
|
206 |
+
" 'test': {\n",
|
207 |
+
" 'input_features': input_features_test,\n",
|
208 |
+
" 'input_length': input_length_test,\n",
|
209 |
+
" 'attention_mask': attention_mask_test,\n",
|
210 |
+
" 'labels': labels_test,\n",
|
211 |
+
" }\n",
|
212 |
+
"})\n",
|
213 |
+
"\n",
|
214 |
+
"\n"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": 5,
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"import datasets\n",
|
224 |
+
"concat = DatasetDict()\n",
|
225 |
+
"concat[\"train\"] = datasets.concatenate_datasets([processed_dataset[\"train\"], processed_dataset2[\"train\"]])\n",
|
226 |
+
"concat['test']= datasets.concatenate_datasets([processed_dataset[\"test\"], processed_dataset2[\"test\"]])\n"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": 7,
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [
|
234 |
+
{
|
235 |
+
"data": {
|
236 |
+
"text/plain": [
|
237 |
+
"DatasetDict({\n",
|
238 |
+
" train: Dataset({\n",
|
239 |
+
" features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
|
240 |
+
" num_rows: 8414\n",
|
241 |
+
" })\n",
|
242 |
+
" test: Dataset({\n",
|
243 |
+
" features: ['input_features', 'input_length', 'attention_mask', 'labels'],\n",
|
244 |
+
" num_rows: 2108\n",
|
245 |
+
" })\n",
|
246 |
+
"})"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
"execution_count": 7,
|
250 |
+
"metadata": {},
|
251 |
+
"output_type": "execute_result"
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"source": [
|
255 |
+
"concat"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "code",
|
260 |
+
"execution_count": 7,
|
261 |
+
"metadata": {},
|
262 |
+
"outputs": [],
|
263 |
+
"source": [
|
264 |
+
"from transformers import WhisperFeatureExtractor\n",
|
265 |
+
"\n",
|
266 |
+
"feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": 6,
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [],
|
274 |
+
"source": [
|
275 |
+
"\n",
|
276 |
+
"from transformers import WhisperTokenizer\n",
|
277 |
+
"\n",
|
278 |
+
"tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 11,
|
284 |
+
"metadata": {},
|
285 |
+
"outputs": [
|
286 |
+
{
|
287 |
+
"data": {
|
288 |
+
"text/plain": [
|
289 |
+
"('./Viet_ASR/tokenizer_config.json',\n",
|
290 |
+
" './Viet_ASR/special_tokens_map.json',\n",
|
291 |
+
" './Viet_ASR/vocab.json',\n",
|
292 |
+
" './Viet_ASR/merges.txt',\n",
|
293 |
+
" './Viet_ASR/normalizer.json',\n",
|
294 |
+
" './Viet_ASR/added_tokens.json')"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
"execution_count": 11,
|
298 |
+
"metadata": {},
|
299 |
+
"output_type": "execute_result"
|
300 |
+
}
|
301 |
+
],
|
302 |
+
"source": [
|
303 |
+
"tokenizer.save_pretrained('./Viet_ASR')"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": 8,
|
309 |
+
"metadata": {},
|
310 |
+
"outputs": [],
|
311 |
+
"source": [
|
312 |
+
"import torch\n",
|
313 |
+
"\n",
|
314 |
+
"from dataclasses import dataclass\n",
|
315 |
+
"from typing import Any, Dict, List, Union\n",
|
316 |
+
"from transformers import WhisperProcessor\n",
|
317 |
+
"\n",
|
318 |
+
"processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)\n",
|
319 |
+
"@dataclass\n",
|
320 |
+
"class DataCollatorSpeechSeq2SeqWithPadding:\n",
|
321 |
+
" processor: Any\n",
|
322 |
+
"\n",
|
323 |
+
" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
|
324 |
+
" # split inputs and labels since they have to be of different lengths and need different padding methods\n",
|
325 |
+
" # first treat the audio inputs by simply returning torch tensors\n",
|
326 |
+
" input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
|
327 |
+
" batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
|
328 |
+
"\n",
|
329 |
+
" # get the tokenized label sequences\n",
|
330 |
+
" label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
|
331 |
+
"\n",
|
332 |
+
"\n",
|
333 |
+
" # pad the labels to max length\n",
|
334 |
+
" labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
|
335 |
+
"\n",
|
336 |
+
" # replace padding with -100 to ignore loss correctly\n",
|
337 |
+
" labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
|
338 |
+
"\n",
|
339 |
+
" # if bos token is appended in previous tokenization step,\n",
|
340 |
+
" # cut bos token here as it's append later anyways\n",
|
341 |
+
" if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
|
342 |
+
" labels = labels[:, 1:]\n",
|
343 |
+
"\n",
|
344 |
+
" batch[\"labels\"] = labels\n",
|
345 |
+
"\n",
|
346 |
+
" return batch\n",
|
347 |
+
"data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"execution_count": 9,
|
353 |
+
"metadata": {},
|
354 |
+
"outputs": [],
|
355 |
+
"source": [
|
356 |
+
"import evaluate\n",
|
357 |
+
"\n",
|
358 |
+
"metric = evaluate.load(\"wer\")"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "code",
|
363 |
+
"execution_count": 12,
|
364 |
+
"metadata": {},
|
365 |
+
"outputs": [],
|
366 |
+
"source": [
|
367 |
+
"from transformers import WhisperForConditionalGeneration\n",
|
368 |
+
"\n",
|
369 |
+
"model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-medium', load_in_8bit=True, device_map=\"auto\" )"
|
370 |
+
]
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"cell_type": "code",
|
374 |
+
"execution_count": 13,
|
375 |
+
"metadata": {},
|
376 |
+
"outputs": [],
|
377 |
+
"source": [
|
378 |
+
"model.config.forced_decoder_ids = None\n",
|
379 |
+
"model.config.suppress_tokens = []"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "code",
|
384 |
+
"execution_count": 14,
|
385 |
+
"metadata": {},
|
386 |
+
"outputs": [
|
387 |
+
{
|
388 |
+
"data": {
|
389 |
+
"text/plain": [
|
390 |
+
"<torch.utils.hooks.RemovableHandle at 0x7f1a9445da60>"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
"execution_count": 14,
|
394 |
+
"metadata": {},
|
395 |
+
"output_type": "execute_result"
|
396 |
+
}
|
397 |
+
],
|
398 |
+
"source": [
|
399 |
+
"from peft import prepare_model_for_kbit_training\n",
|
400 |
+
"\n",
|
401 |
+
"model = prepare_model_for_kbit_training(model)\n",
|
402 |
+
"def make_inputs_require_grad(module, input, output):\n",
|
403 |
+
" output.requires_grad_(True)\n",
|
404 |
+
"\n",
|
405 |
+
"model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "code",
|
410 |
+
"execution_count": 15,
|
411 |
+
"metadata": {},
|
412 |
+
"outputs": [
|
413 |
+
{
|
414 |
+
"name": "stdout",
|
415 |
+
"output_type": "stream",
|
416 |
+
"text": [
|
417 |
+
"trainable params: 9,437,184 || all params: 773,295,104 || trainable%: 1.2203858463844612\n"
|
418 |
+
]
|
419 |
+
}
|
420 |
+
],
|
421 |
+
"source": [
|
422 |
+
"from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model\n",
|
423 |
+
"#target_modules = [\"k_proj\", \"q_proj\", \"v_proj\", \"out_proj\", \"fc1\", \"fc2\"] #will it better ?\n",
|
424 |
+
"target_modules=[\"q_proj\", \"v_proj\"]\n",
|
425 |
+
"config = LoraConfig(r=32, lora_alpha=64, target_modules=target_modules, lora_dropout=0.05, bias=\"none\")\n",
|
426 |
+
"\n",
|
427 |
+
"model = get_peft_model(model, config)\n",
|
428 |
+
"model.print_trainable_parameters()"
|
429 |
+
]
|
430 |
+
},
|
431 |
+
{
|
432 |
+
"cell_type": "code",
|
433 |
+
"execution_count": 16,
|
434 |
+
"metadata": {},
|
435 |
+
"outputs": [],
|
436 |
+
"source": [
|
437 |
+
"from transformers import Seq2SeqTrainingArguments\n",
|
438 |
+
"\n",
|
439 |
+
"training_args = Seq2SeqTrainingArguments(\n",
|
440 |
+
" output_dir=\"./Vietnamese_ASR\", \n",
|
441 |
+
" per_device_train_batch_size=10,\n",
|
442 |
+
" #auto_find_batch_size = True,\n",
|
443 |
+
" gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size\n",
|
444 |
+
" learning_rate=5e-5,\n",
|
445 |
+
" warmup_steps=50,\n",
|
446 |
+
" num_train_epochs=3,\n",
|
447 |
+
" evaluation_strategy=\"epoch\",\n",
|
448 |
+
" gradient_checkpointing=True,\n",
|
449 |
+
" optim=\"adamw_torch\",\n",
|
450 |
+
" fp16=True,\n",
|
451 |
+
" per_device_eval_batch_size=8,\n",
|
452 |
+
" generation_max_length=225,\n",
|
453 |
+
" logging_steps=100,\n",
|
454 |
+
" report_to=[\"tensorboard\"],\n",
|
455 |
+
" predict_with_generate=True,\n",
|
456 |
+
" # load_best_model_at_end=True,\n",
|
457 |
+
" greater_is_better=False,\n",
|
458 |
+
" save_strategy = \"epoch\",\n",
|
459 |
+
" # required as the PeftModel forward doesn't have the signature of the wrapped model's forward\n",
|
460 |
+
" remove_unused_columns=False,\n",
|
461 |
+
" label_names=[\"labels\"], # same reason as above\n",
|
462 |
+
" push_to_hub=True,\n",
|
463 |
+
")"
|
464 |
+
]
|
465 |
+
},
|
466 |
+
{
|
467 |
+
"cell_type": "code",
|
468 |
+
"execution_count": 19,
|
469 |
+
"metadata": {},
|
470 |
+
"outputs": [
|
471 |
+
{
|
472 |
+
"name": "stderr",
|
473 |
+
"output_type": "stream",
|
474 |
+
"text": [
|
475 |
+
"/media/tesla/New Volume/DEMO/DUY/Vietnamese_ASR/./Vietnamese_ASR is already a clone of https://huggingface.co/DuyTa/Vietnamese_ASR. Make sure you pull the latest changes with `repo.git_pull()`.\n"
|
476 |
+
]
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"source": [
|
480 |
+
"from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl\n",
|
481 |
+
"from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
|
482 |
+
"\n",
|
483 |
+
"\n",
|
484 |
+
"class SavePeftModelCallback(TrainerCallback):\n",
|
485 |
+
" def on_save(\n",
|
486 |
+
" self,\n",
|
487 |
+
" args: TrainingArguments,\n",
|
488 |
+
" state: TrainerState,\n",
|
489 |
+
" control: TrainerControl,\n",
|
490 |
+
" **kwargs,\n",
|
491 |
+
" ):\n",
|
492 |
+
" checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n",
|
493 |
+
"\n",
|
494 |
+
" peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
|
495 |
+
" kwargs[\"model\"].save_pretrained(peft_model_path)\n",
|
496 |
+
"\n",
|
497 |
+
" pytorch_model_path = os.path.join(checkpoint_folder, \"pytorch_model.bin\")\n",
|
498 |
+
" if os.path.exists(pytorch_model_path):\n",
|
499 |
+
" os.remove(pytorch_model_path)\n",
|
500 |
+
" return control\n",
|
501 |
+
"\n",
|
502 |
+
"\n",
|
503 |
+
"trainer = Seq2SeqTrainer(\n",
|
504 |
+
" args=training_args,\n",
|
505 |
+
" model=model,\n",
|
506 |
+
" train_dataset=concat[\"train\"],\n",
|
507 |
+
" eval_dataset=concat[\"test\"],\n",
|
508 |
+
" data_collator=data_collator,\n",
|
509 |
+
" # compute_metrics=compute_metrics,\n",
|
510 |
+
" tokenizer=processor.feature_extractor,\n",
|
511 |
+
" callbacks=[SavePeftModelCallback],\n",
|
512 |
+
")\n",
|
513 |
+
"model.config.use_cache = False # silence the warnings. Please re-enable for inference!"
|
514 |
+
]
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"cell_type": "code",
|
518 |
+
"execution_count": 20,
|
519 |
+
"metadata": {},
|
520 |
+
"outputs": [
|
521 |
+
{
|
522 |
+
"data": {
|
523 |
+
"application/vnd.jupyter.widget-view+json": {
|
524 |
+
"model_id": "b0aa0180f6e64eaa8951a4c940aa518f",
|
525 |
+
"version_major": 2,
|
526 |
+
"version_minor": 0
|
527 |
+
},
|
528 |
+
"text/plain": [
|
529 |
+
" 0%| | 0/1263 [00:00<?, ?it/s]"
|
530 |
+
]
|
531 |
+
},
|
532 |
+
"metadata": {},
|
533 |
+
"output_type": "display_data"
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"name": "stderr",
|
537 |
+
"output_type": "stream",
|
538 |
+
"text": [
|
539 |
+
"/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
|
540 |
+
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
|
541 |
+
]
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"name": "stdout",
|
545 |
+
"output_type": "stream",
|
546 |
+
"text": [
|
547 |
+
"{'loss': 1.9814, 'learning_rate': 4.814509480626546e-05, 'epoch': 0.24}\n",
|
548 |
+
"{'loss': 0.6861, 'learning_rate': 4.402308326463314e-05, 'epoch': 0.48}\n",
|
549 |
+
"{'loss': 0.3736, 'learning_rate': 3.9901071723000826e-05, 'epoch': 0.71}\n",
|
550 |
+
"{'loss': 0.332, 'learning_rate': 3.577906018136851e-05, 'epoch': 0.95}\n"
|
551 |
+
]
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"data": {
|
555 |
+
"application/vnd.jupyter.widget-view+json": {
|
556 |
+
"model_id": "2e9a8f06d39e448a9523d9a29699cadc",
|
557 |
+
"version_major": 2,
|
558 |
+
"version_minor": 0
|
559 |
+
},
|
560 |
+
"text/plain": [
|
561 |
+
" 0%| | 0/264 [00:00<?, ?it/s]"
|
562 |
+
]
|
563 |
+
},
|
564 |
+
"metadata": {},
|
565 |
+
"output_type": "display_data"
|
566 |
+
},
|
567 |
+
{
|
568 |
+
"name": "stdout",
|
569 |
+
"output_type": "stream",
|
570 |
+
"text": [
|
571 |
+
"{'eval_loss': 0.3133259117603302, 'eval_runtime': 887.0949, 'eval_samples_per_second': 2.376, 'eval_steps_per_second': 0.298, 'epoch': 1.0}\n"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"name": "stderr",
|
576 |
+
"output_type": "stream",
|
577 |
+
"text": [
|
578 |
+
"/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
|
579 |
+
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
|
580 |
+
]
|
581 |
+
},
|
582 |
+
{
|
583 |
+
"name": "stdout",
|
584 |
+
"output_type": "stream",
|
585 |
+
"text": [
|
586 |
+
"{'loss': 0.3005, 'learning_rate': 3.165704863973619e-05, 'epoch': 1.19}\n",
|
587 |
+
"{'loss': 0.307, 'learning_rate': 2.753503709810388e-05, 'epoch': 1.43}\n",
|
588 |
+
"{'loss': 0.2838, 'learning_rate': 2.341302555647156e-05, 'epoch': 1.66}\n",
|
589 |
+
"{'loss': 0.2746, 'learning_rate': 1.9291014014839242e-05, 'epoch': 1.9}\n"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
{
|
593 |
+
"data": {
|
594 |
+
"application/vnd.jupyter.widget-view+json": {
|
595 |
+
"model_id": "1e65ecdbc96246b8b4721505b4252a8a",
|
596 |
+
"version_major": 2,
|
597 |
+
"version_minor": 0
|
598 |
+
},
|
599 |
+
"text/plain": [
|
600 |
+
" 0%| | 0/264 [00:00<?, ?it/s]"
|
601 |
+
]
|
602 |
+
},
|
603 |
+
"metadata": {},
|
604 |
+
"output_type": "display_data"
|
605 |
+
},
|
606 |
+
{
|
607 |
+
"name": "stdout",
|
608 |
+
"output_type": "stream",
|
609 |
+
"text": [
|
610 |
+
"{'eval_loss': 0.28433552384376526, 'eval_runtime': 880.1965, 'eval_samples_per_second': 2.395, 'eval_steps_per_second': 0.3, 'epoch': 2.0}\n"
|
611 |
+
]
|
612 |
+
},
|
613 |
+
{
|
614 |
+
"name": "stderr",
|
615 |
+
"output_type": "stream",
|
616 |
+
"text": [
|
617 |
+
"/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
|
618 |
+
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n"
|
619 |
+
]
|
620 |
+
},
|
621 |
+
{
|
622 |
+
"name": "stdout",
|
623 |
+
"output_type": "stream",
|
624 |
+
"text": [
|
625 |
+
"{'loss': 0.2857, 'learning_rate': 1.5169002473206925e-05, 'epoch': 2.14}\n",
|
626 |
+
"{'loss': 0.2643, 'learning_rate': 1.104699093157461e-05, 'epoch': 2.38}\n",
|
627 |
+
"{'loss': 0.2604, 'learning_rate': 6.924979389942292e-06, 'epoch': 2.61}\n",
|
628 |
+
"{'loss': 0.2505, 'learning_rate': 2.8029678483099755e-06, 'epoch': 2.85}\n"
|
629 |
+
]
|
630 |
+
},
|
631 |
+
{
|
632 |
+
"data": {
|
633 |
+
"application/vnd.jupyter.widget-view+json": {
|
634 |
+
"model_id": "3a4f479ef36f4f00b3c591503b411e5f",
|
635 |
+
"version_major": 2,
|
636 |
+
"version_minor": 0
|
637 |
+
},
|
638 |
+
"text/plain": [
|
639 |
+
" 0%| | 0/264 [00:00<?, ?it/s]"
|
640 |
+
]
|
641 |
+
},
|
642 |
+
"metadata": {},
|
643 |
+
"output_type": "display_data"
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"name": "stdout",
|
647 |
+
"output_type": "stream",
|
648 |
+
"text": [
|
649 |
+
"{'eval_loss': 0.27759623527526855, 'eval_runtime': 879.7333, 'eval_samples_per_second': 2.396, 'eval_steps_per_second': 0.3, 'epoch': 3.0}\n",
|
650 |
+
"{'train_runtime': 35575.7347, 'train_samples_per_second': 0.71, 'train_steps_per_second': 0.036, 'train_loss': 0.4555940831925127, 'epoch': 3.0}\n"
|
651 |
+
]
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"data": {
|
655 |
+
"text/plain": [
|
656 |
+
"TrainOutput(global_step=1263, training_loss=0.4555940831925127, metrics={'train_runtime': 35575.7347, 'train_samples_per_second': 0.71, 'train_steps_per_second': 0.036, 'train_loss': 0.4555940831925127, 'epoch': 3.0})"
|
657 |
+
]
|
658 |
+
},
|
659 |
+
"execution_count": 20,
|
660 |
+
"metadata": {},
|
661 |
+
"output_type": "execute_result"
|
662 |
+
}
|
663 |
+
],
|
664 |
+
"source": [
|
665 |
+
"trainer.train()"
|
666 |
+
]
|
667 |
+
},
|
668 |
+
{
|
669 |
+
"cell_type": "code",
|
670 |
+
"execution_count": 22,
|
671 |
+
"metadata": {},
|
672 |
+
"outputs": [
|
673 |
+
{
|
674 |
+
"name": "stdout",
|
675 |
+
"output_type": "stream",
|
676 |
+
"text": [
|
677 |
+
"DuyTa/Vietnamese_ASR\n"
|
678 |
+
]
|
679 |
+
}
|
680 |
+
],
|
681 |
+
"source": [
|
682 |
+
"peft_model_id = \"DuyTa/Vietnamese_ASR\"\n",
|
683 |
+
"model.push_to_hub(peft_model_id)\n",
|
684 |
+
"print(peft_model_id)"
|
685 |
+
]
|
686 |
+
},
|
687 |
+
{
|
688 |
+
"cell_type": "code",
|
689 |
+
"execution_count": 10,
|
690 |
+
"metadata": {},
|
691 |
+
"outputs": [],
|
692 |
+
"source": [
|
693 |
+
"from peft import PeftModel, PeftConfig\n",
|
694 |
+
"from transformers import WhisperForConditionalGeneration\n",
|
695 |
+
"peft_model_id = \"./Vietnamese_ASR\"\n",
|
696 |
+
"peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
|
697 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\n",
|
698 |
+
" peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
|
699 |
+
")\n",
|
700 |
+
"model = PeftModel.from_pretrained(model, peft_model_id)"
|
701 |
+
]
|
702 |
+
},
|
703 |
+
{
|
704 |
+
"cell_type": "code",
|
705 |
+
"execution_count": 11,
|
706 |
+
"metadata": {},
|
707 |
+
"outputs": [
|
708 |
+
{
|
709 |
+
"name": "stderr",
|
710 |
+
"output_type": "stream",
|
711 |
+
"text": [
|
712 |
+
" 0%| | 0/88 [00:00<?, ?it/s]"
|
713 |
+
]
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"name": "stderr",
|
717 |
+
"output_type": "stream",
|
718 |
+
"text": [
|
719 |
+
"/home/tesla/miniconda3/envs/DUY/lib/python3.9/site-packages/bitsandbytes/autograd/_functions.py:322: UserWarning: MatMul8bitLt: inputs will be cast from torch.float32 to float16 during quantization\n",
|
720 |
+
" warnings.warn(f\"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization\")\n",
|
721 |
+
"100%|██████████| 88/88 [59:07<00:00, 40.31s/it]"
|
722 |
+
]
|
723 |
+
},
|
724 |
+
{
|
725 |
+
"name": "stdout",
|
726 |
+
"output_type": "stream",
|
727 |
+
"text": [
|
728 |
+
"wer=15.57082617523036\n"
|
729 |
+
]
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"name": "stderr",
|
733 |
+
"output_type": "stream",
|
734 |
+
"text": [
|
735 |
+
"\n"
|
736 |
+
]
|
737 |
+
}
|
738 |
+
],
|
739 |
+
"source": [
|
740 |
+
"from torch.utils.data import DataLoader\n",
|
741 |
+
"from tqdm import tqdm\n",
|
742 |
+
"import numpy as np\n",
|
743 |
+
"import gc\n",
|
744 |
+
"\n",
|
745 |
+
"eval_dataloader = DataLoader(concat[\"test\"], batch_size=24, collate_fn=data_collator)\n",
|
746 |
+
"\n",
|
747 |
+
"model.eval()\n",
|
748 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
749 |
+
" with torch.cuda.amp.autocast():\n",
|
750 |
+
" with torch.no_grad():\n",
|
751 |
+
" generated_tokens = (\n",
|
752 |
+
" model.generate(\n",
|
753 |
+
" input_features=batch[\"input_features\"].to(\"cuda\"),\n",
|
754 |
+
" decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),\n",
|
755 |
+
" max_new_tokens=255,\n",
|
756 |
+
" )\n",
|
757 |
+
" .cpu()\n",
|
758 |
+
" .numpy()\n",
|
759 |
+
" )\n",
|
760 |
+
" labels = batch[\"labels\"].cpu().numpy()\n",
|
761 |
+
" labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
|
762 |
+
" decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
|
763 |
+
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
764 |
+
" metric.add_batch(\n",
|
765 |
+
" predictions=decoded_preds,\n",
|
766 |
+
" references=decoded_labels,\n",
|
767 |
+
" )\n",
|
768 |
+
" del generated_tokens, labels, batch\n",
|
769 |
+
" gc.collect()\n",
|
770 |
+
"wer = 100 * metric.compute()\n",
|
771 |
+
"print(f\"{wer=}\")"
|
772 |
+
]
|
773 |
+
},
|
774 |
+
{
|
775 |
+
"cell_type": "markdown",
|
776 |
+
"metadata": {},
|
777 |
+
"source": [
|
778 |
+
"## Text Norm"
|
779 |
+
]
|
780 |
+
},
|
781 |
+
{
|
782 |
+
"cell_type": "code",
|
783 |
+
"execution_count": null,
|
784 |
+
"metadata": {},
|
785 |
+
"outputs": [],
|
786 |
+
"source": [
|
787 |
+
"# using Vietnamese text normalization after take whisper out token"
|
788 |
+
]
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"cell_type": "code",
|
792 |
+
"execution_count": 12,
|
793 |
+
"metadata": {},
|
794 |
+
"outputs": [
|
795 |
+
{
|
796 |
+
"name": "stderr",
|
797 |
+
"output_type": "stream",
|
798 |
+
"text": [
|
799 |
+
"100%|██████████| 88/88 [58:18<00:00, 39.75s/it]"
|
800 |
+
]
|
801 |
+
},
|
802 |
+
{
|
803 |
+
"name": "stdout",
|
804 |
+
"output_type": "stream",
|
805 |
+
"text": [
|
806 |
+
"normalized_wer=14.601364195460137\n"
|
807 |
+
]
|
808 |
+
},
|
809 |
+
{
|
810 |
+
"name": "stderr",
|
811 |
+
"output_type": "stream",
|
812 |
+
"text": [
|
813 |
+
"\n"
|
814 |
+
]
|
815 |
+
}
|
816 |
+
],
|
817 |
+
"source": [
|
818 |
+
"from torch.utils.data import DataLoader\n",
|
819 |
+
"from tqdm import tqdm\n",
|
820 |
+
"import numpy as np\n",
|
821 |
+
"import gc\n",
|
822 |
+
"from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n",
|
823 |
+
"normalizer = BasicTextNormalizer()\n",
|
824 |
+
"forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)\n",
|
825 |
+
"\n",
|
826 |
+
"model.eval()\n",
|
827 |
+
"for step, batch in enumerate(tqdm(eval_dataloader)):\n",
|
828 |
+
" with torch.cuda.amp.autocast():\n",
|
829 |
+
" with torch.no_grad():\n",
|
830 |
+
" generated_tokens= model.generate(input_features=batch[\"input_features\"].to(\"cuda\"),\n",
|
831 |
+
" forced_decoder_ids=forced_decoder_ids,\n",
|
832 |
+
" max_new_tokens=255).cpu().numpy()\n",
|
833 |
+
" labels = batch[\"labels\"].cpu().numpy()\n",
|
834 |
+
" labels = np.where(labels != -100, labels, processor.tokenizer.pad_token_id)\n",
|
835 |
+
" decoded_preds = processor.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
|
836 |
+
" decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
837 |
+
" metric.add_batch(\n",
|
838 |
+
" predictions=[normalizer(pred).strip() for pred in decoded_preds],\n",
|
839 |
+
" references=[normalizer(label).strip() for label in decoded_labels],\n",
|
840 |
+
" )\n",
|
841 |
+
" # if step==0:\n",
|
842 |
+
" # break\n",
|
843 |
+
" del generated_tokens, labels, batch\n",
|
844 |
+
" gc.collect()\n",
|
845 |
+
"normalized_wer = 100 * metric.compute()\n",
|
846 |
+
"print(f\"{normalized_wer=}\")"
|
847 |
+
]
|
848 |
+
},
|
849 |
+
{
|
850 |
+
"cell_type": "code",
|
851 |
+
"execution_count": 13,
|
852 |
+
"metadata": {},
|
853 |
+
"outputs": [
|
854 |
+
{
|
855 |
+
"name": "stdout",
|
856 |
+
"output_type": "stream",
|
857 |
+
"text": [
|
858 |
+
"pred='toàn bộ phi hành đoàn đã bị giết chết khiến con tàu quân sự đậm thẳng vào ishimura', label='toàn bộ phi hành đoàn đã bị giết chết khiến con tàu quân sự đâm thẳng vào ishimura'\n",
|
859 |
+
"pred='đủ kinh nghiệm để mình quản lý nhân viên và mình làm sao để mình đưa ra một cái dịch vụ tốt nhất', label='đủ kinh nghiệm để mình quản lý nhân viên và mình làm sao để mình đưa ra một cái dịch vụ tốt nhất'\n",
|
860 |
+
"pred='nói một trong một cái chương trình trong tương lai về ngành thu y thì ở mỹ tất cả các đại học nào lớn đều có ngành thu y hết', label='<unk> nói một trong một cái chương trình trong tương lai về ngành thú y thì ở mỹ tất cả các đại học nào lớn đều có ngành thú y hết'\n",
|
861 |
+
"pred='phấn đấu đến năm hai ngàn không trăm hai mười có từ tám trăm đến một ngàn kinh nghiệp tham gia sàn sâu dịch thu mại điện tử của bộ công thương và quốc tế năm mươi phần trăm số này vô bắn trên sàn', label='phấn đấu đến năm hai ngàn không trăm hai mươi có từ tám trăm đến một ngàn doanh nghiệp tham gia sàn giao dịch thương mại điện tử của bộ công thương và quốc tế năm mươi phần trăm số này luôn bán trên sàn'\n",
|
862 |
+
"pred='còn trách nhiệm kiểm tra thanh tra là của ủy ban nhân dân các cấp', label='còn trách nhiệm kiểm tra thanh tra là của ủy ban nhân dân các cấp'\n",
|
863 |
+
"pred='vậy mà cậu im lặng khóa trái tìm mình chắc địa dành cho cái gì đó vĩ đại hơn chăng', label='vậy mà cậu im lặng khóa trái tim mình chắc để giành cho cái gì đó vĩ đại hơn chăng'\n",
|
864 |
+
"pred='khi nộp phiếu trả lời trắc nghiệm thí sinh phải ghi tên và danh sách thí sinh nộp bài', label='khi nộp phiếu trả lời trắc nghiệm thí sinh phải ký tên vào danh sách thí sinh nộp bài'\n",
|
865 |
+
"pred='khi nghĩ rằng mình đã khỏi ai ngờ ung thư lại tái phát và tôi đã lắng nghe câu chuyện của tất cả mọi người', label='khi nghĩ rằng mình đã khỏi ai ngờ ung thư lại tái phát và tôi đã lắng nghe câu chuyện của tất cả mọi người'\n",
|
866 |
+
"pred='người cùng ấp là trương thật từng muốn kết giao với giám vì ông từ chối', label='người cùng ấp là trương thực từng muốn kết giao với giám bị ông từ chối'\n",
|
867 |
+
"pred='bài thơ với những dòng thơ rất xúc động như sau', label='bài thơ với những dòng thơ rất xúc động như sau'\n",
|
868 |
+
"pred='công bố chỉ số niềm tin kinh doanh của doanh nghiệp', label='công bố chỉ s�� niềm tin kinh doanh của doanh nghiệp'\n",
|
869 |
+
"pred='khi quanh hồ tổng tới thăng lông đúng lúc tô trung tự đang đánh nhau to với đồ quảng', label='khi quân hộ tống tới thăng long đúng lúc tô trung từ đang đánh nhau to với đỗ quảng'\n",
|
870 |
+
"pred='chứ không lẽ bây giờ kêu men trai', label='chứ hổng lẽ bây giờ kêu mê trai'\n",
|
871 |
+
"pred='trong thời gian đó anh ấy hãy tâm sự với tôi', label='trong thời gian đó anh ấy hay tâm sự với tôi'\n",
|
872 |
+
"pred='mi mo sa lại cho màu sắc lá đẹp không cần dùng đến màu nhuộng hoàng sực dỡ từ duy nhẹ bé đó đã giúp vườn mi mo sa của bà nhanh chóng đem lại lợi nhuộn', label='mi mo sa lại cho màu sắc lá đẹp không cần dùng đến màu nhuộm hoa rực rỡ tư duy nhạy bén đó đã giúp vườn mi mo sa của bà nhanh chóng đem lại lợi nhuận'\n",
|
873 |
+
"pred='chơi tìm kiếm tài năng thiên đỉnh god thai lần thế các táo đâu hết cả rồi', label='chơi tìm kiếm tài năng thiên đình gót thai lừn thế các táo đâu hết cả rồi'\n",
|
874 |
+
"pred='dù đức và pháp bất đồng sâu sắc nhưng chính kiến của họ thì đều sai', label='dù đức và pháp bất đồng sâu sắc nhưng chính kiến của họ thì đều sai'\n",
|
875 |
+
"pred='đại ca bảo không hình anh ra mà ngồi anh đánh đi', label='đại ca bảo buông anh ra mà thôi anh'\n",
|
876 |
+
"pred='khi mà mang thai bác thị cũng cảnh báo rồi', label='khi mà mang thai thì bác sĩ cũng cảnh báo rồi'\n",
|
877 |
+
"pred='là tăng giảm thất thường và đột xuất kéo dài', label='mà tăng giảm thất thường và đột xuất kéo dài'\n"
|
878 |
+
]
|
879 |
+
}
|
880 |
+
],
|
881 |
+
"source": [
|
882 |
+
"for pred,label in zip(decoded_preds,decoded_labels):\n",
|
883 |
+
" print(f\"{pred=}, {label=}\")"
|
884 |
+
]
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"cell_type": "code",
|
888 |
+
"execution_count": 14,
|
889 |
+
"metadata": {},
|
890 |
+
"outputs": [],
|
891 |
+
"source": [
|
892 |
+
"import torch\n",
|
893 |
+
"from transformers import (\n",
|
894 |
+
" AutomaticSpeechRecognitionPipeline,\n",
|
895 |
+
" WhisperForConditionalGeneration,\n",
|
896 |
+
" WhisperTokenizer,\n",
|
897 |
+
" WhisperProcessor,\n",
|
898 |
+
")\n",
|
899 |
+
"from peft import PeftModel, PeftConfig\n",
|
900 |
+
"\n",
|
901 |
+
"\n",
|
902 |
+
"peft_model_id = \"./Vietnamese_ASR\"\n",
|
903 |
+
"language = \"Vietnamese\"\n",
|
904 |
+
"task = \"transcribe\"\n",
|
905 |
+
"\n",
|
906 |
+
"peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
|
907 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\n",
|
908 |
+
" peft_config.base_model_name_or_path\n",
|
909 |
+
")\n",
|
910 |
+
"model = PeftModel.from_pretrained(model, peft_model_id)\n",
|
911 |
+
"merged_model = model.merge_and_unload()\n"
|
912 |
+
]
|
913 |
+
},
|
914 |
+
{
|
915 |
+
"cell_type": "code",
|
916 |
+
"execution_count": 16,
|
917 |
+
"metadata": {},
|
918 |
+
"outputs": [],
|
919 |
+
"source": [
|
920 |
+
"merged_model.save_pretrained(\"./Vietnamese_ASR/merged\")"
|
921 |
+
]
|
922 |
+
},
|
923 |
+
{
|
924 |
+
"cell_type": "code",
|
925 |
+
"execution_count": 17,
|
926 |
+
"metadata": {},
|
927 |
+
"outputs": [],
|
928 |
+
"source": [
|
929 |
+
"from transformers import WhisperTokenizer\n",
|
930 |
+
"\n",
|
931 |
+
"tokenizer = WhisperTokenizer.from_pretrained('openai/whisper-medium', language=language, task=task)"
|
932 |
+
]
|
933 |
+
},
|
934 |
+
{
|
935 |
+
"cell_type": "code",
|
936 |
+
"execution_count": 18,
|
937 |
+
"metadata": {},
|
938 |
+
"outputs": [
|
939 |
+
{
|
940 |
+
"data": {
|
941 |
+
"text/plain": [
|
942 |
+
"('./Vietnamese_ASR/merged/tokenizer_config.json',\n",
|
943 |
+
" './Vietnamese_ASR/merged/special_tokens_map.json',\n",
|
944 |
+
" './Vietnamese_ASR/merged/vocab.json',\n",
|
945 |
+
" './Vietnamese_ASR/merged/merges.txt',\n",
|
946 |
+
" './Vietnamese_ASR/merged/normalizer.json',\n",
|
947 |
+
" './Vietnamese_ASR/merged/added_tokens.json')"
|
948 |
+
]
|
949 |
+
},
|
950 |
+
"execution_count": 18,
|
951 |
+
"metadata": {},
|
952 |
+
"output_type": "execute_result"
|
953 |
+
}
|
954 |
+
],
|
955 |
+
"source": [
|
956 |
+
"tokenizer.save_pretrained('./Vietnamese_ASR/merged')"
|
957 |
+
]
|
958 |
+
},
|
959 |
+
{
|
960 |
+
"cell_type": "code",
|
961 |
+
"execution_count": 19,
|
962 |
+
"metadata": {},
|
963 |
+
"outputs": [
|
964 |
+
{
|
965 |
+
"name": "stdout",
|
966 |
+
"output_type": "stream",
|
967 |
+
"text": [
|
968 |
+
"/bin/bash: /home/tesla/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n"
|
969 |
+
]
|
970 |
+
}
|
971 |
+
],
|
972 |
+
"source": [
|
973 |
+
"!ct2-transformers-converter --model ./Vietnamese_ASR/merged --output_dir ./Vietnamese_ASR/ct2ranslate"
|
974 |
+
]
|
975 |
+
},
|
976 |
+
{
|
977 |
+
"cell_type": "code",
|
978 |
+
"execution_count": 20,
|
979 |
+
"metadata": {},
|
980 |
+
"outputs": [
|
981 |
+
{
|
982 |
+
"name": "stdout",
|
983 |
+
"output_type": "stream",
|
984 |
+
"text": [
|
985 |
+
"/bin/bash: /home/tesla/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n"
|
986 |
+
]
|
987 |
+
}
|
988 |
+
],
|
989 |
+
"source": [
|
990 |
+
"!ct2-transformers-converter --model ./Vietnamese_ASR/merged --output_dir ./Vietnamese_ASR/ct2ranslate/quantized --quantization float16"
|
991 |
+
]
|
992 |
+
},
|
993 |
+
{
|
994 |
+
"cell_type": "code",
|
995 |
+
"execution_count": 6,
|
996 |
+
"metadata": {},
|
997 |
+
"outputs": [
|
998 |
+
{
|
999 |
+
"name": "stderr",
|
1000 |
+
"output_type": "stream",
|
1001 |
+
"text": [
|
1002 |
+
"/media/tesla/New Volume/DEMO/DUY/Vietnamese_ASR/Vietnamese_ASR/src/Vietnamese_ASR is already a clone of https://huggingface.co/DuyTa/Vietnamese_ASR. Make sure you pull the latest changes with `repo.git_pull()`.\n"
|
1003 |
+
]
|
1004 |
+
}
|
1005 |
+
],
|
1006 |
+
"source": [
|
1007 |
+
"from huggingface_hub import Repository\n",
|
1008 |
+
"repo = Repository(local_dir=\"\", clone_from='DuyTa/Vietnamese_ASR')"
|
1009 |
+
]
|
1010 |
+
},
|
1011 |
+
{
|
1012 |
+
"cell_type": "code",
|
1013 |
+
"execution_count": 6,
|
1014 |
+
"metadata": {},
|
1015 |
+
"outputs": [
|
1016 |
+
{
|
1017 |
+
"data": {
|
1018 |
+
"application/vnd.jupyter.widget-view+json": {
|
1019 |
+
"model_id": "061e5ea903e04d2e95bc3ff8a8de434b",
|
1020 |
+
"version_major": 2,
|
1021 |
+
"version_minor": 0
|
1022 |
+
},
|
1023 |
+
"text/plain": [
|
1024 |
+
"Clean file runs/Aug17_22-42-43_tesla-T4/events.out.tfevents.1692289257.tesla-T4.201346.0: 14%|#4 | 1.0…"
|
1025 |
+
]
|
1026 |
+
},
|
1027 |
+
"metadata": {},
|
1028 |
+
"output_type": "display_data"
|
1029 |
+
}
|
1030 |
+
],
|
1031 |
+
"source": [
|
1032 |
+
"repo.git_pull(rebase=True)"
|
1033 |
+
]
|
1034 |
+
},
|
1035 |
+
{
|
1036 |
+
"cell_type": "code",
|
1037 |
+
"execution_count": null,
|
1038 |
+
"metadata": {},
|
1039 |
+
"outputs": [],
|
1040 |
+
"source": [
|
1041 |
+
"repo.git_add(\".\")\n",
|
1042 |
+
"repo.git_commit(commit_message=\"3 epochs finetuning and quantized model )\")"
|
1043 |
+
]
|
1044 |
+
},
|
1045 |
+
{
|
1046 |
+
"cell_type": "code",
|
1047 |
+
"execution_count": 8,
|
1048 |
+
"metadata": {},
|
1049 |
+
"outputs": [
|
1050 |
+
{
|
1051 |
+
"name": "stderr",
|
1052 |
+
"output_type": "stream",
|
1053 |
+
"text": [
|
1054 |
+
"Several commits (3) will be pushed upstream.\n",
|
1055 |
+
"The progress bars may be unreliable.\n"
|
1056 |
+
]
|
1057 |
+
},
|
1058 |
+
{
|
1059 |
+
"name": "stderr",
|
1060 |
+
"output_type": "stream",
|
1061 |
+
"text": [
|
1062 |
+
"To https://huggingface.co/DuyTa/Vietnamese_ASR\n",
|
1063 |
+
" 63bacc4..82e8e84 main -> main\n",
|
1064 |
+
"\n"
|
1065 |
+
]
|
1066 |
+
},
|
1067 |
+
{
|
1068 |
+
"data": {
|
1069 |
+
"text/plain": [
|
1070 |
+
"'https://huggingface.co/DuyTa/Vietnamese_ASR/commit/82e8e84fe4f1ffee17eff82c39a163f4b81335d5'"
|
1071 |
+
]
|
1072 |
+
},
|
1073 |
+
"execution_count": 8,
|
1074 |
+
"metadata": {},
|
1075 |
+
"output_type": "execute_result"
|
1076 |
+
}
|
1077 |
+
],
|
1078 |
+
"source": [
|
1079 |
+
"repo.git_push()"
|
1080 |
+
]
|
1081 |
+
},
|
1082 |
+
{
|
1083 |
+
"cell_type": "code",
|
1084 |
+
"execution_count": null,
|
1085 |
+
"metadata": {},
|
1086 |
+
"outputs": [],
|
1087 |
+
"source": [
|
1088 |
+
"merged_model.push_to_hub(\"DuyTa/MITI_Whisper\")"
|
1089 |
+
]
|
1090 |
+
},
|
1091 |
+
{
|
1092 |
+
"cell_type": "code",
|
1093 |
+
"execution_count": null,
|
1094 |
+
"metadata": {},
|
1095 |
+
"outputs": [],
|
1096 |
+
"source": [
|
1097 |
+
"import torch\n",
|
1098 |
+
"import gradio as gr\n",
|
1099 |
+
"from transformers import (\n",
|
1100 |
+
" AutomaticSpeechRecognitionPipeline,\n",
|
1101 |
+
" WhisperForConditionalGeneration,\n",
|
1102 |
+
" WhisperTokenizer,\n",
|
1103 |
+
" WhisperProcessor,\n",
|
1104 |
+
")\n",
|
1105 |
+
"from peft import PeftModel, PeftConfig\n",
|
1106 |
+
"\n",
|
1107 |
+
"\n",
|
1108 |
+
"peft_model_id = \"DuyTa/MITI_Whisper\"\n",
|
1109 |
+
"language = \"Vietnamese\"\n",
|
1110 |
+
"task = \"transcribe\"\n",
|
1111 |
+
"peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
|
1112 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\n",
|
1113 |
+
" peft_config.base_model_name_or_path, load_in_8bit=True, device_map=\"auto\"\n",
|
1114 |
+
")\n",
|
1115 |
+
"\n",
|
1116 |
+
"model = PeftModel.from_pretrained(model, peft_model_id)\n",
|
1117 |
+
"tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
|
1118 |
+
"processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
|
1119 |
+
"feature_extractor = processor.feature_extractor\n",
|
1120 |
+
"forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)\n",
|
1121 |
+
"pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
|
1122 |
+
"\n",
|
1123 |
+
"\n",
|
1124 |
+
"def transcribe(audio):\n",
|
1125 |
+
" with torch.cuda.amp.autocast():\n",
|
1126 |
+
" text = pipe(audio, generate_kwargs={\"forced_decoder_ids\": forced_decoder_ids}, max_new_tokens=255)[\"text\"]\n",
|
1127 |
+
" return text\n",
|
1128 |
+
"\n",
|
1129 |
+
"\n",
|
1130 |
+
"iface = gr.Interface(\n",
|
1131 |
+
" fn=transcribe,\n",
|
1132 |
+
" inputs=gr.Audio(source=\"upload\", type=\"filepath\"),\n",
|
1133 |
+
" outputs=\"text\",\n",
|
1134 |
+
" title=\"PEFT LoRA\",\n",
|
1135 |
+
" description=\"Realtime demo for Vietnamese speech recognition using `PEFT-LoRA+INT8` fine-tuned Whisper Medium .\",\n",
|
1136 |
+
")\n",
|
1137 |
+
"\n",
|
1138 |
+
"iface.launch(share=True)"
|
1139 |
+
]
|
1140 |
+
}
|
1141 |
+
],
|
1142 |
+
"metadata": {
|
1143 |
+
"kernelspec": {
|
1144 |
+
"display_name": "DUY",
|
1145 |
+
"language": "python",
|
1146 |
+
"name": "python3"
|
1147 |
+
},
|
1148 |
+
"language_info": {
|
1149 |
+
"codemirror_mode": {
|
1150 |
+
"name": "ipython",
|
1151 |
+
"version": 3
|
1152 |
+
},
|
1153 |
+
"file_extension": ".py",
|
1154 |
+
"mimetype": "text/x-python",
|
1155 |
+
"name": "python",
|
1156 |
+
"nbconvert_exporter": "python",
|
1157 |
+
"pygments_lexer": "ipython3",
|
1158 |
+
"version": "3.9.17"
|
1159 |
+
},
|
1160 |
+
"orig_nbformat": 4
|
1161 |
+
},
|
1162 |
+
"nbformat": 4,
|
1163 |
+
"nbformat_minor": 2
|
1164 |
+
}
|
src/whisperX.ipynb
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 15,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Model was trained with pyannote.audio 0.0.1, yours is 2.1.1. Bad things might happen unless you revert pyannote.audio to 0.x.\n",
|
13 |
+
"Model was trained with torch 1.10.0+cu102, yours is 2.0.0+cu118. Bad things might happen unless you revert torch to 1.x.\n",
|
14 |
+
"CPU times: user 826 ms, sys: 96.7 ms, total: 923 ms\n",
|
15 |
+
"Wall time: 831 ms\n",
|
16 |
+
"[{'text': 'đó là ước vọng của nguyễn ái quốc từ những năm hai mươi của thế kỷ trước về một nhà nước việt nam độc lập dân chủ', 'start': 0.008, 'end': 6.556}]\n"
|
17 |
+
]
|
18 |
+
}
|
19 |
+
],
|
20 |
+
"source": [
|
21 |
+
"import whisperx\n",
|
22 |
+
"import gc \n",
|
23 |
+
"\n",
|
24 |
+
"device = \"cuda\" \n",
|
25 |
+
"audio_file = \"6.wav\"\n",
|
26 |
+
"batch_size = 16 \n",
|
27 |
+
"compute_type = \"float16\" # change to \"int8\" if low on GPU mem (may reduce accuracy)\n",
|
28 |
+
"model_path = \"./Vietnamese_ASR/ct2ranslate\"\n",
|
29 |
+
"# 1. Transcribe with original whisper (batched)\n",
|
30 |
+
"model = whisperx.load_model(model_path, device, compute_type=compute_type,language='vi')\n",
|
31 |
+
"\n",
|
32 |
+
"audio = whisperx.load_audio(audio_file)\n",
|
33 |
+
"%time result = model.transcribe(audio, batch_size=batch_size)\n",
|
34 |
+
"print(result[\"segments\"]) # before alignment\n",
|
35 |
+
"\n",
|
36 |
+
"# delete model if low on GPU resources\n",
|
37 |
+
"# import gc; gc.collect(); torch.cuda.empty_cache(); del model\n",
|
38 |
+
"\n"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import gc; gc.collect()\n",
|
48 |
+
"import torch\n",
|
49 |
+
"torch.cuda.empty_cache(); del model"
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"cell_type": "code",
|
54 |
+
"execution_count": 9,
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [
|
57 |
+
{
|
58 |
+
"name": "stderr",
|
59 |
+
"output_type": "stream",
|
60 |
+
"text": [
|
61 |
+
"Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at nguyenvulebinh/wav2vec2-base-vi and are newly initialized: ['lm_head.weight', 'lm_head.bias']\n",
|
62 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"name": "stdout",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"[{'start': 0.008, 'end': 2.396, 'text': 'với một người đi làm thuê như anh số tiền kiếm được chưa đủ để thoả mãn nhu cầu cá nhân nói gì đến chăm lo cho gia đình', 'words': [{'word': 'với', 'start': 0.008, 'end': 0.068, 'score': 0.01}, {'word': 'một', 'start': 0.088, 'end': 0.148, 'score': 0.01}, {'word': 'người', 'start': 0.169, 'end': 0.269, 'score': 0.011}, {'word': 'đi', 'start': 0.289, 'end': 0.329, 'score': 0.01}, {'word': 'làm', 'start': 0.349, 'end': 0.409, 'score': 0.011}, {'word': 'thuê', 'start': 0.429, 'end': 0.51, 'score': 0.01}, {'word': 'như', 'start': 0.53, 'end': 0.59, 'score': 0.012}, {'word': 'anh', 'start': 0.61, 'end': 0.67, 'score': 0.01}, {'word': 'số', 'start': 0.69, 'end': 0.73, 'score': 0.01}, {'word': 'tiền', 'start': 0.75, 'end': 0.831, 'score': 0.01}, {'word': 'kiếm', 'start': 0.851, 'end': 0.931, 'score': 0.01}, {'word': 'được', 'start': 0.951, 'end': 1.031, 'score': 0.01}, {'word': 'chưa', 'start': 1.051, 'end': 1.132, 'score': 0.01}, {'word': 'đủ', 'start': 1.152, 'end': 1.192, 'score': 0.01}, {'word': 'để', 'start': 1.212, 'end': 1.252, 'score': 0.01}, {'word': 'thoả', 'start': 1.272, 'end': 1.353, 'score': 0.01}, {'word': 'mãn', 'start': 1.373, 'end': 1.433, 'score': 0.011}, {'word': 'nhu', 'start': 1.453, 'end': 1.513, 'score': 0.011}, {'word': 'cầu', 'start': 1.533, 'end': 1.593, 'score': 0.011}, {'word': 'cá', 'start': 1.613, 'end': 1.654, 'score': 0.01}, {'word': 'nhân', 'start': 1.674, 'end': 1.754, 'score': 0.011}, {'word': 'nói', 'start': 1.774, 'end': 1.834, 'score': 0.01}, {'word': 'gì', 'start': 1.854, 'end': 1.894, 'score': 0.011}, {'word': 'đến', 'start': 1.914, 'end': 1.975, 'score': 0.01}, {'word': 'chăm', 'start': 1.995, 'end': 2.075, 'score': 0.011}, {'word': 'lo', 'start': 2.095, 'end': 2.135, 'score': 0.009}, {'word': 'cho', 'start': 2.155, 'end': 2.215, 'score': 0.011}, {'word': 'gia', 'start': 2.235, 'end': 2.296, 'score': 0.01}, {'word': 'đình', 'start': 2.316, 'end': 2.396, 'score': 0.011}]}]\n"
|
70 |
+
]
|
71 |
+
}
|
72 |
+
],
|
73 |
+
"source": [
|
74 |
+
"# 2. Align whisper output\n",
|
75 |
+
"device = \"cuda\" \n",
|
76 |
+
"audio_file = \"audio.wav\"\n",
|
77 |
+
"batch_size = 16 \n",
|
78 |
+
"compute_type = \"float16\" # change to \"int8\" if low on GPU mem (may reduce accuracy)\n",
|
79 |
+
"model_path = \"./Vietnamese_ASR/ct2ranslate\"\n",
|
80 |
+
"model_a, metadata = whisperx.load_align_model(language_code=\"vi\" ,device=device)\n",
|
81 |
+
"result = whisperx.align(result[\"segments\"], model_a, metadata, audio, device, return_char_alignments=False)\n",
|
82 |
+
"\n",
|
83 |
+
"print(result[\"segments\"]) # after alignment\n",
|
84 |
+
"\n",
|
85 |
+
"# delete model if low on GPU resources\n",
|
86 |
+
"import gc; gc.collect(); torch.cuda.empty_cache(); del model_a\n",
|
87 |
+
"\n"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"# 3. Assign speaker labels\n",
|
97 |
+
"diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)\n",
|
98 |
+
"\n",
|
99 |
+
"# add min/max number of speakers if known\n",
|
100 |
+
"diarize_segments = diarize_model(audio)\n",
|
101 |
+
"# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)\n",
|
102 |
+
"\n",
|
103 |
+
"result = whisperx.assign_word_speakers(diarize_segments, result)\n",
|
104 |
+
"print(diarize_segments)\n",
|
105 |
+
"print(result[\"segments\"]) # segments are now assigned speaker IDs"
|
106 |
+
]
|
107 |
+
}
|
108 |
+
],
|
109 |
+
"metadata": {
|
110 |
+
"kernelspec": {
|
111 |
+
"display_name": "DUY",
|
112 |
+
"language": "python",
|
113 |
+
"name": "python3"
|
114 |
+
},
|
115 |
+
"language_info": {
|
116 |
+
"codemirror_mode": {
|
117 |
+
"name": "ipython",
|
118 |
+
"version": 3
|
119 |
+
},
|
120 |
+
"file_extension": ".py",
|
121 |
+
"mimetype": "text/x-python",
|
122 |
+
"name": "python",
|
123 |
+
"nbconvert_exporter": "python",
|
124 |
+
"pygments_lexer": "ipython3",
|
125 |
+
"version": "3.9.17"
|
126 |
+
},
|
127 |
+
"orig_nbformat": 4
|
128 |
+
},
|
129 |
+
"nbformat": 4,
|
130 |
+
"nbformat_minor": 2
|
131 |
+
}
|
src/whisper_quant.py
ADDED
@@ -0,0 +1,995 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import zlib
|
5 |
+
|
6 |
+
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import ctranslate2
|
9 |
+
import numpy as np
|
10 |
+
import tokenizers
|
11 |
+
|
12 |
+
from faster_whisper.audio import decode_audio
|
13 |
+
from faster_whisper.feature_extractor import FeatureExtractor
|
14 |
+
from faster_whisper.tokenizer import Tokenizer
|
15 |
+
from download_quantized import download_model, format_timestamp, get_logger
|
16 |
+
from faster_whisper.vad import (
|
17 |
+
SpeechTimestampsMap,
|
18 |
+
VadOptions,
|
19 |
+
collect_chunks,
|
20 |
+
get_speech_timestamps,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class Word(NamedTuple):
|
25 |
+
start: float
|
26 |
+
end: float
|
27 |
+
word: str
|
28 |
+
probability: float
|
29 |
+
|
30 |
+
|
31 |
+
class Segment(NamedTuple):
|
32 |
+
id: int
|
33 |
+
seek: int
|
34 |
+
start: float
|
35 |
+
end: float
|
36 |
+
text: str
|
37 |
+
tokens: List[int]
|
38 |
+
temperature: float
|
39 |
+
avg_logprob: float
|
40 |
+
compression_ratio: float
|
41 |
+
no_speech_prob: float
|
42 |
+
words: Optional[List[Word]]
|
43 |
+
|
44 |
+
|
45 |
+
class TranscriptionOptions(NamedTuple):
|
46 |
+
beam_size: int
|
47 |
+
best_of: int
|
48 |
+
patience: float
|
49 |
+
length_penalty: float
|
50 |
+
repetition_penalty: float
|
51 |
+
log_prob_threshold: Optional[float]
|
52 |
+
no_speech_threshold: Optional[float]
|
53 |
+
compression_ratio_threshold: Optional[float]
|
54 |
+
condition_on_previous_text: bool
|
55 |
+
prompt_reset_on_temperature: float
|
56 |
+
temperatures: List[float]
|
57 |
+
initial_prompt: Optional[Union[str, Iterable[int]]]
|
58 |
+
prefix: Optional[str]
|
59 |
+
suppress_blank: bool
|
60 |
+
suppress_tokens: Optional[List[int]]
|
61 |
+
without_timestamps: bool
|
62 |
+
max_initial_timestamp: float
|
63 |
+
word_timestamps: bool
|
64 |
+
prepend_punctuations: str
|
65 |
+
append_punctuations: str
|
66 |
+
|
67 |
+
|
68 |
+
class TranscriptionInfo(NamedTuple):
|
69 |
+
language: str
|
70 |
+
language_probability: float
|
71 |
+
duration: float
|
72 |
+
all_language_probs: Optional[List[Tuple[str, float]]]
|
73 |
+
transcription_options: TranscriptionOptions
|
74 |
+
vad_options: VadOptions
|
75 |
+
|
76 |
+
|
77 |
+
class WhisperModel:
|
78 |
+
def __init__(
|
79 |
+
self,
|
80 |
+
model_size_or_path: str,
|
81 |
+
device: str = "auto",
|
82 |
+
device_index: Union[int, List[int]] = 0,
|
83 |
+
compute_type: str = "default",
|
84 |
+
cpu_threads: int = 0,
|
85 |
+
num_workers: int = 1,
|
86 |
+
download_root: Optional[str] = None,
|
87 |
+
local_files_only: bool = False,
|
88 |
+
):
|
89 |
+
"""Initializes the Whisper model.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
93 |
+
small, small.en, medium, medium.en, large-v1, or large-v2), a path to a converted
|
94 |
+
model directory, or a CTranslate2-converted Whisper model ID from the Hugging Face Hub.
|
95 |
+
When a size or a model ID is configured, the converted model is downloaded
|
96 |
+
from the Hugging Face Hub.
|
97 |
+
device: Device to use for computation ("cpu", "cuda", "auto").
|
98 |
+
device_index: Device ID to use.
|
99 |
+
The model can also be loaded on multiple GPUs by passing a list of IDs
|
100 |
+
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
|
101 |
+
when transcribe() is called from multiple Python threads (see also num_workers).
|
102 |
+
compute_type: Type to use for computation.
|
103 |
+
See https://opennmt.net/CTranslate2/quantization.html.
|
104 |
+
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
105 |
+
A non zero value overrides the OMP_NUM_THREADS environment variable.
|
106 |
+
num_workers: When transcribe() is called from multiple Python threads,
|
107 |
+
having multiple workers enables true parallelism when running the model
|
108 |
+
(concurrent calls to self.model.generate() will run in parallel).
|
109 |
+
This can improve the global throughput at the cost of increased memory usage.
|
110 |
+
download_root: Directory where the models should be saved. If not set, the models
|
111 |
+
are saved in the standard Hugging Face cache directory.
|
112 |
+
local_files_only: If True, avoid downloading the file and return the path to the
|
113 |
+
local cached file if it exists.
|
114 |
+
"""
|
115 |
+
self.logger = get_logger()
|
116 |
+
|
117 |
+
if os.path.isdir(model_size_or_path):
|
118 |
+
model_path = model_size_or_path
|
119 |
+
else:
|
120 |
+
model_path = download_model(
|
121 |
+
model_size_or_path,
|
122 |
+
local_files_only=local_files_only,
|
123 |
+
cache_dir=download_root,
|
124 |
+
)
|
125 |
+
|
126 |
+
self.model = ctranslate2.models.Whisper(
|
127 |
+
model_path,
|
128 |
+
device=device,
|
129 |
+
device_index=device_index,
|
130 |
+
compute_type=compute_type,
|
131 |
+
intra_threads=cpu_threads,
|
132 |
+
inter_threads=num_workers,
|
133 |
+
)
|
134 |
+
|
135 |
+
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
136 |
+
if os.path.isfile(tokenizer_file):
|
137 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
138 |
+
else:
|
139 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
140 |
+
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
141 |
+
)
|
142 |
+
|
143 |
+
self.feature_extractor = FeatureExtractor()
|
144 |
+
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
145 |
+
self.frames_per_second = (
|
146 |
+
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
|
147 |
+
)
|
148 |
+
self.tokens_per_second = (
|
149 |
+
self.feature_extractor.sampling_rate // self.num_samples_per_token
|
150 |
+
)
|
151 |
+
self.input_stride = 2
|
152 |
+
self.time_precision = 0.02
|
153 |
+
self.max_length = 448
|
154 |
+
|
155 |
+
def transcribe(
|
156 |
+
self,
|
157 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
158 |
+
language: Optional[str] = None,
|
159 |
+
task: str = "transcribe",
|
160 |
+
beam_size: int = 5,
|
161 |
+
best_of: int = 5,
|
162 |
+
patience: float = 1,
|
163 |
+
length_penalty: float = 1,
|
164 |
+
repetition_penalty: float = 1,
|
165 |
+
temperature: Union[float, List[float], Tuple[float, ...]] = [
|
166 |
+
0.0,
|
167 |
+
0.2,
|
168 |
+
0.4,
|
169 |
+
0.6,
|
170 |
+
0.8,
|
171 |
+
1.0,
|
172 |
+
],
|
173 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
174 |
+
log_prob_threshold: Optional[float] = -1.0,
|
175 |
+
no_speech_threshold: Optional[float] = 0.6,
|
176 |
+
condition_on_previous_text: bool = True,
|
177 |
+
prompt_reset_on_temperature: float = 0.5,
|
178 |
+
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
179 |
+
prefix: Optional[str] = None,
|
180 |
+
suppress_blank: bool = True,
|
181 |
+
suppress_tokens: Optional[List[int]] = [-1],
|
182 |
+
without_timestamps: bool = False,
|
183 |
+
max_initial_timestamp: float = 1.0,
|
184 |
+
word_timestamps: bool = False,
|
185 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
186 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
187 |
+
vad_filter: bool = False,
|
188 |
+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
189 |
+
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
190 |
+
"""Transcribes an input file.
|
191 |
+
|
192 |
+
Arguments:
|
193 |
+
audio: Path to the input file (or a file-like object), or the audio waveform.
|
194 |
+
language: The language spoken in the audio. It should be a language code such
|
195 |
+
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
196 |
+
of audio.
|
197 |
+
task: Task to execute (transcribe or translate).
|
198 |
+
beam_size: Beam size to use for decoding.
|
199 |
+
best_of: Number of candidates when sampling with non-zero temperature.
|
200 |
+
patience: Beam search patience factor.
|
201 |
+
length_penalty: Exponential length penalty constant.
|
202 |
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
203 |
+
(set > 1 to penalize).
|
204 |
+
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
205 |
+
which will be successively used upon failures according to either
|
206 |
+
`compression_ratio_threshold` or `log_prob_threshold`.
|
207 |
+
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
208 |
+
treat as failed.
|
209 |
+
log_prob_threshold: If the average log probability over sampled tokens is
|
210 |
+
below this value, treat as failed.
|
211 |
+
no_speech_threshold: If the no_speech probability is higher than this value AND
|
212 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
213 |
+
consider the segment as silent.
|
214 |
+
condition_on_previous_text: If True, the previous output of the model is provided
|
215 |
+
as a prompt for the next window; disabling may make the text inconsistent across
|
216 |
+
windows, but the model becomes less prone to getting stuck in a failure loop,
|
217 |
+
such as repetition looping or timestamps going out of sync.
|
218 |
+
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
|
219 |
+
Arg has effect only if condition_on_previous_text is True.
|
220 |
+
initial_prompt: Optional text string or iterable of token ids to provide as a
|
221 |
+
prompt for the first window.
|
222 |
+
prefix: Optional text to provide as a prefix for the first window.
|
223 |
+
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
224 |
+
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
225 |
+
of symbols as defined in the model config.json file.
|
226 |
+
without_timestamps: Only sample text tokens.
|
227 |
+
max_initial_timestamp: The initial timestamp cannot be later than this.
|
228 |
+
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
229 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
230 |
+
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
231 |
+
with the next word
|
232 |
+
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
233 |
+
with the previous word
|
234 |
+
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
235 |
+
without speech. This step is using the Silero VAD model
|
236 |
+
https://github.com/snakers4/silero-vad.
|
237 |
+
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
238 |
+
parameters and default values in the class `VadOptions`).
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
A tuple with:
|
242 |
+
|
243 |
+
- a generator over transcribed segments
|
244 |
+
- an instance of TranscriptionInfo
|
245 |
+
"""
|
246 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
247 |
+
|
248 |
+
if not isinstance(audio, np.ndarray):
|
249 |
+
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
250 |
+
|
251 |
+
duration = audio.shape[0] / sampling_rate
|
252 |
+
|
253 |
+
self.logger.info(
|
254 |
+
"Processing audio with duration %s", format_timestamp(duration)
|
255 |
+
)
|
256 |
+
|
257 |
+
if vad_filter:
|
258 |
+
if vad_parameters is None:
|
259 |
+
vad_parameters = VadOptions()
|
260 |
+
elif isinstance(vad_parameters, dict):
|
261 |
+
vad_parameters = VadOptions(**vad_parameters)
|
262 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
263 |
+
audio = collect_chunks(audio, speech_chunks)
|
264 |
+
|
265 |
+
self.logger.info(
|
266 |
+
"VAD filter removed %s of audio",
|
267 |
+
format_timestamp(duration - (audio.shape[0] / sampling_rate)),
|
268 |
+
)
|
269 |
+
|
270 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
271 |
+
self.logger.debug(
|
272 |
+
"VAD filter kept the following audio segments: %s",
|
273 |
+
", ".join(
|
274 |
+
"[%s -> %s]"
|
275 |
+
% (
|
276 |
+
format_timestamp(chunk["start"] / sampling_rate),
|
277 |
+
format_timestamp(chunk["end"] / sampling_rate),
|
278 |
+
)
|
279 |
+
for chunk in speech_chunks
|
280 |
+
),
|
281 |
+
)
|
282 |
+
|
283 |
+
else:
|
284 |
+
speech_chunks = None
|
285 |
+
|
286 |
+
features = self.feature_extractor(audio)
|
287 |
+
|
288 |
+
encoder_output = None
|
289 |
+
all_language_probs = None
|
290 |
+
|
291 |
+
if language is None:
|
292 |
+
if not self.model.is_multilingual:
|
293 |
+
language = "en"
|
294 |
+
language_probability = 1
|
295 |
+
else:
|
296 |
+
segment = features[:, : self.feature_extractor.nb_max_frames]
|
297 |
+
encoder_output = self.encode(segment)
|
298 |
+
# results is a list of tuple[str, float] with language names and
|
299 |
+
# probabilities.
|
300 |
+
results = self.model.detect_language(encoder_output)[0]
|
301 |
+
# Parse language names to strip out markers
|
302 |
+
all_language_probs = [(token[2:-2], prob) for (token, prob) in results]
|
303 |
+
# Get top language token and probability
|
304 |
+
language, language_probability = all_language_probs[0]
|
305 |
+
|
306 |
+
self.logger.info(
|
307 |
+
"Detected language '%s' with probability %.2f",
|
308 |
+
language,
|
309 |
+
language_probability,
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
language_probability = 1
|
313 |
+
|
314 |
+
tokenizer = Tokenizer(
|
315 |
+
self.hf_tokenizer,
|
316 |
+
self.model.is_multilingual,
|
317 |
+
task=task,
|
318 |
+
language=language,
|
319 |
+
)
|
320 |
+
|
321 |
+
options = TranscriptionOptions(
|
322 |
+
beam_size=beam_size,
|
323 |
+
best_of=best_of,
|
324 |
+
patience=patience,
|
325 |
+
length_penalty=length_penalty,
|
326 |
+
repetition_penalty=repetition_penalty,
|
327 |
+
log_prob_threshold=log_prob_threshold,
|
328 |
+
no_speech_threshold=no_speech_threshold,
|
329 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
330 |
+
condition_on_previous_text=condition_on_previous_text,
|
331 |
+
prompt_reset_on_temperature=prompt_reset_on_temperature,
|
332 |
+
temperatures=(
|
333 |
+
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
334 |
+
),
|
335 |
+
initial_prompt=initial_prompt,
|
336 |
+
prefix=prefix,
|
337 |
+
suppress_blank=suppress_blank,
|
338 |
+
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
339 |
+
without_timestamps=without_timestamps,
|
340 |
+
max_initial_timestamp=max_initial_timestamp,
|
341 |
+
word_timestamps=word_timestamps,
|
342 |
+
prepend_punctuations=prepend_punctuations,
|
343 |
+
append_punctuations=append_punctuations,
|
344 |
+
)
|
345 |
+
|
346 |
+
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
347 |
+
|
348 |
+
if speech_chunks:
|
349 |
+
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
350 |
+
|
351 |
+
info = TranscriptionInfo(
|
352 |
+
language=language,
|
353 |
+
language_probability=language_probability,
|
354 |
+
duration=duration,
|
355 |
+
transcription_options=options,
|
356 |
+
vad_options=vad_parameters,
|
357 |
+
all_language_probs=all_language_probs,
|
358 |
+
)
|
359 |
+
|
360 |
+
return segments, info
|
361 |
+
|
362 |
+
def generate_segments(
|
363 |
+
self,
|
364 |
+
features: np.ndarray,
|
365 |
+
tokenizer: Tokenizer,
|
366 |
+
options: TranscriptionOptions,
|
367 |
+
encoder_output: Optional[ctranslate2.StorageView] = None,
|
368 |
+
) -> Iterable[Segment]:
|
369 |
+
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
370 |
+
idx = 0
|
371 |
+
seek = 0
|
372 |
+
all_tokens = []
|
373 |
+
prompt_reset_since = 0
|
374 |
+
|
375 |
+
if options.initial_prompt is not None:
|
376 |
+
if isinstance(options.initial_prompt, str):
|
377 |
+
initial_prompt = " " + options.initial_prompt.strip()
|
378 |
+
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
379 |
+
all_tokens.extend(initial_prompt_tokens)
|
380 |
+
else:
|
381 |
+
all_tokens.extend(options.initial_prompt)
|
382 |
+
|
383 |
+
last_speech_timestamp = 0.0
|
384 |
+
while seek < content_frames:
|
385 |
+
time_offset = seek * self.feature_extractor.time_per_frame
|
386 |
+
segment = features[:, seek : seek + self.feature_extractor.nb_max_frames]
|
387 |
+
segment_size = min(
|
388 |
+
self.feature_extractor.nb_max_frames, content_frames - seek
|
389 |
+
)
|
390 |
+
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
391 |
+
|
392 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
393 |
+
self.logger.debug(
|
394 |
+
"Processing segment at %s", format_timestamp(time_offset)
|
395 |
+
)
|
396 |
+
|
397 |
+
previous_tokens = all_tokens[prompt_reset_since:]
|
398 |
+
prompt = self.get_prompt(
|
399 |
+
tokenizer,
|
400 |
+
previous_tokens,
|
401 |
+
without_timestamps=options.without_timestamps,
|
402 |
+
prefix=options.prefix if seek == 0 else None,
|
403 |
+
)
|
404 |
+
|
405 |
+
if encoder_output is None:
|
406 |
+
encoder_output = self.encode(segment)
|
407 |
+
|
408 |
+
(
|
409 |
+
result,
|
410 |
+
avg_logprob,
|
411 |
+
temperature,
|
412 |
+
compression_ratio,
|
413 |
+
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
414 |
+
|
415 |
+
if options.no_speech_threshold is not None:
|
416 |
+
# no voice activity check
|
417 |
+
should_skip = result.no_speech_prob > options.no_speech_threshold
|
418 |
+
|
419 |
+
if (
|
420 |
+
options.log_prob_threshold is not None
|
421 |
+
and avg_logprob > options.log_prob_threshold
|
422 |
+
):
|
423 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
424 |
+
should_skip = False
|
425 |
+
|
426 |
+
if should_skip:
|
427 |
+
self.logger.debug(
|
428 |
+
"No speech threshold is met (%f > %f)",
|
429 |
+
result.no_speech_prob,
|
430 |
+
options.no_speech_threshold,
|
431 |
+
)
|
432 |
+
|
433 |
+
# fast-forward to the next segment boundary
|
434 |
+
seek += segment_size
|
435 |
+
encoder_output = None
|
436 |
+
continue
|
437 |
+
|
438 |
+
tokens = result.sequences_ids[0]
|
439 |
+
|
440 |
+
previous_seek = seek
|
441 |
+
current_segments = []
|
442 |
+
|
443 |
+
single_timestamp_ending = (
|
444 |
+
len(tokens) >= 2
|
445 |
+
and tokens[-2] < tokenizer.timestamp_begin
|
446 |
+
and tokens[-1] >= tokenizer.timestamp_begin
|
447 |
+
)
|
448 |
+
|
449 |
+
consecutive_timestamps = [
|
450 |
+
i
|
451 |
+
for i in range(len(tokens))
|
452 |
+
if i > 0
|
453 |
+
and tokens[i] >= tokenizer.timestamp_begin
|
454 |
+
and tokens[i - 1] >= tokenizer.timestamp_begin
|
455 |
+
]
|
456 |
+
|
457 |
+
if len(consecutive_timestamps) > 0:
|
458 |
+
slices = list(consecutive_timestamps)
|
459 |
+
if single_timestamp_ending:
|
460 |
+
slices.append(len(tokens))
|
461 |
+
|
462 |
+
last_slice = 0
|
463 |
+
for current_slice in slices:
|
464 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
465 |
+
start_timestamp_position = (
|
466 |
+
sliced_tokens[0] - tokenizer.timestamp_begin
|
467 |
+
)
|
468 |
+
end_timestamp_position = (
|
469 |
+
sliced_tokens[-1] - tokenizer.timestamp_begin
|
470 |
+
)
|
471 |
+
start_time = (
|
472 |
+
time_offset + start_timestamp_position * self.time_precision
|
473 |
+
)
|
474 |
+
end_time = (
|
475 |
+
time_offset + end_timestamp_position * self.time_precision
|
476 |
+
)
|
477 |
+
|
478 |
+
current_segments.append(
|
479 |
+
dict(
|
480 |
+
seek=seek,
|
481 |
+
start=start_time,
|
482 |
+
end=end_time,
|
483 |
+
tokens=sliced_tokens,
|
484 |
+
)
|
485 |
+
)
|
486 |
+
last_slice = current_slice
|
487 |
+
|
488 |
+
if single_timestamp_ending:
|
489 |
+
# single timestamp at the end means no speech after the last timestamp.
|
490 |
+
seek += segment_size
|
491 |
+
else:
|
492 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
493 |
+
last_timestamp_position = (
|
494 |
+
tokens[last_slice - 1] - tokenizer.timestamp_begin
|
495 |
+
)
|
496 |
+
seek += last_timestamp_position * self.input_stride
|
497 |
+
|
498 |
+
else:
|
499 |
+
duration = segment_duration
|
500 |
+
timestamps = [
|
501 |
+
token for token in tokens if token >= tokenizer.timestamp_begin
|
502 |
+
]
|
503 |
+
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
|
504 |
+
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
|
505 |
+
duration = last_timestamp_position * self.time_precision
|
506 |
+
|
507 |
+
current_segments.append(
|
508 |
+
dict(
|
509 |
+
seek=seek,
|
510 |
+
start=time_offset,
|
511 |
+
end=time_offset + duration,
|
512 |
+
tokens=tokens,
|
513 |
+
)
|
514 |
+
)
|
515 |
+
|
516 |
+
seek += segment_size
|
517 |
+
|
518 |
+
if options.word_timestamps:
|
519 |
+
self.add_word_timestamps(
|
520 |
+
current_segments,
|
521 |
+
tokenizer,
|
522 |
+
encoder_output,
|
523 |
+
segment_size,
|
524 |
+
options.prepend_punctuations,
|
525 |
+
options.append_punctuations,
|
526 |
+
last_speech_timestamp=last_speech_timestamp,
|
527 |
+
)
|
528 |
+
|
529 |
+
word_end_timestamps = [
|
530 |
+
w["end"] for s in current_segments for w in s["words"]
|
531 |
+
]
|
532 |
+
if len(word_end_timestamps) > 0:
|
533 |
+
last_speech_timestamp = word_end_timestamps[-1]
|
534 |
+
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
535 |
+
seek_shift = round(
|
536 |
+
(word_end_timestamps[-1] - time_offset) * self.frames_per_second
|
537 |
+
)
|
538 |
+
|
539 |
+
if seek_shift > 0:
|
540 |
+
seek = previous_seek + seek_shift
|
541 |
+
|
542 |
+
encoder_output = None
|
543 |
+
|
544 |
+
for segment in current_segments:
|
545 |
+
tokens = segment["tokens"]
|
546 |
+
text = tokenizer.decode(tokens)
|
547 |
+
|
548 |
+
if segment["start"] == segment["end"] or not text.strip():
|
549 |
+
continue
|
550 |
+
|
551 |
+
all_tokens.extend(tokens)
|
552 |
+
idx += 1
|
553 |
+
|
554 |
+
yield Segment(
|
555 |
+
id=idx,
|
556 |
+
seek=seek,
|
557 |
+
start=segment["start"],
|
558 |
+
end=segment["end"],
|
559 |
+
text=text,
|
560 |
+
tokens=tokens,
|
561 |
+
temperature=temperature,
|
562 |
+
avg_logprob=avg_logprob,
|
563 |
+
compression_ratio=compression_ratio,
|
564 |
+
no_speech_prob=result.no_speech_prob,
|
565 |
+
words=(
|
566 |
+
[Word(**word) for word in segment["words"]]
|
567 |
+
if options.word_timestamps
|
568 |
+
else None
|
569 |
+
),
|
570 |
+
)
|
571 |
+
|
572 |
+
if (
|
573 |
+
not options.condition_on_previous_text
|
574 |
+
or temperature > options.prompt_reset_on_temperature
|
575 |
+
):
|
576 |
+
if options.condition_on_previous_text:
|
577 |
+
self.logger.debug(
|
578 |
+
"Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
|
579 |
+
temperature,
|
580 |
+
options.prompt_reset_on_temperature,
|
581 |
+
)
|
582 |
+
|
583 |
+
prompt_reset_since = len(all_tokens)
|
584 |
+
|
585 |
+
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
586 |
+
# When the model is running on multiple GPUs, the encoder output should be moved
|
587 |
+
# to the CPU since we don't know which GPU will handle the next job.
|
588 |
+
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
589 |
+
|
590 |
+
features = np.expand_dims(features, 0)
|
591 |
+
features = get_ctranslate2_storage(features)
|
592 |
+
|
593 |
+
return self.model.encode(features, to_cpu=to_cpu)
|
594 |
+
|
595 |
+
def generate_with_fallback(
|
596 |
+
self,
|
597 |
+
encoder_output: ctranslate2.StorageView,
|
598 |
+
prompt: List[int],
|
599 |
+
tokenizer: Tokenizer,
|
600 |
+
options: TranscriptionOptions,
|
601 |
+
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
602 |
+
decode_result = None
|
603 |
+
all_results = []
|
604 |
+
below_cr_threshold_results = []
|
605 |
+
|
606 |
+
max_initial_timestamp_index = int(
|
607 |
+
round(options.max_initial_timestamp / self.time_precision)
|
608 |
+
)
|
609 |
+
|
610 |
+
for temperature in options.temperatures:
|
611 |
+
if temperature > 0:
|
612 |
+
kwargs = {
|
613 |
+
"beam_size": 1,
|
614 |
+
"num_hypotheses": options.best_of,
|
615 |
+
"sampling_topk": 0,
|
616 |
+
"sampling_temperature": temperature,
|
617 |
+
}
|
618 |
+
else:
|
619 |
+
kwargs = {
|
620 |
+
"beam_size": options.beam_size,
|
621 |
+
"patience": options.patience,
|
622 |
+
}
|
623 |
+
|
624 |
+
result = self.model.generate(
|
625 |
+
encoder_output,
|
626 |
+
[prompt],
|
627 |
+
length_penalty=options.length_penalty,
|
628 |
+
repetition_penalty=options.repetition_penalty,
|
629 |
+
max_length=self.max_length,
|
630 |
+
return_scores=True,
|
631 |
+
return_no_speech_prob=True,
|
632 |
+
suppress_blank=options.suppress_blank,
|
633 |
+
suppress_tokens=options.suppress_tokens,
|
634 |
+
max_initial_timestamp_index=max_initial_timestamp_index,
|
635 |
+
**kwargs,
|
636 |
+
)[0]
|
637 |
+
|
638 |
+
tokens = result.sequences_ids[0]
|
639 |
+
|
640 |
+
# Recover the average log prob from the returned score.
|
641 |
+
seq_len = len(tokens)
|
642 |
+
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
643 |
+
avg_logprob = cum_logprob / (seq_len + 1)
|
644 |
+
|
645 |
+
text = tokenizer.decode(tokens).strip()
|
646 |
+
compression_ratio = get_compression_ratio(text)
|
647 |
+
|
648 |
+
decode_result = (
|
649 |
+
result,
|
650 |
+
avg_logprob,
|
651 |
+
temperature,
|
652 |
+
compression_ratio,
|
653 |
+
)
|
654 |
+
all_results.append(decode_result)
|
655 |
+
|
656 |
+
needs_fallback = False
|
657 |
+
|
658 |
+
if options.compression_ratio_threshold is not None:
|
659 |
+
if compression_ratio > options.compression_ratio_threshold:
|
660 |
+
needs_fallback = True # too repetitive
|
661 |
+
|
662 |
+
self.logger.debug(
|
663 |
+
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
|
664 |
+
temperature,
|
665 |
+
compression_ratio,
|
666 |
+
options.compression_ratio_threshold,
|
667 |
+
)
|
668 |
+
else:
|
669 |
+
below_cr_threshold_results.append(decode_result)
|
670 |
+
|
671 |
+
if (
|
672 |
+
options.log_prob_threshold is not None
|
673 |
+
and avg_logprob < options.log_prob_threshold
|
674 |
+
):
|
675 |
+
needs_fallback = True # average log probability is too low
|
676 |
+
|
677 |
+
self.logger.debug(
|
678 |
+
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
679 |
+
temperature,
|
680 |
+
avg_logprob,
|
681 |
+
options.log_prob_threshold,
|
682 |
+
)
|
683 |
+
|
684 |
+
if (
|
685 |
+
options.no_speech_threshold is not None
|
686 |
+
and result.no_speech_prob > options.no_speech_threshold
|
687 |
+
):
|
688 |
+
needs_fallback = False # silence
|
689 |
+
|
690 |
+
if not needs_fallback:
|
691 |
+
break
|
692 |
+
else:
|
693 |
+
# all failed, select the result with the highest average log probability
|
694 |
+
decode_result = max(
|
695 |
+
below_cr_threshold_results or all_results, key=lambda x: x[1]
|
696 |
+
)
|
697 |
+
|
698 |
+
return decode_result
|
699 |
+
|
700 |
+
def get_prompt(
|
701 |
+
self,
|
702 |
+
tokenizer: Tokenizer,
|
703 |
+
previous_tokens: List[int],
|
704 |
+
without_timestamps: bool = False,
|
705 |
+
prefix: Optional[str] = None,
|
706 |
+
) -> List[int]:
|
707 |
+
prompt = []
|
708 |
+
|
709 |
+
if previous_tokens:
|
710 |
+
prompt.append(tokenizer.sot_prev)
|
711 |
+
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
712 |
+
|
713 |
+
prompt.extend(tokenizer.sot_sequence)
|
714 |
+
|
715 |
+
if without_timestamps:
|
716 |
+
prompt.append(tokenizer.no_timestamps)
|
717 |
+
|
718 |
+
if prefix:
|
719 |
+
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
720 |
+
if len(prefix_tokens) >= self.max_length // 2:
|
721 |
+
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
722 |
+
if not without_timestamps:
|
723 |
+
prompt.append(tokenizer.timestamp_begin)
|
724 |
+
prompt.extend(prefix_tokens)
|
725 |
+
|
726 |
+
return prompt
|
727 |
+
|
728 |
+
def add_word_timestamps(
|
729 |
+
self,
|
730 |
+
segments: List[dict],
|
731 |
+
tokenizer: Tokenizer,
|
732 |
+
encoder_output: ctranslate2.StorageView,
|
733 |
+
num_frames: int,
|
734 |
+
prepend_punctuations: str,
|
735 |
+
append_punctuations: str,
|
736 |
+
last_speech_timestamp: float,
|
737 |
+
):
|
738 |
+
if len(segments) == 0:
|
739 |
+
return
|
740 |
+
|
741 |
+
text_tokens_per_segment = [
|
742 |
+
[token for token in segment["tokens"] if token < tokenizer.eot]
|
743 |
+
for segment in segments
|
744 |
+
]
|
745 |
+
|
746 |
+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
747 |
+
alignment = self.find_alignment(
|
748 |
+
tokenizer, text_tokens, encoder_output, num_frames
|
749 |
+
)
|
750 |
+
word_durations = np.array([word["end"] - word["start"] for word in alignment])
|
751 |
+
word_durations = word_durations[word_durations.nonzero()]
|
752 |
+
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
753 |
+
max_duration = median_duration * 2
|
754 |
+
|
755 |
+
# hack: truncate long words at sentence boundaries.
|
756 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
757 |
+
if len(word_durations) > 0:
|
758 |
+
sentence_end_marks = ".。!!??"
|
759 |
+
# ensure words at sentence boundaries
|
760 |
+
# are not longer than twice the median word duration.
|
761 |
+
for i in range(1, len(alignment)):
|
762 |
+
if alignment[i]["end"] - alignment[i]["start"] > max_duration:
|
763 |
+
if alignment[i]["word"] in sentence_end_marks:
|
764 |
+
alignment[i]["end"] = alignment[i]["start"] + max_duration
|
765 |
+
elif alignment[i - 1]["word"] in sentence_end_marks:
|
766 |
+
alignment[i]["start"] = alignment[i]["end"] - max_duration
|
767 |
+
|
768 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
769 |
+
|
770 |
+
time_offset = (
|
771 |
+
segments[0]["seek"]
|
772 |
+
* self.feature_extractor.hop_length
|
773 |
+
/ self.feature_extractor.sampling_rate
|
774 |
+
)
|
775 |
+
|
776 |
+
word_index = 0
|
777 |
+
|
778 |
+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
779 |
+
saved_tokens = 0
|
780 |
+
words = []
|
781 |
+
|
782 |
+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
783 |
+
timing = alignment[word_index]
|
784 |
+
|
785 |
+
if timing["word"]:
|
786 |
+
words.append(
|
787 |
+
dict(
|
788 |
+
word=timing["word"],
|
789 |
+
start=round(time_offset + timing["start"], 2),
|
790 |
+
end=round(time_offset + timing["end"], 2),
|
791 |
+
probability=timing["probability"],
|
792 |
+
)
|
793 |
+
)
|
794 |
+
|
795 |
+
saved_tokens += len(timing["tokens"])
|
796 |
+
word_index += 1
|
797 |
+
|
798 |
+
# hack: truncate long words at segment boundaries.
|
799 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
800 |
+
if len(words) > 0:
|
801 |
+
# ensure the first and second word after a pause is not longer than
|
802 |
+
# twice the median word duration.
|
803 |
+
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
804 |
+
words[0]["end"] - words[0]["start"] > max_duration
|
805 |
+
or (
|
806 |
+
len(words) > 1
|
807 |
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
808 |
+
)
|
809 |
+
):
|
810 |
+
if (
|
811 |
+
len(words) > 1
|
812 |
+
and words[1]["end"] - words[1]["start"] > max_duration
|
813 |
+
):
|
814 |
+
boundary = max(
|
815 |
+
words[1]["end"] / 2, words[1]["end"] - max_duration
|
816 |
+
)
|
817 |
+
words[0]["end"] = words[1]["start"] = boundary
|
818 |
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
819 |
+
|
820 |
+
# prefer the segment-level start timestamp if the first word is too long.
|
821 |
+
if (
|
822 |
+
segment["start"] < words[0]["end"]
|
823 |
+
and segment["start"] - 0.5 > words[0]["start"]
|
824 |
+
):
|
825 |
+
words[0]["start"] = max(
|
826 |
+
0, min(words[0]["end"] - median_duration, segment["start"])
|
827 |
+
)
|
828 |
+
else:
|
829 |
+
segment["start"] = words[0]["start"]
|
830 |
+
|
831 |
+
# prefer the segment-level end timestamp if the last word is too long.
|
832 |
+
if (
|
833 |
+
segment["end"] > words[-1]["start"]
|
834 |
+
and segment["end"] + 0.5 < words[-1]["end"]
|
835 |
+
):
|
836 |
+
words[-1]["end"] = max(
|
837 |
+
words[-1]["start"] + median_duration, segment["end"]
|
838 |
+
)
|
839 |
+
else:
|
840 |
+
segment["end"] = words[-1]["end"]
|
841 |
+
|
842 |
+
last_speech_timestamp = segment["end"]
|
843 |
+
|
844 |
+
segment["words"] = words
|
845 |
+
|
846 |
+
def find_alignment(
|
847 |
+
self,
|
848 |
+
tokenizer: Tokenizer,
|
849 |
+
text_tokens: List[int],
|
850 |
+
encoder_output: ctranslate2.StorageView,
|
851 |
+
num_frames: int,
|
852 |
+
median_filter_width: int = 7,
|
853 |
+
) -> List[dict]:
|
854 |
+
if len(text_tokens) == 0:
|
855 |
+
return []
|
856 |
+
|
857 |
+
result = self.model.align(
|
858 |
+
encoder_output,
|
859 |
+
tokenizer.sot_sequence,
|
860 |
+
[text_tokens],
|
861 |
+
num_frames,
|
862 |
+
median_filter_width=median_filter_width,
|
863 |
+
)[0]
|
864 |
+
|
865 |
+
text_token_probs = result.text_token_probs
|
866 |
+
|
867 |
+
alignments = result.alignments
|
868 |
+
text_indices = np.array([pair[0] for pair in alignments])
|
869 |
+
time_indices = np.array([pair[1] for pair in alignments])
|
870 |
+
|
871 |
+
words, word_tokens = tokenizer.split_to_word_tokens(
|
872 |
+
text_tokens + [tokenizer.eot]
|
873 |
+
)
|
874 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
875 |
+
if len(word_boundaries) <= 1:
|
876 |
+
return []
|
877 |
+
|
878 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
879 |
+
jump_times = time_indices[jumps] / self.tokens_per_second
|
880 |
+
start_times = jump_times[word_boundaries[:-1]]
|
881 |
+
end_times = jump_times[word_boundaries[1:]]
|
882 |
+
word_probabilities = [
|
883 |
+
np.mean(text_token_probs[i:j])
|
884 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
885 |
+
]
|
886 |
+
|
887 |
+
return [
|
888 |
+
dict(
|
889 |
+
word=word, tokens=tokens, start=start, end=end, probability=probability
|
890 |
+
)
|
891 |
+
for word, tokens, start, end, probability in zip(
|
892 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
893 |
+
)
|
894 |
+
]
|
895 |
+
|
896 |
+
|
897 |
+
def restore_speech_timestamps(
|
898 |
+
segments: Iterable[Segment],
|
899 |
+
speech_chunks: List[dict],
|
900 |
+
sampling_rate: int,
|
901 |
+
) -> Iterable[Segment]:
|
902 |
+
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
903 |
+
|
904 |
+
for segment in segments:
|
905 |
+
if segment.words:
|
906 |
+
words = []
|
907 |
+
for word in segment.words:
|
908 |
+
# Ensure the word start and end times are resolved to the same chunk.
|
909 |
+
middle = (word.start + word.end) / 2
|
910 |
+
chunk_index = ts_map.get_chunk_index(middle)
|
911 |
+
word = word._replace(
|
912 |
+
start=ts_map.get_original_time(word.start, chunk_index),
|
913 |
+
end=ts_map.get_original_time(word.end, chunk_index),
|
914 |
+
)
|
915 |
+
words.append(word)
|
916 |
+
|
917 |
+
segment = segment._replace(
|
918 |
+
start=words[0].start,
|
919 |
+
end=words[-1].end,
|
920 |
+
words=words,
|
921 |
+
)
|
922 |
+
|
923 |
+
else:
|
924 |
+
segment = segment._replace(
|
925 |
+
start=ts_map.get_original_time(segment.start),
|
926 |
+
end=ts_map.get_original_time(segment.end),
|
927 |
+
)
|
928 |
+
|
929 |
+
yield segment
|
930 |
+
|
931 |
+
|
932 |
+
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
933 |
+
segment = np.ascontiguousarray(segment)
|
934 |
+
segment = ctranslate2.StorageView.from_array(segment)
|
935 |
+
return segment
|
936 |
+
|
937 |
+
|
938 |
+
def get_compression_ratio(text: str) -> float:
|
939 |
+
text_bytes = text.encode("utf-8")
|
940 |
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
941 |
+
|
942 |
+
|
943 |
+
def get_suppressed_tokens(tokenizer, suppress_tokens):
|
944 |
+
if not suppress_tokens or -1 in suppress_tokens:
|
945 |
+
return suppress_tokens
|
946 |
+
|
947 |
+
suppress_tokens = list(suppress_tokens)
|
948 |
+
|
949 |
+
# Ensure the following special tokens are suppressed when the user does
|
950 |
+
# not use the default set (-1).
|
951 |
+
suppress_tokens.extend(
|
952 |
+
[
|
953 |
+
tokenizer.transcribe,
|
954 |
+
tokenizer.translate,
|
955 |
+
tokenizer.sot,
|
956 |
+
tokenizer.sot_prev,
|
957 |
+
tokenizer.sot_lm,
|
958 |
+
]
|
959 |
+
)
|
960 |
+
|
961 |
+
return sorted(set(suppress_tokens))
|
962 |
+
|
963 |
+
|
964 |
+
def merge_punctuations(alignment: List[dict], prepended: str, appended: str):
|
965 |
+
# merge prepended punctuations
|
966 |
+
i = len(alignment) - 2
|
967 |
+
j = len(alignment) - 1
|
968 |
+
while i >= 0:
|
969 |
+
previous = alignment[i]
|
970 |
+
following = alignment[j]
|
971 |
+
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
|
972 |
+
# prepend it to the following word
|
973 |
+
following["word"] = previous["word"] + following["word"]
|
974 |
+
following["tokens"] = previous["tokens"] + following["tokens"]
|
975 |
+
previous["word"] = ""
|
976 |
+
previous["tokens"] = []
|
977 |
+
else:
|
978 |
+
j = i
|
979 |
+
i -= 1
|
980 |
+
|
981 |
+
# merge appended punctuations
|
982 |
+
i = 0
|
983 |
+
j = 1
|
984 |
+
while j < len(alignment):
|
985 |
+
previous = alignment[i]
|
986 |
+
following = alignment[j]
|
987 |
+
if not previous["word"].endswith(" ") and following["word"] in appended:
|
988 |
+
# append it to the previous word
|
989 |
+
previous["word"] = previous["word"] + following["word"]
|
990 |
+
previous["tokens"] = previous["tokens"] + following["tokens"]
|
991 |
+
following["word"] = ""
|
992 |
+
following["tokens"] = []
|
993 |
+
else:
|
994 |
+
i = j
|
995 |
+
j += 1
|