JacobLinCool commited on
Commit
83ea110
1 Parent(s): a0da4cc

feat: download infer pack

Browse files
Files changed (1) hide show
  1. app/export.py +76 -6
app/export.py CHANGED
@@ -1,32 +1,80 @@
1
  from glob import glob
2
  import os
3
  import shutil
 
4
  import gradio as gr
5
  from infer.lib.train.process_ckpt import extract_small_model
 
6
 
7
 
8
  def download_weight(exp_dir: str) -> str:
9
- models = glob(f"{exp_dir}/G_*.pth")
10
- if not models:
11
- raise gr.Error("No model found")
12
 
13
- latest_model = max(models, key=os.path.getctime)
14
- print(f"Latest model: {latest_model}")
15
 
16
  name = os.path.basename(exp_dir)
17
  out = os.path.join(exp_dir, f"{name}.pth")
18
  extract_small_model(
19
- latest_model, out, "40k", True, "Model trained by ZeroGPU.", "v2"
20
  )
21
 
22
  return out
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def download_expdir(exp_dir: str) -> str:
26
  shutil.make_archive(exp_dir, "zip", exp_dir)
27
  return f"{exp_dir}.zip"
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  def remove_expdir(exp_dir: str) -> str:
31
  shutil.rmtree(exp_dir)
32
  return ""
@@ -48,12 +96,23 @@ class ExportTab:
48
  )
49
  self.download_weight_output = gr.File(label="Prune latest model")
50
 
 
 
 
 
 
 
51
  with gr.Row():
52
  self.download_expdir_btn = gr.Button(
53
  value="Download experiment directory", variant="primary"
54
  )
55
  self.download_expdir_output = gr.File(label="Archive experiment directory")
56
 
 
 
 
 
 
57
  with gr.Row():
58
  self.remove_expdir_btn = gr.Button(
59
  value="REMOVE experiment directory", variant="stop"
@@ -66,12 +125,23 @@ class ExportTab:
66
  outputs=[self.download_weight_output],
67
  )
68
 
 
 
 
 
 
 
69
  self.download_expdir_btn.click(
70
  fn=download_expdir,
71
  inputs=[exp_dir],
72
  outputs=[self.download_expdir_output],
73
  )
74
 
 
 
 
 
 
75
  self.remove_expdir_btn.click(
76
  fn=remove_expdir,
77
  inputs=[exp_dir],
 
1
  from glob import glob
2
  import os
3
  import shutil
4
+ import tempfile
5
  import gradio as gr
6
  from infer.lib.train.process_ckpt import extract_small_model
7
+ from app.train import train_index
8
 
9
 
10
  def download_weight(exp_dir: str) -> str:
11
+ checkpoints = glob(f"{exp_dir}/G_*.pth")
12
+ if not checkpoints:
13
+ raise gr.Error("No checkpoint found")
14
 
15
+ latest_checkpoint = max(checkpoints, key=os.path.getctime)
16
+ print(f"Latest checkpoint: {latest_checkpoint}")
17
 
18
  name = os.path.basename(exp_dir)
19
  out = os.path.join(exp_dir, f"{name}.pth")
20
  extract_small_model(
21
+ latest_checkpoint, out, "40k", True, "Model trained by ZeroGPU.", "v2"
22
  )
23
 
24
  return out
25
 
26
 
27
+ def download_inference_pack(exp_dir: str) -> str:
28
+ net_g = download_weight(exp_dir)
29
+ index = glob(f"{exp_dir}/added_*.index")
30
+ if not index:
31
+ train_index(exp_dir)
32
+ index = glob(f"{exp_dir}/added_*.index")
33
+ if not index:
34
+ raise gr.Error("Index not found")
35
+
36
+ # make zip of those two files
37
+ tmp = os.path.join(exp_dir, "inference_pack")
38
+ if os.path.exists(tmp):
39
+ shutil.rmtree(tmp)
40
+ os.makedirs(tmp)
41
+ shutil.copy(net_g, tmp)
42
+ shutil.copy(index[0], tmp)
43
+ shutil.make_archive(tmp, "zip", tmp)
44
+ shutil.rmtree(tmp)
45
+
46
+ return f"{tmp}.zip"
47
+
48
+
49
  def download_expdir(exp_dir: str) -> str:
50
  shutil.make_archive(exp_dir, "zip", exp_dir)
51
  return f"{exp_dir}.zip"
52
 
53
 
54
+ def remove_legacy_checkpoints(exp_dir: str):
55
+ checkpoints = glob(f"{exp_dir}/G_*.pth")
56
+ if not checkpoints:
57
+ raise gr.Error("No checkpoint found")
58
+
59
+ latest_checkpoint = max(checkpoints, key=os.path.getctime)
60
+ print(f"Latest checkpoint: {latest_checkpoint}")
61
+ for checkpoint in checkpoints:
62
+ if checkpoint != latest_checkpoint:
63
+ os.remove(checkpoint)
64
+ print(f"Removed: {checkpoint}")
65
+
66
+ checkpoints = glob(f"{exp_dir}/D_*.pth")
67
+ if not checkpoints:
68
+ raise gr.Error("No checkpoint found")
69
+
70
+ latest_checkpoint = max(checkpoints, key=os.path.getctime)
71
+ print(f"Latest checkpoint: {latest_checkpoint}")
72
+ for checkpoint in checkpoints:
73
+ if checkpoint != latest_checkpoint:
74
+ os.remove(checkpoint)
75
+ print(f"Removed: {checkpoint}")
76
+
77
+
78
  def remove_expdir(exp_dir: str) -> str:
79
  shutil.rmtree(exp_dir)
80
  return ""
 
96
  )
97
  self.download_weight_output = gr.File(label="Prune latest model")
98
 
99
+ with gr.Row():
100
+ self.download_inference_pack_btn = gr.Button(
101
+ value="Download inference pack (model + index)", variant="primary"
102
+ )
103
+ self.download_inference_pack_output = gr.File(label="Inference pack")
104
+
105
  with gr.Row():
106
  self.download_expdir_btn = gr.Button(
107
  value="Download experiment directory", variant="primary"
108
  )
109
  self.download_expdir_output = gr.File(label="Archive experiment directory")
110
 
111
+ with gr.Row():
112
+ self.remove_legacy_checkpoints_btn = gr.Button(
113
+ value="Remove legacy checkpoints"
114
+ )
115
+
116
  with gr.Row():
117
  self.remove_expdir_btn = gr.Button(
118
  value="REMOVE experiment directory", variant="stop"
 
125
  outputs=[self.download_weight_output],
126
  )
127
 
128
+ self.download_inference_pack_btn.click(
129
+ fn=download_inference_pack,
130
+ inputs=[exp_dir],
131
+ outputs=[self.download_inference_pack_output],
132
+ )
133
+
134
  self.download_expdir_btn.click(
135
  fn=download_expdir,
136
  inputs=[exp_dir],
137
  outputs=[self.download_expdir_output],
138
  )
139
 
140
+ self.remove_legacy_checkpoints_btn.click(
141
+ fn=remove_legacy_checkpoints,
142
+ inputs=[exp_dir],
143
+ )
144
+
145
  self.remove_expdir_btn.click(
146
  fn=remove_expdir,
147
  inputs=[exp_dir],