zhzluke96 commited on
Commit
650b56c
1 Parent(s): 8f52106
Files changed (1) hide show
  1. modules/models.py +10 -3
modules/models.py CHANGED
@@ -10,15 +10,20 @@ import gc
10
  logger = logging.getLogger(__name__)
11
 
12
  chat_tts = None
 
 
 
 
13
  load_event = threading.Event()
14
 
15
 
16
  def load_chat_tts_in_thread():
17
  global chat_tts
18
  if chat_tts:
19
- load_event.set() # 如果已经加载过,直接设置事件
20
  return
21
 
 
22
  chat_tts = ChatTTS.Chat()
23
  chat_tts.load_models(
24
  compile=config.runtime_env_vars.compile,
@@ -33,17 +38,19 @@ def load_chat_tts_in_thread():
33
  )
34
 
35
  devices.torch_gc()
36
- load_event.set() # 设置事件,表示加载完成
 
37
 
38
 
39
  def initialize_chat_tts():
40
  model_thread = threading.Thread(target=load_chat_tts_in_thread)
41
  model_thread.start()
 
42
 
43
 
44
  def load_chat_tts():
45
  if chat_tts is None:
46
- initialize_chat_tts()
47
  load_event.wait()
48
  return chat_tts
49
 
 
10
  logger = logging.getLogger(__name__)
11
 
12
  chat_tts = None
13
+ # 某些平台上,不让在主线程中加载模型,否则会出现错误
14
+ # huggingface Error:
15
+ # RuntimeError: CUDA must not be initialized in the main process on Spaces with Stateless GPU environment.
16
+ # You can look at this Stacktrace to find out which part of your code triggered a CUDA init
17
  load_event = threading.Event()
18
 
19
 
20
  def load_chat_tts_in_thread():
21
  global chat_tts
22
  if chat_tts:
23
+ load_event.set()
24
  return
25
 
26
+ logger.info("Loading ChatTTS models")
27
  chat_tts = ChatTTS.Chat()
28
  chat_tts.load_models(
29
  compile=config.runtime_env_vars.compile,
 
38
  )
39
 
40
  devices.torch_gc()
41
+ load_event.set()
42
+ logger.info("ChatTTS models loaded")
43
 
44
 
45
  def initialize_chat_tts():
46
  model_thread = threading.Thread(target=load_chat_tts_in_thread)
47
  model_thread.start()
48
+ return model_thread
49
 
50
 
51
  def load_chat_tts():
52
  if chat_tts is None:
53
+ initialize_chat_tts().join()
54
  load_event.wait()
55
  return chat_tts
56