File size: 2,476 Bytes
100aae6
45c38e5
100aae6
 
 
 
45c38e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100aae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from transformers import ImageClassificationPipeline
import torch


class PreTrainedPipeline():
    def __init__(self, path):
        """
        Initialize model
        """
        # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
        model_flag = 'google/vit-base-patch16-224-in21k'
        # model_flag = 'google/vit-base-patch16-384'
        self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
        self.model = ViTForImageClassification.from_pretrained(path)
        self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
        
    def __call__(self, inputs):
        """
        Args:
            inputs (:obj:`np.array`):
                The raw waveform of audio received. By default at 16KHz.
        Return:
            A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
            the detected text from the input audio.
        """
        # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values  # Batch size 1
        # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
        return self.pipe(inputs)


class PairClassificationPipeline(ImageClassificationPipeline):
    pipe_to_tensor = transforms.ToTensor()
    pipe_to_pil = transforms.ToPILImage()

    def preprocess(self, image):       
        left_image, right_image = self.horizontal_split_image(image)
        model_inputs = self.extract_split_feature(left_image, right_image)
        # model_inputs = super().preprocess(image)
        # print(model_inputs['pixel_values'].shape)
        return model_inputs

    def horizontal_split_image(self, image):
        # image = image.resize((448,224))
        w, h = image.size
        half_w = w//2
        left_image = image.crop([0,0,half_w,h])
        right_image = image.crop([half_w,0,2*half_w,h])
        return left_image, right_image
    
    def extract_split_feature(self, left_image, right_image):
        model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework)
        right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework)
        model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1)
        return model_inputs