stupidog04 commited on
Commit
4a87051
1 Parent(s): 702d195

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +15 -30
pipeline.py CHANGED
@@ -1,34 +1,21 @@
1
- import numpy as np
2
- from typing import Dict
3
-
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- from pyctcdecode import Alphabet, BeamSearchDecoderCTC
6
 
7
  class PreTrainedPipeline():
8
  def __init__(self, path):
9
  """
10
  Initialize model
11
  """
12
- self.processor = Wav2Vec2Processor.from_pretrained(path)
13
- self.model = Wav2Vec2ForCTC.from_pretrained(path)
14
- vocab_list = list(self.processor.tokenizer.get_vocab().keys())
15
-
16
- # convert ctc blank character representation
17
- vocab_list[0] = ""
18
-
19
- # replace special characters
20
- vocab_list[1] = "⁇"
21
- vocab_list[2] = "⁇"
22
- vocab_list[3] = "⁇"
23
-
24
- # convert space character representation
25
- vocab_list[4] = " "
26
-
27
- alphabet = Alphabet.build_alphabet(vocab_list, ctc_token_idx=0)
28
-
29
- self.decoder = BeamSearchDecoderCTC(alphabet)
30
- self.sampling_rate = 16000
31
-
32
 
33
  def __call__(self, inputs)-> Dict[str, str]:
34
  """
@@ -39,8 +26,6 @@ class PreTrainedPipeline():
39
  A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
40
  the detected text from the input audio.
41
  """
42
- input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
43
- logits = self.model(input_values).logits.cpu().detach().numpy()[0]
44
- return {
45
- "text": self.decoder.decode(logits)
46
- }
 
1
+ from torchvision import transforms
2
+ from pair_classification import PairClassificationPipeline
 
 
 
3
 
4
  class PreTrainedPipeline():
5
  def __init__(self, path):
6
  """
7
  Initialize model
8
  """
9
+ model_flag = 'google/vit-base-patch16-224-in21k'
10
+ # self.processor = feature_extractor = ViTFeatureExtractor.from_pretrained(model_flag)
11
+ self.pipe = pipeline("pair-classification", model=model_flag , feature_extractor=model_flag ,
12
+ model_kwargs={'num_labels':len(label2id),
13
+ 'label2id':label2id,
14
+ 'id2label':id2label,
15
+ 'num_channels':6,
16
+ 'ignore_mismatched_sizes': True })
17
+ self.model = self.pipe.model.from_pretrained(path)
18
+
 
 
 
 
 
 
 
 
 
 
19
 
20
  def __call__(self, inputs)-> Dict[str, str]:
21
  """
 
26
  A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
27
  the detected text from the input audio.
28
  """
29
+ # input_values = self.processor(inputs, return_tensors="pt", sampling_rate=self.sampling_rate).input_values # Batch size 1
30
+ # logits = self.model(input_values).logits.cpu().detach().numpy()[0]
31
+ return self.pipe(inputs)