r3gm commited on
Commit
3ace12e
1 Parent(s): 1a37ac3

Update src/mdx.py

Browse files
Files changed (1) hide show
  1. src/mdx.py +3 -1
src/mdx.py CHANGED
@@ -166,6 +166,8 @@ class MDX:
166
  waves = np.array(wave_p[:, i:i + self.model.chunk_size])
167
  mix_waves.append(waves)
168
 
 
 
169
  mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
170
 
171
  return mix_waves, pad, trim
@@ -240,7 +242,7 @@ def run_mdx(model_params, output_dir, model_path, filename, exclude_main=False,
240
 
241
  #device_properties = torch.cuda.get_device_properties(device)
242
  print("Device", device)
243
- vram_gb = 6 #device_properties.total_memory / 1024**3
244
  m_threads = 1 if vram_gb < 8 else 2
245
 
246
  model_hash = MDX.get_hash(model_path)
 
166
  waves = np.array(wave_p[:, i:i + self.model.chunk_size])
167
  mix_waves.append(waves)
168
 
169
+ print(self.device)
170
+
171
  mix_waves = torch.tensor(mix_waves, dtype=torch.float32).to(self.device)
172
 
173
  return mix_waves, pad, trim
 
242
 
243
  #device_properties = torch.cuda.get_device_properties(device)
244
  print("Device", device)
245
+ vram_gb = 12 #device_properties.total_memory / 1024**3
246
  m_threads = 1 if vram_gb < 8 else 2
247
 
248
  model_hash = MDX.get_hash(model_path)