|
import gradio as gr |
|
import subprocess |
|
import os |
|
import threading |
|
import queue |
|
import json |
|
|
|
|
|
def run_subprocess(cmd, output_queue): |
|
try: |
|
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) |
|
for line in process.stdout: |
|
output_queue.put(line) |
|
process.wait() |
|
if process.returncode == 0: |
|
output_queue.put("Process completed successfully!") |
|
else: |
|
output_queue.put(f"Process failed with return code {process.returncode}") |
|
except Exception as e: |
|
output_queue.put(f"An error occurred: {str(e)}") |
|
|
|
|
|
class ConfigManager: |
|
def __init__(self, filepath="settings.json"): |
|
self.filepath = filepath |
|
self.settings = self.load_settings() |
|
|
|
def save_settings(self): |
|
with open(self.filepath, "w") as f: |
|
json.dump(self.settings, f, indent=2, ensure_ascii=False) |
|
|
|
def load_settings(self): |
|
if os.path.exists(self.filepath): |
|
with open(self.filepath, "r") as f: |
|
return json.load(f) |
|
return {"saved_combinations": {}} |
|
|
|
def get_saved_combinations(self): |
|
return self.settings.get("saved_combinations", {}) |
|
|
|
def add_combination(self, name, combination): |
|
self.settings["saved_combinations"][name] = combination |
|
self.save_settings() |
|
|
|
config_manager = ConfigManager() |
|
|
|
|
|
def run_training(model_type, config_path, start_checkpoint, results_path, data_paths, valid_paths, num_workers, device_ids): |
|
if not (model_type and config_path and results_path and data_paths and valid_paths): |
|
return "Error: Missing required inputs for training." |
|
|
|
cmd = [ |
|
"python", "train.py", |
|
"--model_type", model_type, |
|
"--config_path", config_path, |
|
"--results_path", results_path, |
|
"--data_path", *data_paths.split(';'), |
|
"--valid_path", *valid_paths.split(';'), |
|
"--num_workers", str(num_workers), |
|
"--device_ids", device_ids |
|
] |
|
|
|
if start_checkpoint: |
|
cmd += ["--start_check_point", start_checkpoint] |
|
|
|
output_queue = queue.Queue() |
|
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start() |
|
|
|
output = [] |
|
while not output_queue.empty(): |
|
output.append(output_queue.get()) |
|
return "\n".join(output) |
|
|
|
def run_inference(model_type, config_path, start_checkpoint, input_folder, store_dir, extract_instrumental): |
|
if not (model_type and config_path and input_folder and store_dir): |
|
return "Error: Missing required inputs for inference." |
|
|
|
cmd = [ |
|
"python", "inference.py", |
|
"--model_type", model_type, |
|
"--config_path", config_path, |
|
"--input_folder", input_folder, |
|
"--store_dir", store_dir |
|
] |
|
|
|
if start_checkpoint: |
|
cmd += ["--start_check_point", start_checkpoint] |
|
if extract_instrumental: |
|
cmd += ["--extract_instrumental"] |
|
|
|
output_queue = queue.Queue() |
|
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start() |
|
|
|
output = [] |
|
while not output_queue.empty(): |
|
output.append(output_queue.get()) |
|
return "\n".join(output) |
|
|
|
|
|
def add_preset(name, model_type, config_path, checkpoint): |
|
if not name: |
|
return "Error: Name is required to save a preset." |
|
|
|
config_manager.add_combination(name, { |
|
"model_type": model_type, |
|
"config_path": config_path, |
|
"checkpoint": checkpoint |
|
}) |
|
return f"Preset '{name}' saved successfully." |
|
|
|
saved_presets = config_manager.get_saved_combinations() |
|
preset_names = list(saved_presets.keys()) |
|
|
|
def load_preset(name): |
|
if name in saved_presets: |
|
preset = saved_presets[name] |
|
return preset["model_type"], preset["config_path"], preset["checkpoint"] |
|
return "", "", "" |
|
|
|
with gr.Blocks() as gui: |
|
gr.Markdown("# Music Source Separation Training & Inference GUI") |
|
|
|
|
|
with gr.Accordion("Training Configuration", open=False): |
|
model_type = gr.Dropdown( |
|
choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type" |
|
) |
|
config_path = gr.Textbox(label="Config File Path") |
|
start_checkpoint = gr.Textbox(label="Checkpoint (Optional)") |
|
results_path = gr.Textbox(label="Results Path") |
|
data_paths = gr.Textbox(label="Data Paths (separated by ';')") |
|
valid_paths = gr.Textbox(label="Validation Paths (separated by ';')") |
|
num_workers = gr.Number(label="Number of Workers", value=4) |
|
device_ids = gr.Textbox(label="Device IDs (comma-separated)", value="0") |
|
train_output = gr.Textbox(label="Training Output", interactive=False) |
|
|
|
gr.Button("Run Training").click( |
|
run_training, |
|
inputs=[ |
|
model_type, config_path, start_checkpoint, results_path, data_paths, |
|
valid_paths, num_workers, device_ids |
|
], |
|
outputs=train_output |
|
) |
|
|
|
|
|
with gr.Accordion("Inference Configuration", open=False): |
|
infer_model_type = gr.Dropdown( |
|
choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type" |
|
) |
|
infer_config_path = gr.Textbox(label="Config File Path") |
|
infer_checkpoint = gr.Textbox(label="Checkpoint (Optional)") |
|
input_folder = gr.Textbox(label="Input Folder") |
|
store_dir = gr.Textbox(label="Output Folder") |
|
extract_instrumental = gr.Checkbox(label="Extract Instrumental", value=False) |
|
infer_output = gr.Textbox(label="Inference Output", interactive=False) |
|
|
|
gr.Button("Run Inference").click( |
|
run_inference, |
|
inputs=[ |
|
infer_model_type, infer_config_path, infer_checkpoint, input_folder, |
|
store_dir, extract_instrumental |
|
], |
|
outputs=infer_output |
|
) |
|
|
|
|
|
with gr.Accordion("Presets", open=False): |
|
preset_name = gr.Textbox(label="Preset Name") |
|
preset_model_type = gr.Textbox(label="Model Type") |
|
preset_config_path = gr.Textbox(label="Config Path") |
|
preset_checkpoint = gr.Textbox(label="Checkpoint") |
|
preset_feedback = gr.Textbox(label="Feedback", interactive=False) |
|
|
|
gr.Button("Save Preset").click( |
|
add_preset, |
|
inputs=[preset_name, preset_model_type, preset_config_path, preset_checkpoint], |
|
outputs=preset_feedback |
|
) |
|
|
|
preset_dropdown = gr.Dropdown( |
|
choices=preset_names, label="Load Preset" |
|
) |
|
gr.Button("Load Preset").click( |
|
load_preset, inputs=preset_dropdown, outputs=[preset_model_type, preset_config_path, preset_checkpoint] |
|
) |
|
|
|
gui.launch(share=True) |