mrfakename commited on
Commit
75e5a1a
·
verified ·
1 Parent(s): c0fb8c8

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/infer/utils_infer.py +11 -4
src/f5_tts/infer/utils_infer.py CHANGED
@@ -139,7 +139,9 @@ asr_pipe = None
139
  def initialize_asr_pipeline(device=device, dtype=None):
140
  if dtype is None:
141
  dtype = (
142
- torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
143
  )
144
  global asr_pipe
145
  asr_pipe = pipeline(
@@ -172,7 +174,9 @@ def transcribe(ref_audio, language=None):
172
  def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
173
  if dtype is None:
174
  dtype = (
175
- torch.float16 if device == "cuda" and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
176
  )
177
  model = model.to(dtype)
178
 
@@ -180,9 +184,9 @@ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
180
  if ckpt_type == "safetensors":
181
  from safetensors.torch import load_file
182
 
183
- checkpoint = load_file(ckpt_path)
184
  else:
185
- checkpoint = torch.load(ckpt_path, weights_only=True)
186
 
187
  if use_ema:
188
  if ckpt_type == "safetensors":
@@ -204,6 +208,9 @@ def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
204
  checkpoint = {"model_state_dict": checkpoint}
205
  model.load_state_dict(checkpoint["model_state_dict"])
206
 
 
 
 
207
  return model.to(device)
208
 
209
 
 
139
  def initialize_asr_pipeline(device=device, dtype=None):
140
  if dtype is None:
141
  dtype = (
142
+ torch.float16
143
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
144
+ else torch.float32
145
  )
146
  global asr_pipe
147
  asr_pipe = pipeline(
 
174
  def load_checkpoint(model, ckpt_path, device, dtype=None, use_ema=True):
175
  if dtype is None:
176
  dtype = (
177
+ torch.float16
178
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(device).major >= 6
179
+ else torch.float32
180
  )
181
  model = model.to(dtype)
182
 
 
184
  if ckpt_type == "safetensors":
185
  from safetensors.torch import load_file
186
 
187
+ checkpoint = load_file(ckpt_path, device=device)
188
  else:
189
+ checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True)
190
 
191
  if use_ema:
192
  if ckpt_type == "safetensors":
 
208
  checkpoint = {"model_state_dict": checkpoint}
209
  model.load_state_dict(checkpoint["model_state_dict"])
210
 
211
+ del checkpoint
212
+ torch.cuda.empty_cache()
213
+
214
  return model.to(device)
215
 
216