from glob import glob import os import shutil import gradio as gr from infer.lib.train.process_ckpt import extract_small_model from app.train import train_index from huggingface_hub import upload_folder def download_weight(exp_dir: str) -> str: checkpoints = glob(f"{exp_dir}/G_*.pth") if not checkpoints: raise gr.Error("No checkpoint found") latest_checkpoint = max(checkpoints, key=os.path.getctime) print(f"Latest checkpoint: {latest_checkpoint}") out = os.path.join(exp_dir, f"model.pth") extract_small_model( latest_checkpoint, out, "40k", True, "Model trained by ZeroGPU.", "v2" ) return out def download_inference_pack(exp_dir: str) -> str: net_g = download_weight(exp_dir) index = glob(f"{exp_dir}/added_*.index") if not index: train_index(exp_dir) index = glob(f"{exp_dir}/added_*.index") if not index: raise gr.Error("Index not found") # make zip of those two files tmp = os.path.join(exp_dir, "inference_pack") if os.path.exists(tmp): shutil.rmtree(tmp) os.makedirs(tmp) shutil.copy(net_g, tmp) shutil.copy(index[0], tmp) shutil.make_archive(tmp, "zip", tmp) shutil.rmtree(tmp) return f"{tmp}.zip" def download_expdir(exp_dir: str) -> str: shutil.make_archive(exp_dir, "zip", exp_dir) return f"{exp_dir}.zip" def upload_to_huggingface(exp_dir: str, repo_id: str, token: str) -> str: commit = upload_folder( repo_id=repo_id, folder_path=exp_dir, ignore_patterns=["_data", "*.zip", "tmp.wav"], token=token if token.startswith("hf_") else None, ) return commit.commit_url def remove_legacy_checkpoints(exp_dir: str): checkpoints = glob(f"{exp_dir}/G_*.pth") if not checkpoints: raise gr.Error("No checkpoint found") latest_checkpoint = max(checkpoints, key=os.path.getctime) print(f"Latest checkpoint: {latest_checkpoint}") for checkpoint in checkpoints: if checkpoint != latest_checkpoint: os.remove(checkpoint) print(f"Removed: {checkpoint}") checkpoints = glob(f"{exp_dir}/D_*.pth") if not checkpoints: raise gr.Error("No checkpoint found") latest_checkpoint = max(checkpoints, key=os.path.getctime) print(f"Latest checkpoint: {latest_checkpoint}") for checkpoint in checkpoints: if checkpoint != latest_checkpoint: os.remove(checkpoint) print(f"Removed: {checkpoint}") def remove_expdir(exp_dir: str) -> str: shutil.rmtree(exp_dir) return "" class ExportTab: def __init__(self): pass def ui(self): gr.Markdown("# Download Model or Experiment Directory") gr.Markdown( "You can download the latest model or the entire experiment directory here." ) with gr.Row(): self.download_weight_btn = gr.Button( value="Latest model (for inferencing)", variant="primary" ) self.download_weight_output = gr.File(label="Prune latest model") with gr.Row(): self.download_inference_pack_btn = gr.Button( value="Download inference pack (model + index)", variant="primary" ) self.download_inference_pack_output = gr.File(label="Inference pack") with gr.Row(): self.download_expdir_btn = gr.Button( value="Download experiment directory", variant="primary" ) self.download_expdir_output = gr.File(label="Archive experiment directory") with gr.Row(): with gr.Column(): gr.Markdown("### Upload to Hugging Face") gr.Markdown( "You can upload the entire experiment directory to Hugging Face." ) self.commit_link = gr.Markdown("") with gr.Column(): self.repo_id = gr.Textbox(label="Repository ID") self.token = gr.Textbox(label="Personal access token") self.upload_to_huggingface_btn = gr.Button( value="Upload to Hugging Face", variant="primary" ) with gr.Row(): self.remove_legacy_checkpoints_btn = gr.Button( value="Remove legacy checkpoints" ) with gr.Row(): self.remove_expdir_btn = gr.Button( value="REMOVE experiment directory", variant="stop" ) def build(self, exp_dir: gr.Textbox): self.download_weight_btn.click( fn=download_weight, inputs=[exp_dir], outputs=[self.download_weight_output], ) self.download_inference_pack_btn.click( fn=download_inference_pack, inputs=[exp_dir], outputs=[self.download_inference_pack_output], ) self.download_expdir_btn.click( fn=download_expdir, inputs=[exp_dir], outputs=[self.download_expdir_output], ) self.upload_to_huggingface_btn.click( fn=upload_to_huggingface, inputs=[exp_dir, self.repo_id, self.token], outputs=[self.commit_link], ) self.remove_legacy_checkpoints_btn.click( fn=remove_legacy_checkpoints, inputs=[exp_dir], ) self.remove_expdir_btn.click( fn=remove_expdir, inputs=[exp_dir], outputs=[exp_dir], )