File size: 4,215 Bytes
1df74c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
from modules.Enhancer.ResembleEnhance import unload_enhancer
from modules.webui import webui_config
from modules.webui.webui_utils import get_speaker_names
from .ft_ui_utils import get_datasets_listfile, run_speaker_ft
from .ProcessMonitor import ProcessMonitor
from modules.speaker import speaker_mgr
from modules.models import unload_chat_tts
class SpeakerFt:
def __init__(self):
self.process_monitor = ProcessMonitor()
self.status_str = "idle"
def unload_main_thread_models(self):
unload_chat_tts()
unload_enhancer()
def run(
self,
batch_size: int,
epochs: int,
lr: str,
train_text: bool,
data_path: str,
select_speaker: str = "",
):
if self.process_monitor.process:
return
self.unload_main_thread_models()
spk_path = None
if select_speaker != "" and select_speaker != "none":
select_speaker = select_speaker.split(" : ")[1].strip()
spk = speaker_mgr.get_speaker(select_speaker)
if spk is None:
return ["Speaker not found"]
spk_filename = speaker_mgr.get_speaker_filename(spk.id)
spk_path = f"./data/speakers/{spk_filename}"
command = ["python3", "-m", "modules.finetune.train_speaker"]
command += [
f"--batch_size={batch_size}",
f"--epochs={epochs}",
f"--data_path={data_path}",
]
if train_text:
command.append("--train_text")
if spk_path:
command.append(f"--init_speaker={spk_path}")
self.status("Training process starting")
self.process_monitor.start_process(command)
self.status("Training started")
def status(self, text: str):
self.status_str = text
def flush(self):
stdout, stderr = self.process_monitor.get_output()
return f"{self.status_str}\n{stdout}\n{stderr}"
def clear(self):
self.process_monitor.stdout = ""
self.process_monitor.stderr = ""
self.status("Logs cleared")
def stop(self):
self.process_monitor.stop_process()
self.status("Training stopped")
def create_speaker_ft_tab(demo: gr.Blocks):
spk_ft = SpeakerFt()
speakers, speaker_names = get_speaker_names()
speaker_names = ["none"] + speaker_names
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
gr.Markdown("🎛️hparams")
dataset_input = gr.Dropdown(
label="Dataset", choices=get_datasets_listfile()
)
lr_input = gr.Textbox(label="Learning Rate", value="1e-2")
epochs_input = gr.Slider(
label="Epochs", value=10, minimum=1, maximum=100, step=1
)
batch_size_input = gr.Slider(
label="Batch Size", value=4, minimum=1, maximum=64, step=1
)
train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True)
init_spk_dropdown = gr.Dropdown(
label="Initial Speaker",
choices=speaker_names,
value="none",
)
with gr.Group():
start_train_btn = gr.Button("Start Training")
stop_train_btn = gr.Button("Stop Training")
clear_train_btn = gr.Button("Clear logs")
with gr.Column(scale=5):
with gr.Group():
# log
gr.Markdown("📜logs")
log_output = gr.Textbox(
show_label=False, label="Log", value="", lines=20, interactive=True
)
start_train_btn.click(
spk_ft.run,
inputs=[
batch_size_input,
epochs_input,
lr_input,
train_text_checkbox,
dataset_input,
init_spk_dropdown,
],
outputs=[],
)
stop_train_btn.click(spk_ft.stop)
clear_train_btn.click(spk_ft.clear)
if webui_config.experimental:
demo.load(spk_ft.flush, every=1, outputs=[log_output])
|