zhzluke96 commited on
Commit
8a3a4ec
1 Parent(s): 43c9aae
Files changed (1) hide show
  1. modules/models.py +7 -8
modules/models.py CHANGED
@@ -10,10 +10,7 @@ import gc
10
  logger = logging.getLogger(__name__)
11
 
12
  chat_tts = None
13
- # 某些平台上,不让在主线程中加载模型,否则会出现错误
14
- # huggingface Error:
15
- # RuntimeError: CUDA must not be initialized in the main process on Spaces with Stateless GPU environment.
16
- # You can look at this Stacktrace to find out which part of your code triggered a CUDA init
17
 
18
 
19
  def load_chat_tts_in_thread():
@@ -40,14 +37,16 @@ def load_chat_tts_in_thread():
40
 
41
 
42
  def initialize_chat_tts():
43
- model_thread = threading.Thread(target=load_chat_tts_in_thread)
44
- model_thread.start()
45
- return model_thread
 
 
46
 
47
 
48
  def load_chat_tts():
49
  if chat_tts is None:
50
- initialize_chat_tts().join()
51
  if chat_tts is None:
52
  raise Exception("Failed to load ChatTTS models")
53
  return chat_tts
 
10
  logger = logging.getLogger(__name__)
11
 
12
  chat_tts = None
13
+ lock = threading.Lock()
 
 
 
14
 
15
 
16
  def load_chat_tts_in_thread():
 
37
 
38
 
39
  def initialize_chat_tts():
40
+ with lock:
41
+ if chat_tts is None:
42
+ model_thread = threading.Thread(target=load_chat_tts_in_thread)
43
+ model_thread.start()
44
+ model_thread.join()
45
 
46
 
47
  def load_chat_tts():
48
  if chat_tts is None:
49
+ initialize_chat_tts()
50
  if chat_tts is None:
51
  raise Exception("Failed to load ChatTTS models")
52
  return chat_tts