File size: 1,440 Bytes
4fec958 af7cc76 4fec958 af7cc76 4fec958 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
"""
This is just an example of what people would submit for inference.
"""
import os
from typing import Dict, List
import torch
from s3prl.downstream.runner import Runner
class PreTrainedModel(Runner):
def __init__(self, path=""):
"""
Initialize downstream model.
"""
ckp_file = os.path.join(path, "hubert_sd.ckpt")
ckp = torch.load(ckp_file, map_location="cpu")
ckp["Args"].init_ckpt = ckp_file
ckp["Args"].mode = "inference"
ckp["Args"].device = "cpu" # Just to try in my computer
Runner.__init__(self, ckp["Args"], ckp["Config"])
def __call__(self, inputs) -> List[int]:
"""
Args: inputs (:obj:`np.array`): The raw waveform of audio received. By
default at 16KHz.
Return: A list with logits.
"""
for entry in self.all_entries:
entry.model.eval()
inputs = [torch.FloatTensor(inputs)]
with torch.no_grad():
features = self.upstream.model(inputs)
features = self.featurizer.model(inputs, features)
preds = self.downstream.model.inference(features, [])
return preds[0]
"""
import io
import soundfile as sf
from urllib.request import urlopen
model = PreTrainedModel()
url = "https://huggingface.co/datasets/lewtun/s3prl-sd-dummy/raw/main/audio.wav"
data, samplerate = sf.read(io.BytesIO(urlopen(url).read()))
print(model(data))
""" |