stupidog04 commited on
Commit
0597420
1 Parent(s): 164a729

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +36 -12
pipeline.py CHANGED
@@ -1,6 +1,8 @@
1
  from torchvision import transforms
2
- from pair_classification import PairClassificationPipeline
3
- from typing import Dict
 
 
4
 
5
 
6
  class PreTrainedPipeline():
@@ -8,17 +10,13 @@ class PreTrainedPipeline():
8
  """
9
  Initialize model
10
  """
11
- model_flag = 'google/vit-base-patch16-224-in21k'
12
  # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
13
- self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag ,
14
- model_kwargs={'num_labels':len(label2id),
15
- 'label2id':label2id,
16
- 'id2label':id2label,
17
- 'num_channels':6,
18
- 'ignore_mismatched_sizes': True })
19
- self.model = self.pipe.model.from_pretrained(path)
20
 
21
-
22
  def __call__(self, inputs):
23
  """
24
  Args:
@@ -30,4 +28,30 @@ class PreTrainedPipeline():
30
  """
31
  # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
32
  # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
33
- return self.pipe(inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from torchvision import transforms
2
+ from transformers import ViTFeatureExtractor, ViTForImageClassification, ViTConfig
3
+ from transformers import ImageClassificationPipeline
4
+ import torch
5
+
6
 
7
 
8
  class PreTrainedPipeline():
 
10
  """
11
  Initialize model
12
  """
 
13
  # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
14
+ model_flag = 'google/vit-base-patch16-224-in21k'
15
+ # model_flag = 'google/vit-base-patch16-384'
16
+ self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
17
+ self.model = ViTForImageClassification.from_pretrained(path)
18
+ self.pipe = PairClassificationPipeline(self.model, feature_extractor=self.feature_extractor)
 
 
19
 
 
20
  def __call__(self, inputs):
21
  """
22
  Args:
 
28
  """
29
  # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
30
  # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
31
+ return self.pipe(inputs)
32
+
33
+
34
+ class PairClassificationPipeline(ImageClassificationPipeline):
35
+ pipe_to_tensor = transforms.ToTensor()
36
+ pipe_to_pil = transforms.ToPILImage()
37
+
38
+ def preprocess(self, image):
39
+ left_image, right_image = self.horizontal_split_image(image)
40
+ model_inputs = self.extract_split_feature(left_image, right_image)
41
+ # model_inputs = super().preprocess(image)
42
+ # print(model_inputs['pixel_values'].shape)
43
+ return model_inputs
44
+
45
+ def horizontal_split_image(self, image):
46
+ # image = image.resize((448,224))
47
+ w, h = image.size
48
+ half_w = w//2
49
+ left_image = image.crop([0,0,half_w,h])
50
+ right_image = image.crop([half_w,0,2*half_w,h])
51
+ return left_image, right_image
52
+
53
+ def extract_split_feature(self, left_image, right_image):
54
+ model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework)
55
+ right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework)
56
+ model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1)
57
+ return model_inputs