Nithya commited on
Commit
f2917d8
·
1 Parent(s): ab069bc

made the files run on gpu

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. src/generate_utils.py +2 -2
app.py CHANGED
@@ -29,7 +29,7 @@ audio_path = 'models/pitch_to_audio/'
29
  # audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
30
  # db_path_audio = '/home/mila/n/nithya.shikarpur/scratch/pitch-diffusion/data/merged_data-finalest/cached-audio-pitch-16k'
31
 
32
- device = 'cpu'
33
 
34
  global_ind = -1
35
  global_audios = np.array([0.0])
 
29
  # audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
30
  # db_path_audio = '/home/mila/n/nithya.shikarpur/scratch/pitch-diffusion/data/merged_data-finalest/cached-audio-pitch-16k'
31
 
32
+ device = 'cuda'
33
 
34
  global_ind = -1
35
  global_audios = np.array([0.0])
src/generate_utils.py CHANGED
@@ -61,7 +61,7 @@ def load_processed_pitch(pitch,
61
  def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
62
  gin.parse_config_file(config)
63
  model = UNet()
64
- model.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
65
  model.to(device)
66
  if qt is not None:
67
  qt = joblib.load(qt)
@@ -80,7 +80,7 @@ def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
80
  def load_audio_model(config, ckpt, qt = None, device='cuda'):
81
  gin.parse_config_file(config)
82
  model = UNetPitchConditioned() # there are no gin parameters for some reason
83
- model.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
84
  model.to(device)
85
  if qt is not None:
86
  qt = joblib.load(qt)
 
61
  def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
62
  gin.parse_config_file(config)
63
  model = UNet()
64
+ model.load_state_dict(torch.load(ckpt, map_location='cuda')['state_dict'])
65
  model.to(device)
66
  if qt is not None:
67
  qt = joblib.load(qt)
 
80
  def load_audio_model(config, ckpt, qt = None, device='cuda'):
81
  gin.parse_config_file(config)
82
  model = UNetPitchConditioned() # there are no gin parameters for some reason
83
+ model.load_state_dict(torch.load(ckpt, map_location='cuda')['state_dict'])
84
  model.to(device)
85
  if qt is not None:
86
  qt = joblib.load(qt)