Kr08 commited on
Commit
40a48e2
·
verified ·
1 Parent(s): 7dccca6

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +6 -0
model_utils.py CHANGED
@@ -2,12 +2,14 @@ import torch
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import whisper
4
  from config import WHISPER_MODEL_SIZE
 
5
 
6
  # Global variables to store models
7
  whisper_processor = None
8
  whisper_model = None
9
  whisper_model_small = None
10
 
 
11
  def load_models():
12
  global whisper_processor, whisper_model, whisper_model_small
13
  if whisper_processor is None:
@@ -17,21 +19,25 @@ def load_models():
17
  if whisper_model_small is None:
18
  whisper_model_small = whisper.load_model(WHISPER_MODEL_SIZE)
19
 
 
20
  def get_device():
21
  return "cuda:0" if torch.cuda.is_available() else "cpu"
22
 
 
23
  def get_processor():
24
  global whisper_processor
25
  if whisper_processor is None:
26
  load_models()
27
  return whisper_processor
28
 
 
29
  def get_model():
30
  global whisper_model
31
  if whisper_model is None:
32
  load_models()
33
  return whisper_model
34
 
 
35
  def get_whisper_model_small():
36
  global whisper_model_small
37
  if whisper_model_small is None:
 
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import whisper
4
  from config import WHISPER_MODEL_SIZE
5
+ import spaces
6
 
7
  # Global variables to store models
8
  whisper_processor = None
9
  whisper_model = None
10
  whisper_model_small = None
11
 
12
+ +@spaces.GPU
13
  def load_models():
14
  global whisper_processor, whisper_model, whisper_model_small
15
  if whisper_processor is None:
 
19
  if whisper_model_small is None:
20
  whisper_model_small = whisper.load_model(WHISPER_MODEL_SIZE)
21
 
22
+ +@spaces.GPU
23
  def get_device():
24
  return "cuda:0" if torch.cuda.is_available() else "cpu"
25
 
26
+ +@spaces.GPU
27
  def get_processor():
28
  global whisper_processor
29
  if whisper_processor is None:
30
  load_models()
31
  return whisper_processor
32
 
33
+ +@spaces.GPU
34
  def get_model():
35
  global whisper_model
36
  if whisper_model is None:
37
  load_models()
38
  return whisper_model
39
 
40
+ +@spaces.GPU
41
  def get_whisper_model_small():
42
  global whisper_model_small
43
  if whisper_model_small is None: