Kr08 commited on
Commit
1dc65c3
·
verified ·
1 Parent(s): 43f1b5e

Update model_utils.py

Browse files
Files changed (1) hide show
  1. model_utils.py +4 -4
model_utils.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import whisper
4
  from config import WHISPER_MODEL_SIZE
5
- import spaces
6
 
7
  # Global variables to store models
8
  whisper_processor = None
@@ -15,13 +15,13 @@ def load_models():
15
  if whisper_processor is None:
16
  whisper_processor = WhisperProcessor.from_pretrained(f"openai/whisper-{WHISPER_MODEL_SIZE}")
17
  if whisper_model is None:
18
- whisper_model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{WHISPER_MODEL_SIZE}").to(get_device())
19
  if whisper_model_small is None:
20
  whisper_model_small = whisper.load_model(WHISPER_MODEL_SIZE)
21
 
22
  @spaces.GPU
23
  def get_device():
24
- return "cuda:0" if torch.cuda.is_available() else "cpu"
25
 
26
  @spaces.GPU
27
  def get_processor():
@@ -35,7 +35,7 @@ def get_model():
35
  global whisper_model
36
  if whisper_model is None:
37
  load_models()
38
- return whisper_model
39
 
40
  @spaces.GPU
41
  def get_whisper_model_small():
 
2
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
3
  import whisper
4
  from config import WHISPER_MODEL_SIZE
5
+ import spaces
6
 
7
  # Global variables to store models
8
  whisper_processor = None
 
15
  if whisper_processor is None:
16
  whisper_processor = WhisperProcessor.from_pretrained(f"openai/whisper-{WHISPER_MODEL_SIZE}")
17
  if whisper_model is None:
18
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(f"openai/whisper-{WHISPER_MODEL_SIZE}")
19
  if whisper_model_small is None:
20
  whisper_model_small = whisper.load_model(WHISPER_MODEL_SIZE)
21
 
22
  @spaces.GPU
23
  def get_device():
24
+ return "cuda" if torch.cuda.is_available() else "cpu"
25
 
26
  @spaces.GPU
27
  def get_processor():
 
35
  global whisper_model
36
  if whisper_model is None:
37
  load_models()
38
+ return whisper_model.to(get_device())
39
 
40
  @spaces.GPU
41
  def get_whisper_model_small():