KKKKDDME / gui-gradio.py
ArianatorQualquer's picture
Create gui-gradio.py
e1952b6 verified
import gradio as gr
import subprocess
import os
import threading
import queue
import json
# Função para rodar subprocessos
def run_subprocess(cmd, output_queue):
try:
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in process.stdout:
output_queue.put(line)
process.wait()
if process.returncode == 0:
output_queue.put("Process completed successfully!")
else:
output_queue.put(f"Process failed with return code {process.returncode}")
except Exception as e:
output_queue.put(f"An error occurred: {str(e)}")
# Classe para gerenciar as configurações salvas
class ConfigManager:
def __init__(self, filepath="settings.json"):
self.filepath = filepath
self.settings = self.load_settings()
def save_settings(self):
with open(self.filepath, "w") as f:
json.dump(self.settings, f, indent=2, ensure_ascii=False)
def load_settings(self):
if os.path.exists(self.filepath):
with open(self.filepath, "r") as f:
return json.load(f)
return {"saved_combinations": {}}
def get_saved_combinations(self):
return self.settings.get("saved_combinations", {})
def add_combination(self, name, combination):
self.settings["saved_combinations"][name] = combination
self.save_settings()
config_manager = ConfigManager()
# Funções de treinamento e inferência
def run_training(model_type, config_path, start_checkpoint, results_path, data_paths, valid_paths, num_workers, device_ids):
if not (model_type and config_path and results_path and data_paths and valid_paths):
return "Error: Missing required inputs for training."
cmd = [
"python", "train.py",
"--model_type", model_type,
"--config_path", config_path,
"--results_path", results_path,
"--data_path", *data_paths.split(';'),
"--valid_path", *valid_paths.split(';'),
"--num_workers", str(num_workers),
"--device_ids", device_ids
]
if start_checkpoint:
cmd += ["--start_check_point", start_checkpoint]
output_queue = queue.Queue()
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
output = []
while not output_queue.empty():
output.append(output_queue.get())
return "\n".join(output)
def run_inference(model_type, config_path, start_checkpoint, input_folder, store_dir, extract_instrumental):
if not (model_type and config_path and input_folder and store_dir):
return "Error: Missing required inputs for inference."
cmd = [
"python", "inference.py",
"--model_type", model_type,
"--config_path", config_path,
"--input_folder", input_folder,
"--store_dir", store_dir
]
if start_checkpoint:
cmd += ["--start_check_point", start_checkpoint]
if extract_instrumental:
cmd += ["--extract_instrumental"]
output_queue = queue.Queue()
threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()
output = []
while not output_queue.empty():
output.append(output_queue.get())
return "\n".join(output)
# Interface Gradio
def add_preset(name, model_type, config_path, checkpoint):
if not name:
return "Error: Name is required to save a preset."
config_manager.add_combination(name, {
"model_type": model_type,
"config_path": config_path,
"checkpoint": checkpoint
})
return f"Preset '{name}' saved successfully."
saved_presets = config_manager.get_saved_combinations()
preset_names = list(saved_presets.keys())
def load_preset(name):
if name in saved_presets:
preset = saved_presets[name]
return preset["model_type"], preset["config_path"], preset["checkpoint"]
return "", "", ""
with gr.Blocks() as gui:
gr.Markdown("# Music Source Separation Training & Inference GUI")
# Treinamento
with gr.Accordion("Training Configuration", open=False):
model_type = gr.Dropdown(
choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
)
config_path = gr.Textbox(label="Config File Path")
start_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
results_path = gr.Textbox(label="Results Path")
data_paths = gr.Textbox(label="Data Paths (separated by ';')")
valid_paths = gr.Textbox(label="Validation Paths (separated by ';')")
num_workers = gr.Number(label="Number of Workers", value=4)
device_ids = gr.Textbox(label="Device IDs (comma-separated)", value="0")
train_output = gr.Textbox(label="Training Output", interactive=False)
gr.Button("Run Training").click(
run_training,
inputs=[
model_type, config_path, start_checkpoint, results_path, data_paths,
valid_paths, num_workers, device_ids
],
outputs=train_output
)
# Inferência
with gr.Accordion("Inference Configuration", open=False):
infer_model_type = gr.Dropdown(
choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
)
infer_config_path = gr.Textbox(label="Config File Path")
infer_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
input_folder = gr.Textbox(label="Input Folder")
store_dir = gr.Textbox(label="Output Folder")
extract_instrumental = gr.Checkbox(label="Extract Instrumental", value=False)
infer_output = gr.Textbox(label="Inference Output", interactive=False)
gr.Button("Run Inference").click(
run_inference,
inputs=[
infer_model_type, infer_config_path, infer_checkpoint, input_folder,
store_dir, extract_instrumental
],
outputs=infer_output
)
# Gerenciamento de Presets
with gr.Accordion("Presets", open=False):
preset_name = gr.Textbox(label="Preset Name")
preset_model_type = gr.Textbox(label="Model Type")
preset_config_path = gr.Textbox(label="Config Path")
preset_checkpoint = gr.Textbox(label="Checkpoint")
preset_feedback = gr.Textbox(label="Feedback", interactive=False)
gr.Button("Save Preset").click(
add_preset,
inputs=[preset_name, preset_model_type, preset_config_path, preset_checkpoint],
outputs=preset_feedback
)
preset_dropdown = gr.Dropdown(
choices=preset_names, label="Load Preset"
)
gr.Button("Load Preset").click(
load_preset, inputs=preset_dropdown, outputs=[preset_model_type, preset_config_path, preset_checkpoint]
)
gui.launch(share=True)