|
from torchvision import transforms |
|
from pair_classification import PairClassificationPipeline |
|
from typing import Dict |
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path): |
|
""" |
|
Initialize model |
|
""" |
|
model_flag = 'google/vit-base-patch16-224-in21k' |
|
|
|
self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag , |
|
model_kwargs={'num_labels':len(label2id), |
|
'label2id':label2id, |
|
'id2label':id2label, |
|
'num_channels':6, |
|
'ignore_mismatched_sizes': True }) |
|
self.model = self.pipe.model.from_pretrained(path) |
|
|
|
|
|
def __call__(self, inputs): |
|
""" |
|
Args: |
|
inputs (:obj:`np.array`): |
|
The raw waveform of audio received. By default at 16KHz. |
|
Return: |
|
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing |
|
the detected text from the input audio. |
|
""" |
|
|
|
|
|
return self.pipe(inputs) |