Spaces:
Running
Running
File size: 1,108 Bytes
19fb0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import gradio as gr
import torch
from safetensors.torch import save_file
import os
def convert_ckpt_to_safetensors(ckpt_file):
if not ckpt_file.name.endswith('.ckpt'):
return "Please upload a .ckpt file."
try:
# Load the checkpoint
ckpt = torch.load(ckpt_file.name, map_location="cpu")
# Extract the state dict
if "state_dict" in ckpt:
state_dict = ckpt["state_dict"]
else:
state_dict = ckpt
# Create the output filename
output_file = os.path.splitext(ckpt_file.name)[0] + ".safetensors"
# Save as safetensors
save_file(state_dict, output_file)
return f"Conversion successful. Saved as {output_file}"
except Exception as e:
return f"Error during conversion: {str(e)}"
iface = gr.Interface(
fn=convert_ckpt_to_safetensors,
inputs=gr.File(label="Upload .ckpt file"),
outputs="text",
title="CKPT to Safetensors Converter",
description="Upload a .ckpt file to convert it to the safetensors format."
)
iface.launch()
|