from glob import glob import os import shutil import gradio as gr from infer.lib.train.process_ckpt import extract_small_model def download_weight(exp_dir: str) -> str: models = glob(f"{exp_dir}/G_*.pth") if not models: raise gr.Error("No model found") latest_model = max(models, key=os.path.getctime) print(f"Latest model: {latest_model}") name = os.path.basename(exp_dir) out = os.path.join(exp_dir, f"{name}.pth") extract_small_model( latest_model, out, "40k", True, "Model trained by ZeroGPU.", "v2" ) return out def download_expdir(exp_dir: str) -> str: shutil.make_archive(exp_dir, "zip", exp_dir) return f"{exp_dir}.zip" def remove_expdir(exp_dir: str) -> str: shutil.rmtree(exp_dir) return "" class ExportTab: def __init__(self): pass def ui(self): gr.Markdown("# Download Model or Experiment Directory") gr.Markdown( "You can download the latest model or the entire experiment directory here." ) with gr.Row(): self.download_weight_btn = gr.Button( value="Latest model (for inferencing)", variant="primary" ) self.download_weight_output = gr.File(label="Prune latest model") with gr.Row(): self.download_expdir_btn = gr.Button( value="Download experiment directory", variant="primary" ) self.download_expdir_output = gr.File(label="Archive experiment directory") with gr.Row(): self.remove_expdir_btn = gr.Button( value="REMOVE experiment directory", variant="stop" ) def build(self, exp_dir: gr.Textbox): self.download_weight_btn.click( fn=download_weight, inputs=[exp_dir], outputs=[self.download_weight_output], ) self.download_expdir_btn.click( fn=download_expdir, inputs=[exp_dir], outputs=[self.download_expdir_output], ) self.remove_expdir_btn.click( fn=remove_expdir, inputs=[exp_dir], outputs=[exp_dir], )