stupidog04 commited on
Commit
02cf913
1 Parent(s): aa1dbe1

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +46 -32
pipeline.py CHANGED
@@ -1,32 +1,46 @@
1
- from torchvision import transforms
2
- from transformers import ImageClassificationPipeline
3
- import torch
4
-
5
-
6
- class PairClassificationPipeline(ImageClassificationPipeline):
7
- pipe_to_tensor = transforms.ToTensor()
8
- pipe_to_pil = transforms.ToPILImage()
9
-
10
- def __init__():
11
- super().__init__()
12
-
13
- def preprocess(self, image):
14
- left_image, right_image = self.horizontal_split_image(image)
15
- model_inputs = self.extract_split_feature(left_image, right_image)
16
- # model_inputs = super().preprocess(image)
17
- # print(model_inputs['pixel_values'].shape)
18
- return model_inputs
19
-
20
- def horizontal_split_image(self, image):
21
- # image = image.resize((448,224))
22
- w, h = image.size
23
- half_w = w//2
24
- left_image = image.crop([0,0,half_w,h])
25
- right_image = image.crop([half_w,0,2*half_w,h])
26
- return left_image, right_image
27
-
28
- def extract_split_feature(self, left_image, right_image):
29
- model_inputs = self.feature_extractor(images=left_image, return_tensors=self.framework)
30
- right_inputs = self.feature_extractor(images=right_image, return_tensors=self.framework)
31
- model_inputs['pixel_values'] = torch.cat([model_inputs['pixel_values'],right_inputs['pixel_values']], dim=1)
32
- return model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Dict
3
+
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+ from pyctcdecode import Alphabet, BeamSearchDecoderCTC
6
+
7
+ class PreTrainedPipeline():
8
+ def __init__(self, path):
9
+ """
10
+ Initialize model
11
+ """
12
+ self.processor = Wav2Vec2Processor.from_pretrained(path)
13
+ self.model = Wav2Vec2ForCTC.from_pretrained(path)
14
+ vocab_list = list(self.processor.tokenizer.get_vocab().keys())
15
+
16
+ # convert ctc blank character representation
17
+ vocab_list[0] = ""
18
+
19
+ # replace special characters
20
+ vocab_list[1] = "⁇"
21
+ vocab_list[2] = "⁇"
22
+ vocab_list[3] = "⁇"
23
+
24
+ # convert space character representation
25
+ vocab_list[4] = " "
26
+
27
+ alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=0)
28
+
29
+ self.decoder = BeamSearchDecoderCTC(alphabet)
30
+ self.sampling_rate = 16000
31
+
32
+
33
+ def __call__(self, inputs)-> Dict[str, str]:
34
+ """
35
+ Args:
36
+ inputs (:obj:`np.array`):
37
+ The raw waveform of audio received. By default at 16KHz.
38
+ Return:
39
+ A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
40
+ the detected text from the input audio.
41
+ """
42
+ input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
43
+ logits = self.model(input_values).logits.cpu().detach().numpy()[0]
44
+ return {
45
+ "text": self.decoder.decode(logits)
46
+ }