Update README.md
Browse files
README.md
CHANGED
@@ -16,6 +16,7 @@ Our model is SOTA-comparable on multiple MIR tasks even under probing settings,
|
|
16 |
```python
|
17 |
from transformers import Wav2Vec2Processor, Data2VecAudioModel
|
18 |
import torch
|
|
|
19 |
from datasets import load_dataset
|
20 |
|
21 |
# load demo audio and set processor
|
@@ -31,11 +32,20 @@ model = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
|
|
31 |
# audio file is decoded on the fly
|
32 |
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
33 |
with torch.no_grad():
|
34 |
-
outputs = model(**inputs)
|
35 |
|
36 |
-
# take a look at the output shape
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
```
|
40 |
|
41 |
Our model is based on the [data2vec audio model](https://huggingface.co/docs/transformers/model_doc/data2vec#transformers.Data2VecAudioModel).
|
|
|
16 |
```python
|
17 |
from transformers import Wav2Vec2Processor, Data2VecAudioModel
|
18 |
import torch
|
19 |
+
from torch import nn
|
20 |
from datasets import load_dataset
|
21 |
|
22 |
# load demo audio and set processor
|
|
|
32 |
# audio file is decoded on the fly
|
33 |
inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
34 |
with torch.no_grad():
|
35 |
+
outputs = model(**inputs, output_hidden_states=True)
|
36 |
|
37 |
+
# take a look at the output shape, there are 13 layers of representation
|
38 |
+
# each layer performs differently in different downstream tasks, you should choose empirically
|
39 |
+
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
|
40 |
+
print(all_layer_hidden_states.shape) # [13 layer, 292 timestep, 768 feature_dim]
|
41 |
+
|
42 |
+
# for utterance level classification tasks, you can simply reduce the representation in time
|
43 |
+
time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
|
44 |
+
print(time_reduced_hidden_states.shape) # [13, 768]
|
45 |
+
|
46 |
+
# you can even use a learnable weighted average representation
|
47 |
+
aggregator = nn.Conv1d(in_channels=12, out_channels=1, kernel_size=1)
|
48 |
+
weighted_avg_hidden_states = aggregator(time_reduced_hidden_states).squeeze()
|
49 |
```
|
50 |
|
51 |
Our model is based on the [data2vec audio model](https://huggingface.co/docs/transformers/model_doc/data2vec#transformers.Data2VecAudioModel).
|