jhtonyKoo commited on
Commit
0ea0beb
·
1 Parent(s): dd08b4b

modify app

Browse files
Files changed (1) hide show
  1. inference.py +16 -4
inference.py CHANGED
@@ -105,12 +105,11 @@ class MasteringStyleTransfer:
105
  else:
106
  divergence_counter = 0
107
 
108
- # Log top 10 parameter differences
109
  if step == 0:
110
  initial_params = current_params
111
- top_10_diff = self.get_top_10_diff_string(initial_params, current_params)
112
- log_entry = f"Step {step + 1}, Loss: {total_loss.item():.4f}\n{top_10_diff}\n"
113
- yield log_entry, output_audio, current_params, step + 1
114
 
115
  if divergence_counter >= 10:
116
  print(f"Optimization stopped early due to divergence at step {step}")
@@ -119,8 +118,21 @@ class MasteringStyleTransfer:
119
  total_loss.backward()
120
  optimizer.step()
121
 
 
 
122
  return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1
123
 
 
 
 
 
 
 
 
 
 
 
 
124
  def preprocess_audio(self, audio, target_sample_rate=44100):
125
  sample_rate, data = audio
126
 
 
105
  else:
106
  divergence_counter = 0
107
 
108
+ # Log top 5 parameter differences
109
  if step == 0:
110
  initial_params = current_params
111
+ top_5_diff = self.get_top_5_diff_string(initial_params, current_params)
112
+ log_entry = f"Step {step + 1}, Loss: {total_loss.item():.4f}\n{top_5_diff}\n"
 
113
 
114
  if divergence_counter >= 10:
115
  print(f"Optimization stopped early due to divergence at step {step}")
 
118
  total_loss.backward()
119
  optimizer.step()
120
 
121
+ yield log_entry, output_audio.detach(), current_params, step + 1, total_loss.item()
122
+
123
  return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1
124
 
125
+ def get_top_5_diff_string(self, initial_params, current_params):
126
+ diff_dict = {}
127
+ for key in initial_params.keys():
128
+ diff = abs(current_params[key] - initial_params[key])
129
+ diff_dict[key] = diff
130
+
131
+ sorted_diff = sorted(diff_dict.items(), key=lambda x: x[1], reverse=True)
132
+ top_5_diff = sorted_diff[:5]
133
+
134
+ return "\n".join([f"{key}: {value:.4f}" for key, value in top_5_diff])
135
+
136
  def preprocess_audio(self, audio, target_sample_rate=44100):
137
  sample_rate, data = audio
138