File size: 6,455 Bytes
95261ed
7fd072f
33a2c1e
44d964a
33a2c1e
95261ed
33a2c1e
 
44d964a
33a2c1e
95261ed
c0e541b
31f7bdb
95261ed
44d964a
 
95261ed
 
01fddc0
31f7bdb
95261ed
 
 
44d964a
 
 
95261ed
 
 
31f7bdb
 
c0e541b
31f7bdb
c0e541b
 
95261ed
 
7c5d37e
 
 
 
 
 
 
7fd072f
44d964a
7fd072f
 
 
 
7c5d37e
7fd072f
44d964a
 
 
7c5d37e
44d964a
7c5d37e
 
 
 
 
44d964a
 
 
 
 
 
 
 
 
c0e541b
 
44d964a
 
 
 
 
 
c0e541b
95261ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44d964a
95261ed
 
 
 
c0e541b
44d964a
95261ed
31f7bdb
c0e541b
95261ed
 
 
 
 
 
 
 
 
 
33a2c1e
95261ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33a2c1e
 
 
 
 
 
 
95261ed
33a2c1e
 
 
 
95261ed
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# External programs
import os
import sys
from typing import List

import whisper
from whisper import Whisper

from src.config import ModelConfig
from src.hooks.whisperProgressHook import ProgressListener, create_progress_listener_handle

from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache

class WhisperContainer:
    def __init__(self, model_name: str, device: str = None, download_root: str = None, 
                 cache: ModelCache = None, models: List[ModelConfig] = []):
        self.model_name = model_name
        self.device = device
        self.download_root = download_root
        self.cache = cache

        # Will be created on demand
        self.model = None

        # List of known models
        self.models = models
    
    def get_model(self):
        if self.model is None:

            if (self.cache is None):
                self.model = self._create_model()
            else:
                model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
                self.model = self.cache.get(model_key, self._create_model)
        return self.model

    def ensure_downloaded(self):
        """
        Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
        passing the container to a subprocess.
        """
        # Warning: Using private API here
        try:
            root_dir = self.download_root
            model_config = self.get_model_config()

            if root_dir is None:
                root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")

            if self.model_name in whisper._MODELS:
                whisper._download(whisper._MODELS[self.model_name], root_dir, False)
            else:
                # If the model is not in the official list, see if it needs to be downloaded
                model_config.download_url(root_dir)
            return True
        
        except Exception as e:
            # Given that the API is private, it could change at any time. We don't want to crash the program
            print("Error pre-downloading model: " + str(e))
            return False

    def get_model_config(self) -> ModelConfig:
        """
        Get the model configuration for the model.
        """
        for model in self.models:
            if model.name == self.model_name:
                return model
        return None

    def _create_model(self):
        print("Loading whisper model " + self.model_name)
        
        model_config = self.get_model_config()
        # Note that the model will not be downloaded in the case of an official Whisper model
        model_path = model_config.download_url(self.download_root)

        return whisper.load_model(model_path, device=self.device, download_root=self.download_root)

    def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
        """
        Create a WhisperCallback object that can be used to transcript audio files.

        Parameters
        ----------
        language: str
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        initial_prompt: str
            The initial prompt to use for the transcription.
        decodeOptions: dict
            Additional options to pass to the decoder. Must be pickleable.

        Returns
        -------
        A WhisperCallback object.
        """
        return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)

    # This is required for multiprocessing
    def __getstate__(self):
        return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root, "models": self.models }

    def __setstate__(self, state):
        self.model_name = state["model_name"]
        self.device = state["device"]
        self.download_root = state["download_root"]
        self.models = state["models"]
        self.model = None
        # Depickled objects must use the global cache
        self.cache = GLOBAL_MODEL_CACHE


class WhisperCallback:
    def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
        self.model_container = model_container
        self.language = language
        self.task = task
        self.initial_prompt = initial_prompt
        self.decodeOptions = decodeOptions
        
    def invoke(self, audio, segment_index: int, prompt: str, detected_language: str, progress_listener: ProgressListener = None):
        """
        Peform the transcription of the given audio file or data.

        Parameters
        ----------
        audio: Union[str, np.ndarray, torch.Tensor]
            The audio file to transcribe, or the audio data as a numpy array or torch tensor.
        segment_index: int
            The target language of the transcription. If not specified, the language will be inferred from the audio content.
        task: str
            The task - either translate or transcribe.
        prompt: str
            The prompt to use for the transcription.
        detected_language: str
            The detected language of the audio file.

        Returns
        -------
        The result of the Whisper call.
        """
        model = self.model_container.get_model()

        if progress_listener is not None:
            with create_progress_listener_handle(progress_listener):
                return self._transcribe(model, audio, segment_index, prompt, detected_language)
        else:
            return self._transcribe(model, audio, segment_index, prompt, detected_language)
    
    def _transcribe(self, model: Whisper, audio, segment_index: int, prompt: str, detected_language: str):
        return model.transcribe(audio, \
            language=self.language if self.language else detected_language, task=self.task, \
            initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
            **self.decodeOptions
        )

    def _concat_prompt(self, prompt1, prompt2):
        if (prompt1 is None):
            return prompt2
        elif (prompt2 is None):
            return prompt1
        else:
            return prompt1 + " " + prompt2