srd4 commited on
Commit
e87d782
1 Parent(s): 4baee5b

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +17 -14
handler.py CHANGED
@@ -1,31 +1,34 @@
1
  from typing import Dict
2
- from faster_whisper import WhisperModel
3
  import io
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, model_dir=None):
7
- # Set model size, assuming installation has been done with appropriate model files and setup
 
 
 
8
  model_size = "medium" if model_dir is None else model_dir
9
- # Change to 'cuda' to use the GPU, and set compute_type for faster computation
10
- self.model = WhisperModel(model_size, device="cuda", compute_type="float16")
11
 
12
  def __call__(self, data: Dict) -> Dict[str, str]:
13
- # Process the input data expected to be in 'inputs' key containing audio file bytes
14
  audio_bytes = data["inputs"]
15
-
16
- # Convert bytes to a file-like object
17
  audio_file = io.BytesIO(audio_bytes)
18
-
19
- # Perform transcription using the model
20
- segments, info = self.model.transcribe(audio_file)
21
-
22
- # Compile the results into a text string and extract language information
23
- # Strip whitespace from each segment before joining them
 
 
24
  text = " ".join(segment.text.strip() for segment in segments)
 
 
25
  language_code = info.language
26
  language_prob = info.language_probability
27
 
28
- # Compile the response dictionary
29
  result = {
30
  "text": text,
31
  "language": language_code,
 
1
  from typing import Dict
2
+ from faster_whisper import WhisperModel, Streaming
3
  import io
4
+ import re
5
 
6
  class EndpointHandler:
7
  def __init__(self, model_dir=None):
8
+ # Use int8 on CPU to reduce memory usage and potentially increase speed.
9
+ compute_type = "int8" if model_dir == "cpu" else "float16"
10
+
11
+ # Initialize WhisperModel with given model_size and compute_type
12
  model_size = "medium" if model_dir is None else model_dir
13
+ self.model = WhisperModel(model_size, device=model_dir, compute_type=compute_type)
 
14
 
15
  def __call__(self, data: Dict) -> Dict[str, str]:
 
16
  audio_bytes = data["inputs"]
 
 
17
  audio_file = io.BytesIO(audio_bytes)
18
+
19
+ # Use Streaming interface to leverage VAD and potential speed improvements.
20
+ # Small beam size to speed up transcription. Adjust based on performance/accuracy needs.
21
+ beam_size = 1
22
+ streaming = Streaming(device=model_dir, compute_type=compute_type, vad=True)
23
+ segments, info = streaming.transcribe(audio_file, beam_size=beam_size)
24
+
25
+ # Aggregate transcribed text and remove any extra spaces.
26
  text = " ".join(segment.text.strip() for segment in segments)
27
+ text = re.sub(' +', ' ', text)
28
+
29
  language_code = info.language
30
  language_prob = info.language_probability
31
 
 
32
  result = {
33
  "text": text,
34
  "language": language_code,