yizhilll commited on
Commit
b74e8bd
1 Parent(s): d0726b6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -2
README.md CHANGED
@@ -59,20 +59,34 @@ from transformers import Wav2Vec2Processor
59
  from transformers import AutoModel
60
  import torch
61
  from torch import nn
 
62
  from datasets import load_dataset
63
 
 
64
  # load demo audio and set processor
65
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
66
  dataset = dataset.sort("id")
67
  sampling_rate = dataset.features["audio"].sampling_rate
68
- processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
 
 
 
 
 
 
 
69
 
70
  # loading our model weights
71
  commit_hash='7bab7bb5d8b52448eff4873a980dc17f0015a09c'# this is recommended for security reason, the hash might be updated
72
  model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True, revision=commit_hash)
73
 
74
  # audio file is decoded on the fly
75
- inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
 
 
 
 
 
76
  with torch.no_grad():
77
  outputs = model(**inputs, output_hidden_states=True)
78
 
 
59
  from transformers import AutoModel
60
  import torch
61
  from torch import nn
62
+ import torchaudio.transforms as T
63
  from datasets import load_dataset
64
 
65
+
66
  # load demo audio and set processor
67
  dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
68
  dataset = dataset.sort("id")
69
  sampling_rate = dataset.features["audio"].sampling_rate
70
+ processor = Wav2Vec2Processor.from_pretrained("m-a-p/MERT-v1-330M")
71
+
72
+ resample_rate = processor.feature_extractor.sampling_rate
73
+ # make sure the sample_rate aligned
74
+ if resample_rate != sampling_rate:
75
+ resampler = T.Resample(sample_rate, resample_rate)
76
+ else:
77
+ resampler = None
78
 
79
  # loading our model weights
80
  commit_hash='7bab7bb5d8b52448eff4873a980dc17f0015a09c'# this is recommended for security reason, the hash might be updated
81
  model = AutoModel.from_pretrained("m-a-p/MERT-v1-330M", trust_remote_code=True, revision=commit_hash)
82
 
83
  # audio file is decoded on the fly
84
+ if resampler is None:
85
+ input_audio = dataset[0]["audio"]["array"]
86
+ else:
87
+ input_audio = resampler(dataset[0]["audio"]["array"])
88
+
89
+ inputs = processor(input_audio, sampling_rate=resample_rate, return_tensors="pt")
90
  with torch.no_grad():
91
  outputs = model(**inputs, output_hidden_states=True)
92