namelessai commited on
Commit
19fb0f0
1 Parent(s): a9fe1be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from safetensors.torch import save_file
4
+ import os
5
+
6
+ def convert_ckpt_to_safetensors(ckpt_file):
7
+ if not ckpt_file.name.endswith('.ckpt'):
8
+ return "Please upload a .ckpt file."
9
+
10
+ try:
11
+ # Load the checkpoint
12
+ ckpt = torch.load(ckpt_file.name, map_location="cpu")
13
+
14
+ # Extract the state dict
15
+ if "state_dict" in ckpt:
16
+ state_dict = ckpt["state_dict"]
17
+ else:
18
+ state_dict = ckpt
19
+
20
+ # Create the output filename
21
+ output_file = os.path.splitext(ckpt_file.name)[0] + ".safetensors"
22
+
23
+ # Save as safetensors
24
+ save_file(state_dict, output_file)
25
+
26
+ return f"Conversion successful. Saved as {output_file}"
27
+ except Exception as e:
28
+ return f"Error during conversion: {str(e)}"
29
+
30
+ iface = gr.Interface(
31
+ fn=convert_ckpt_to_safetensors,
32
+ inputs=gr.File(label="Upload .ckpt file"),
33
+ outputs="text",
34
+ title="CKPT to Safetensors Converter",
35
+ description="Upload a .ckpt file to convert it to the safetensors format."
36
+ )
37
+
38
+ iface.launch()