yizhilll commited on
Commit
8721106
1 Parent(s): b74e8bd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -11
README.md CHANGED
@@ -55,7 +55,8 @@ More details will be written in our coming-soon paper.
55
  # Model Usage
56
 
57
  ```python
58
- from transformers import Wav2Vec2Processor
 
59
  from transformers import AutoModel
60
  import torch
61
  from torch import nn
@@ -63,28 +64,32 @@ import torchaudio.transforms as T
63
  from datasets import load_dataset
64
 
65
 
 
 
 
 
 
 
 
 
66
  # load demo audio and set processor
67
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
68
  dataset = dataset.sort("id")
69
  sampling_rate = dataset.features["audio"].sampling_rate
70
- processor = Wav2Vec2Processor.from_pretrained("m-a-p/MERT-v1-330M")
71
 
72
- resample_rate = processor.feature_extractor.sampling_rate
73
  # make sure the sample_rate aligned
74
  if resample_rate != sampling_rate:
75
- resampler = T.Resample(sample_rate, resample_rate)
 
76
  else:
77
- resampler = None
78
-
79
- # loading our model weights
80
- commit_hash='7bab7bb5d8b52448eff4873a980dc17f0015a09c'# this is recommended for security reason, the hash might be updated
81
- model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True, revision=commit_hash)
82
 
83
  # audio file is decoded on the fly
84
  if resampler is None:
85
  input_audio = dataset[0]["audio"]["array"]
86
  else:
87
- input_audio = resampler(dataset[0]["audio"]["array"])
88
 
89
  inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
90
  with torch.no_grad():
@@ -100,7 +105,7 @@ time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
100
  print(time_reduced_hidden_states.shape) # [25, 1024]
101
 
102
  # you can even use a learnable weighted average representation
103
- aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
104
  weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
105
  print(weighted_avg_hidden_states.shape) # [1024]
106
  ```
 
55
  # Model Usage
56
 
57
  ```python
58
+ # from transformers import Wav2Vec2Processor
59
+ from transformers import Wav2Vec2FeatureExtractor
60
  from transformers import AutoModel
61
  import torch
62
  from torch import nn
 
64
  from datasets import load_dataset
65
 
66
 
67
+
68
+
69
+ commit_hash='b74e8bdecaa1aa58bbd1fd752a7db0695194d9bb'# this is recommended for security reason, the hash might be updated
70
+ # loading our model weights
71
+ model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True, revision=commit_hash)
72
+ # loading the corresponding preprocessor config
73
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True, revision=commit_hash)
74
+
75
  # load demo audio and set processor
76
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
77
  dataset = dataset.sort("id")
78
  sampling_rate = dataset.features["audio"].sampling_rate
 
79
 
80
+ resample_rate = processor.sampling_rate
81
  # make sure the sample_rate aligned
82
  if resample_rate != sampling_rate:
83
+ print(f'setting rate from {sampling_rate} to {resample_rate}')
84
+ resampler = T.Resample(sampling_rate, resample_rate)
85
  else:
86
+ resampler = None
 
 
 
 
87
 
88
  # audio file is decoded on the fly
89
  if resampler is None:
90
  input_audio = dataset[0]["audio"]["array"]
91
  else:
92
+ input_audio = resampler(torch.from_numpy(dataset[0]["audio"]["array"]))
93
 
94
  inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
95
  with torch.no_grad():
 
105
  print(time_reduced_hidden_states.shape) # [25, 1024]
106
 
107
  # you can even use a learnable weighted average representation
108
+ aggregator = nn.Conv1d(in_channels=25, out_channels=1, kernel_size=1)
109
  weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
110
  print(weighted_avg_hidden_states.shape) # [1024]
111
  ```