a43992899 commited on
Commit
2e154fb
1 Parent(s): 573f43e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -4
README.md CHANGED
@@ -16,6 +16,7 @@ Our model is SOTA-comparable on multiple MIR tasks even under probing settings,
16
  ```python
17
  from transformers import Wav2Vec2Processor, Data2VecAudioModel
18
  import torch
 
19
  from datasets import load_dataset
20
 
21
  # load demo audio and set processor
@@ -31,11 +32,20 @@ model = Data2VecAudioModel.from_pretrained("m-a-p/music2vec-v1")
31
  # audio file is decoded on the fly
32
  inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
33
  with torch.no_grad():
34
- outputs = model(**inputs)
35
 
36
- # take a look at the output shape
37
- last_hidden_states = outputs.last_hidden_state
38
- print(list(last_hidden_states.shape)) # [1, 292, 768]
 
 
 
 
 
 
 
 
 
39
  ```
40
 
41
  Our model is based on the [data2vec audio model](https://huggingface.co/docs/transformers/model_doc/data2vec#transformers.Data2VecAudioModel).
 
16
  ```python
17
  from transformers import Wav2Vec2Processor, Data2VecAudioModel
18
  import torch
19
+ from torch import nn
20
  from datasets import load_dataset
21
 
22
  # load demo audio and set processor
 
32
  # audio file is decoded on the fly
33
  inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
34
  with torch.no_grad():
35
+ outputs = model(**inputs, output_hidden_states=True)
36
 
37
+ # take a look at the output shape, there are 13 layers of representation
38
+ # each layer performs differently in different downstream tasks, you should choose empirically
39
+ all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
40
+ print(all_layer_hidden_states.shape) # [13 layer, 292 timestep, 768 feature_dim]
41
+
42
+ # for utterance level classification tasks, you can simply reduce the representation in time
43
+ time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
44
+ print(time_reduced_hidden_states.shape) # [13, 768]
45
+
46
+ # you can even use a learnable weighted average representation
47
+ aggregator = nn.Conv1d(in_channels=12, out_channels=1, kernel_size=1)
48
+ weighted_avg_hidden_states = aggregator(time_reduced_hidden_states).squeeze()
49
  ```
50
 
51
  Our model is based on the [data2vec audio model](https://huggingface.co/docs/transformers/model_doc/data2vec#transformers.Data2VecAudioModel).