from torchvision import transforms from pair_classification import PairClassificationPipeline from typing import Dict class PreTrainedPipeline(): def __init__(self, path): """ Initialize model """ model_flag = 'google/vit-base-patch16-224-in21k' # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag) self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag , model_kwargs={'num_labels':len(label2id), 'label2id':label2id, 'id2label':id2label, 'num_channels':6, 'ignore_mismatched_sizes': True }) self.model = self.pipe.model.from_pretrained(path) def __call__(self, inputs): """ Args: inputs (:obj:`np.array`): The raw waveform of audio received. By default at 16KHz. Return: A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing the detected text from the input audio. """ # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1 # logits = self.model(input_values).logits.cpu().detach().numpy()[0] return self.pipe(inputs)