Spaces:
Sleeping
Sleeping
Nithya
commited on
Commit
·
f2917d8
1
Parent(s):
ab069bc
made the files run on gpu
Browse files- app.py +1 -1
- src/generate_utils.py +2 -2
app.py
CHANGED
@@ -29,7 +29,7 @@ audio_path = 'models/pitch_to_audio/'
|
|
29 |
# audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
|
30 |
# db_path_audio = '/home/mila/n/nithya.shikarpur/scratch/pitch-diffusion/data/merged_data-finalest/cached-audio-pitch-16k'
|
31 |
|
32 |
-
device = '
|
33 |
|
34 |
global_ind = -1
|
35 |
global_audios = np.array([0.0])
|
|
|
29 |
# audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
|
30 |
# db_path_audio = '/home/mila/n/nithya.shikarpur/scratch/pitch-diffusion/data/merged_data-finalest/cached-audio-pitch-16k'
|
31 |
|
32 |
+
device = 'cuda'
|
33 |
|
34 |
global_ind = -1
|
35 |
global_audios = np.array([0.0])
|
src/generate_utils.py
CHANGED
@@ -61,7 +61,7 @@ def load_processed_pitch(pitch,
|
|
61 |
def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
|
62 |
gin.parse_config_file(config)
|
63 |
model = UNet()
|
64 |
-
model.load_state_dict(torch.load(ckpt, map_location='
|
65 |
model.to(device)
|
66 |
if qt is not None:
|
67 |
qt = joblib.load(qt)
|
@@ -80,7 +80,7 @@ def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
|
|
80 |
def load_audio_model(config, ckpt, qt = None, device='cuda'):
|
81 |
gin.parse_config_file(config)
|
82 |
model = UNetPitchConditioned() # there are no gin parameters for some reason
|
83 |
-
model.load_state_dict(torch.load(ckpt, map_location='
|
84 |
model.to(device)
|
85 |
if qt is not None:
|
86 |
qt = joblib.load(qt)
|
|
|
61 |
def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
|
62 |
gin.parse_config_file(config)
|
63 |
model = UNet()
|
64 |
+
model.load_state_dict(torch.load(ckpt, map_location='cuda')['state_dict'])
|
65 |
model.to(device)
|
66 |
if qt is not None:
|
67 |
qt = joblib.load(qt)
|
|
|
80 |
def load_audio_model(config, ckpt, qt = None, device='cuda'):
|
81 |
gin.parse_config_file(config)
|
82 |
model = UNetPitchConditioned() # there are no gin parameters for some reason
|
83 |
+
model.load_state_dict(torch.load(ckpt, map_location='cuda')['state_dict'])
|
84 |
model.to(device)
|
85 |
if qt is not None:
|
86 |
qt = joblib.load(qt)
|