jhtonyKoo commited on
Commit
cac2c49
·
1 Parent(s): 879e4b5

update loss

Browse files
Files changed (2) hide show
  1. app.py +27 -4
  2. inference.py +2 -2
app.py CHANGED
@@ -222,10 +222,24 @@ with gr.Blocks() as demo:
222
  num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
223
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
224
  learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
225
- af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
226
  loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
227
- clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio", visible=False)
228
- clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def update_clap_options(loss_function):
231
  if loss_function == "CLAPFeatureLoss":
@@ -236,9 +250,18 @@ with gr.Blocks() as demo:
236
  loss_function.change(
237
  update_clap_options,
238
  inputs=[loss_function],
239
- outputs=[clap_target_type, clap_text_prompt]
240
  )
241
 
 
 
 
 
 
 
 
 
 
242
  ito_button = gr.Button("Perform ITO")
243
 
244
  with gr.Row():
 
222
  num_steps = gr.Slider(minimum=1, maximum=100, value=10, step=1, label="Number of Steps")
223
  optimizer = gr.Dropdown(["Adam", "RAdam", "SGD"], value="RAdam", label="Optimizer")
224
  learning_rate = gr.Slider(minimum=0.0001, maximum=0.1, value=0.001, step=0.0001, label="Learning Rate")
 
225
  loss_function = gr.Radio(["AudioFeatureLoss", "CLAPFeatureLoss"], label="Loss Function", value="AudioFeatureLoss")
226
+
227
+ # af_weights = gr.Textbox(label="AudioFeatureLoss Weights (comma-separated)", value="0.1,0.001,1.0,1.0,0.1")
228
+ # clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio", visible=False)
229
+ # clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
230
+
231
+ # Audio Feature Loss weights
232
+ with gr.Column(visible=True) as audio_feature_weights:
233
+ af_weights = gr.Textbox(
234
+ label="AudioFeatureLoss Weights (comma-separated)",
235
+ value="0.1,0.001,1.0,1.0,0.1",
236
+ info="RMS, Crest Factor, Stereo Width, Stereo Imbalance, Bark Spectrum"
237
+ )
238
+
239
+ # CLAP Loss options
240
+ with gr.Column(visible=False) as clap_options:
241
+ clap_target_type = gr.Radio(["Audio", "Text"], label="CLAP Target Type", value="Audio")
242
+ clap_text_prompt = gr.Textbox(label="CLAP Text Prompt", visible=False)
243
 
244
  def update_clap_options(loss_function):
245
  if loss_function == "CLAPFeatureLoss":
 
250
  loss_function.change(
251
  update_clap_options,
252
  inputs=[loss_function],
253
+ outputs=[audio_feature_weights, clap_options]
254
  )
255
 
256
+ def update_clap_text_prompt(clap_target_type):
257
+ return gr.update(visible=clap_target_type == "Text")
258
+
259
+ clap_target_type.change(
260
+ update_clap_text_prompt,
261
+ inputs=[clap_target_type],
262
+ outputs=[clap_text_prompt]
263
+ )
264
+
265
  ito_button = gr.Button("Perform ITO")
266
 
267
  with gr.Row():
inference.py CHANGED
@@ -93,10 +93,10 @@ class MasteringStyleTransfer:
93
  losses = af_loss(output_audio, reference_tensor)
94
  elif ito_config['loss_function'] == 'CLAPFeatureLoss':
95
  if ito_config['clap_target_type'] == 'Audio':
96
- target = ito_reference_tensor
97
  else:
98
  target = ito_config['clap_text_prompt']
99
- losses = self.clap_loss(est_targets, target, self.args.sample_rate)
100
  total_loss = sum(losses.values())
101
 
102
  if total_loss < min_loss:
 
93
  losses = af_loss(output_audio, reference_tensor)
94
  elif ito_config['loss_function'] == 'CLAPFeatureLoss':
95
  if ito_config['clap_target_type'] == 'Audio':
96
+ target = reference_tensor
97
  else:
98
  target = ito_config['clap_text_prompt']
99
+ losses = self.clap_loss(output_audio, target, self.args.sample_rate)
100
  total_loss = sum(losses.values())
101
 
102
  if total_loss < min_loss: