JacobLinCool commited on
Commit
b4b3999
1 Parent(s): 4a7b229

feat: upload expdir to hugging face

Browse files
Files changed (1) hide show
  1. app/export.py +26 -0
app/export.py CHANGED
@@ -4,6 +4,7 @@ import shutil
4
  import gradio as gr
5
  from infer.lib.train.process_ckpt import extract_small_model
6
  from app.train import train_index
 
7
 
8
 
9
  def download_weight(exp_dir: str) -> str:
@@ -49,6 +50,11 @@ def download_expdir(exp_dir: str) -> str:
49
  return f"{exp_dir}.zip"
50
 
51
 
 
 
 
 
 
52
  def remove_legacy_checkpoints(exp_dir: str):
53
  checkpoints = glob(f"{exp_dir}/G_*.pth")
54
  if not checkpoints:
@@ -106,6 +112,20 @@ class ExportTab:
106
  )
107
  self.download_expdir_output = gr.File(label="Archive experiment directory")
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  with gr.Row():
110
  self.remove_legacy_checkpoints_btn = gr.Button(
111
  value="Remove legacy checkpoints"
@@ -135,6 +155,12 @@ class ExportTab:
135
  outputs=[self.download_expdir_output],
136
  )
137
 
 
 
 
 
 
 
138
  self.remove_legacy_checkpoints_btn.click(
139
  fn=remove_legacy_checkpoints,
140
  inputs=[exp_dir],
 
4
  import gradio as gr
5
  from infer.lib.train.process_ckpt import extract_small_model
6
  from app.train import train_index
7
+ from huggingface_hub import upload_folder
8
 
9
 
10
  def download_weight(exp_dir: str) -> str:
 
50
  return f"{exp_dir}.zip"
51
 
52
 
53
+ def upload_to_huggingface(exp_dir: str, repo_id: str, token: str) -> str:
54
+ commit = upload_folder(repo_id=repo_id, folder_path=exp_dir, token=token)
55
+ return commit.commit_url
56
+
57
+
58
  def remove_legacy_checkpoints(exp_dir: str):
59
  checkpoints = glob(f"{exp_dir}/G_*.pth")
60
  if not checkpoints:
 
112
  )
113
  self.download_expdir_output = gr.File(label="Archive experiment directory")
114
 
115
+ with gr.Row():
116
+ with gr.Column():
117
+ gr.Markdown("### Upload to Hugging Face")
118
+ gr.Markdown(
119
+ "You can upload the entire experiment directory to Hugging Face."
120
+ )
121
+ self.commit_link = gr.Markdown("")
122
+ with gr.Column():
123
+ self.repo_id = gr.Textbox(label="Repository ID")
124
+ self.token = gr.Textbox(label="Personal access token")
125
+ self.upload_to_huggingface_btn = gr.Button(
126
+ value="Upload to Hugging Face", variant="primary"
127
+ )
128
+
129
  with gr.Row():
130
  self.remove_legacy_checkpoints_btn = gr.Button(
131
  value="Remove legacy checkpoints"
 
155
  outputs=[self.download_expdir_output],
156
  )
157
 
158
+ self.upload_to_huggingface_btn.click(
159
+ fn=upload_to_huggingface,
160
+ inputs=[exp_dir, self.repo_id, self.token],
161
+ outputs=[self.commit_link],
162
+ )
163
+
164
  self.remove_legacy_checkpoints_btn.click(
165
  fn=remove_legacy_checkpoints,
166
  inputs=[exp_dir],