File size: 6,890 Bytes
e1952b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import gradio as gr
import subprocess
import os
import threading
import queue
import json

# Função para rodar subprocessos
def run_subprocess(cmd, output_queue):
    try:
        process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
        for line in process.stdout:
            output_queue.put(line)
        process.wait()
        if process.returncode == 0:
            output_queue.put("Process completed successfully!")
        else:
            output_queue.put(f"Process failed with return code {process.returncode}")
    except Exception as e:
        output_queue.put(f"An error occurred: {str(e)}")

# Classe para gerenciar as configurações salvas
class ConfigManager:
    def __init__(self, filepath="settings.json"):
        self.filepath = filepath
        self.settings = self.load_settings()

    def save_settings(self):
        with open(self.filepath, "w") as f:
            json.dump(self.settings, f, indent=2, ensure_ascii=False)

    def load_settings(self):
        if os.path.exists(self.filepath):
            with open(self.filepath, "r") as f:
                return json.load(f)
        return {"saved_combinations": {}}

    def get_saved_combinations(self):
        return self.settings.get("saved_combinations", {})

    def add_combination(self, name, combination):
        self.settings["saved_combinations"][name] = combination
        self.save_settings()

config_manager = ConfigManager()

# Funções de treinamento e inferência
def run_training(model_type, config_path, start_checkpoint, results_path, data_paths, valid_paths, num_workers, device_ids):
    if not (model_type and config_path and results_path and data_paths and valid_paths):
        return "Error: Missing required inputs for training."

    cmd = [
        "python", "train.py",
        "--model_type", model_type,
        "--config_path", config_path,
        "--results_path", results_path,
        "--data_path", *data_paths.split(';'),
        "--valid_path", *valid_paths.split(';'),
        "--num_workers", str(num_workers),
        "--device_ids", device_ids
    ]

    if start_checkpoint:
        cmd += ["--start_check_point", start_checkpoint]

    output_queue = queue.Queue()
    threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()

    output = []
    while not output_queue.empty():
        output.append(output_queue.get())
    return "\n".join(output)

def run_inference(model_type, config_path, start_checkpoint, input_folder, store_dir, extract_instrumental):
    if not (model_type and config_path and input_folder and store_dir):
        return "Error: Missing required inputs for inference."

    cmd = [
        "python", "inference.py",
        "--model_type", model_type,
        "--config_path", config_path,
        "--input_folder", input_folder,
        "--store_dir", store_dir
    ]

    if start_checkpoint:
        cmd += ["--start_check_point", start_checkpoint]
    if extract_instrumental:
        cmd += ["--extract_instrumental"]

    output_queue = queue.Queue()
    threading.Thread(target=run_subprocess, args=(cmd, output_queue), daemon=True).start()

    output = []
    while not output_queue.empty():
        output.append(output_queue.get())
    return "\n".join(output)

# Interface Gradio
def add_preset(name, model_type, config_path, checkpoint):
    if not name:
        return "Error: Name is required to save a preset."

    config_manager.add_combination(name, {
        "model_type": model_type,
        "config_path": config_path,
        "checkpoint": checkpoint
    })
    return f"Preset '{name}' saved successfully."

saved_presets = config_manager.get_saved_combinations()
preset_names = list(saved_presets.keys())

def load_preset(name):
    if name in saved_presets:
        preset = saved_presets[name]
        return preset["model_type"], preset["config_path"], preset["checkpoint"]
    return "", "", ""

with gr.Blocks() as gui:
    gr.Markdown("# Music Source Separation Training & Inference GUI")

    # Treinamento
    with gr.Accordion("Training Configuration", open=False):
        model_type = gr.Dropdown(
            choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
        )
        config_path = gr.Textbox(label="Config File Path")
        start_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
        results_path = gr.Textbox(label="Results Path")
        data_paths = gr.Textbox(label="Data Paths (separated by ';')")
        valid_paths = gr.Textbox(label="Validation Paths (separated by ';')")
        num_workers = gr.Number(label="Number of Workers", value=4)
        device_ids = gr.Textbox(label="Device IDs (comma-separated)", value="0")
        train_output = gr.Textbox(label="Training Output", interactive=False)

        gr.Button("Run Training").click(
            run_training,
            inputs=[
                model_type, config_path, start_checkpoint, results_path, data_paths,
                valid_paths, num_workers, device_ids
            ],
            outputs=train_output
        )

    # Inferência
    with gr.Accordion("Inference Configuration", open=False):
        infer_model_type = gr.Dropdown(
            choices=["apollo", "bandit", "htdemucs", "scnet"], label="Model Type"
        )
        infer_config_path = gr.Textbox(label="Config File Path")
        infer_checkpoint = gr.Textbox(label="Checkpoint (Optional)")
        input_folder = gr.Textbox(label="Input Folder")
        store_dir = gr.Textbox(label="Output Folder")
        extract_instrumental = gr.Checkbox(label="Extract Instrumental", value=False)
        infer_output = gr.Textbox(label="Inference Output", interactive=False)

        gr.Button("Run Inference").click(
            run_inference,
            inputs=[
                infer_model_type, infer_config_path, infer_checkpoint, input_folder,
                store_dir, extract_instrumental
            ],
            outputs=infer_output
        )

    # Gerenciamento de Presets
    with gr.Accordion("Presets", open=False):
        preset_name = gr.Textbox(label="Preset Name")
        preset_model_type = gr.Textbox(label="Model Type")
        preset_config_path = gr.Textbox(label="Config Path")
        preset_checkpoint = gr.Textbox(label="Checkpoint")
        preset_feedback = gr.Textbox(label="Feedback", interactive=False)

        gr.Button("Save Preset").click(
            add_preset,
            inputs=[preset_name, preset_model_type, preset_config_path, preset_checkpoint],
            outputs=preset_feedback
        )

        preset_dropdown = gr.Dropdown(
            choices=preset_names, label="Load Preset"
        )
        gr.Button("Load Preset").click(
            load_preset, inputs=preset_dropdown, outputs=[preset_model_type, preset_config_path, preset_checkpoint]
        )

gui.launch(share=True)