from glob import glob import os import shutil import tempfile import gradio as gr from infer.lib.train.process_ckpt import extract_small_model from app.train import train_index def download_weight(exp_dir: str) -> str: checkpoints = glob(f"{exp_dir}/G_*.pth") if not checkpoints: raise gr.Error("No checkpoint found") latest_checkpoint = max(checkpoints, key=os.path.getctime) print(f"Latest checkpoint: {latest_checkpoint}") name = os.path.basename(exp_dir) out = os.path.join(exp_dir, f"{name}.pth") extract_small_model( latest_checkpoint, out, "40k", True, "Model trained by ZeroGPU.", "v2" ) return out def download_inference_pack(exp_dir: str) -> str: net_g = download_weight(exp_dir) index = glob(f"{exp_dir}/added_*.index") if not index: train_index(exp_dir) index = glob(f"{exp_dir}/added_*.index") if not index: raise gr.Error("Index not found") # make zip of those two files tmp = os.path.join(exp_dir, "inference_pack") if os.path.exists(tmp): shutil.rmtree(tmp) os.makedirs(tmp) shutil.copy(net_g, tmp) shutil.copy(index[0], tmp) shutil.make_archive(tmp, "zip", tmp) shutil.rmtree(tmp) return f"{tmp}.zip" def download_expdir(exp_dir: str) -> str: shutil.make_archive(exp_dir, "zip", exp_dir) return f"{exp_dir}.zip" def remove_legacy_checkpoints(exp_dir: str): checkpoints = glob(f"{exp_dir}/G_*.pth") if not checkpoints: raise gr.Error("No checkpoint found") latest_checkpoint = max(checkpoints, key=os.path.getctime) print(f"Latest checkpoint: {latest_checkpoint}") for checkpoint in checkpoints: if checkpoint != latest_checkpoint: os.remove(checkpoint) print(f"Removed: {checkpoint}") checkpoints = glob(f"{exp_dir}/D_*.pth") if not checkpoints: raise gr.Error("No checkpoint found") latest_checkpoint = max(checkpoints, key=os.path.getctime) print(f"Latest checkpoint: {latest_checkpoint}") for checkpoint in checkpoints: if checkpoint != latest_checkpoint: os.remove(checkpoint) print(f"Removed: {checkpoint}") def remove_expdir(exp_dir: str) -> str: shutil.rmtree(exp_dir) return "" class ExportTab: def __init__(self): pass def ui(self): gr.Markdown("# Download Model or Experiment Directory") gr.Markdown( "You can download the latest model or the entire experiment directory here." ) with gr.Row(): self.download_weight_btn = gr.Button( value="Latest model (for inferencing)", variant="primary" ) self.download_weight_output = gr.File(label="Prune latest model") with gr.Row(): self.download_inference_pack_btn = gr.Button( value="Download inference pack (model + index)", variant="primary" ) self.download_inference_pack_output = gr.File(label="Inference pack") with gr.Row(): self.download_expdir_btn = gr.Button( value="Download experiment directory", variant="primary" ) self.download_expdir_output = gr.File(label="Archive experiment directory") with gr.Row(): self.remove_legacy_checkpoints_btn = gr.Button( value="Remove legacy checkpoints" ) with gr.Row(): self.remove_expdir_btn = gr.Button( value="REMOVE experiment directory", variant="stop" ) def build(self, exp_dir: gr.Textbox): self.download_weight_btn.click( fn=download_weight, inputs=[exp_dir], outputs=[self.download_weight_output], ) self.download_inference_pack_btn.click( fn=download_inference_pack, inputs=[exp_dir], outputs=[self.download_inference_pack_output], ) self.download_expdir_btn.click( fn=download_expdir, inputs=[exp_dir], outputs=[self.download_expdir_output], ) self.remove_legacy_checkpoints_btn.click( fn=remove_legacy_checkpoints, inputs=[exp_dir], ) self.remove_expdir_btn.click( fn=remove_expdir, inputs=[exp_dir], outputs=[exp_dir], )