Commit
·
a875c0d
1
Parent(s):
47a6a98
WIP feature extractor for wav2vec2
Browse files- feature_extractor.py +68 -0
- readme.MD +14 -0
feature_extractor.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import soundfile as sf
|
3 |
+
import pdb
|
4 |
+
from pydub import AudioSegment
|
5 |
+
from transformers import AutoTokenizer, Wav2Vec2ForCTC
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import glob
|
9 |
+
import numpy
|
10 |
+
import os.path
|
11 |
+
|
12 |
+
processor = AutoTokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
13 |
+
|
14 |
+
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60")
|
15 |
+
|
16 |
+
# Dementia path
|
17 |
+
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt
|
18 |
+
# cookie dementia /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Dementia/cookie
|
19 |
+
# /home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie
|
20 |
+
|
21 |
+
|
22 |
+
def convert_mp3_to_wav(audio_file):
|
23 |
+
sound = AudioSegment.from_mp3(audio_file)
|
24 |
+
sound.export(audio_file + ".wav", format="wav")
|
25 |
+
|
26 |
+
|
27 |
+
def feature_extractor(path):
|
28 |
+
print("the path is", path)
|
29 |
+
|
30 |
+
wav_files = glob.glob(path + "/*.wav")
|
31 |
+
#print(wav_files)
|
32 |
+
for wav_file in wav_files:
|
33 |
+
print("the wavfile is", wav_files)
|
34 |
+
# wav2vec2 embeddings
|
35 |
+
if not os.path.isfile(wav_file + ".wav2vec2.pt"):
|
36 |
+
get_wav2vecembeddings_from_audiofile(wav_file)
|
37 |
+
|
38 |
+
def get_wav2vecembeddings_from_audiofile(wav_file):
|
39 |
+
print("the file is", wav_file)
|
40 |
+
speech, sample_rate = sf.read(wav_file)
|
41 |
+
input_values = processor(wav_file, return_tensors="pt", padding=True) # there is no truncation param anymore
|
42 |
+
print("input values", input_values)
|
43 |
+
|
44 |
+
file_info = os.stat(wav_file)
|
45 |
+
file_size = file_info.st_size
|
46 |
+
print("the size is", file_size)
|
47 |
+
|
48 |
+
if file_size > 250:
|
49 |
+
with torch.no_grad():
|
50 |
+
encoded_states = model(
|
51 |
+
input_values=input_values["input_ids"],
|
52 |
+
attention_mask=input_values["attention_mask"],
|
53 |
+
output_hidden_states=True
|
54 |
+
)
|
55 |
+
|
56 |
+
last_hidden_state = encoded_states.hidden_states[-1] # The last hidden-state is the first element of the output tuple
|
57 |
+
print("getting wav2vec2 embeddings")
|
58 |
+
print(last_hidden_state)
|
59 |
+
torch.save(last_hidden_state, wav_file + '.wav2vec2.pt')
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
feature_extractor("/home/bmoell/data/media.talkbank.org/dementia/English/Pitt/Control/cookie")
|
68 |
+
|
readme.MD
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# todo
|
2 |
+
install things to run on tpu / hugginface / datasets
|
3 |
+
load in data
|
4 |
+
train
|
5 |
+
|
6 |
+
# Important readmes
|
7 |
+
https://github.com/huggingface/transformers/tree/f42a0abf4bd765ad08e14b347a3acbe9fade31b9/examples/research_projects/jax-projects/wav2vec2
|
8 |
+
|
9 |
+
# path to files
|
10 |
+
# cookie control
|
11 |
+
data/media.talkbank.org/dementia/English/Pitt/Control/cookie
|
12 |
+
|
13 |
+
# cookie dementia
|
14 |
+
data/media.talkbank.org/dementia/English/Pitt/Control/cookie
|