sanchit-gandhi HF staff commited on
Commit
b4d4d63
1 Parent(s): 3dc00eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -155,6 +155,14 @@ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
155
  max_new_tokens = int(frame_rate * audio_length_in_s)
156
  play_steps = int(frame_rate * play_steps_in_s)
157
 
 
 
 
 
 
 
 
 
158
  streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
159
 
160
  generation_kwargs = dict(
 
155
  max_new_tokens = int(frame_rate * audio_length_in_s)
156
  play_steps = int(frame_rate * play_steps_in_s)
157
 
158
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
159
+
160
+ if device != model.device:
161
+ model.to(device)
162
+
163
+ if device == "cuda:0":
164
+ model.to(device).half();
165
+
166
  streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
167
 
168
  generation_kwargs = dict(