Spaces:
Running
Running
import gradio as gr | |
import torch | |
from safetensors.torch import save_file | |
import os | |
def convert_ckpt_to_safetensors(ckpt_file): | |
if not ckpt_file.name.endswith('.ckpt'): | |
return "Please upload a .ckpt file." | |
try: | |
# Load the checkpoint | |
ckpt = torch.load(ckpt_file.name, map_location="cpu") | |
# Extract the state dict | |
if "state_dict" in ckpt: | |
state_dict = ckpt["state_dict"] | |
else: | |
state_dict = ckpt | |
# Create the output filename | |
output_file = os.path.splitext(ckpt_file.name)[0] + ".safetensors" | |
# Save as safetensors | |
save_file(state_dict, output_file) | |
return f"Conversion successful. Saved as {output_file}" | |
except Exception as e: | |
return f"Error during conversion: {str(e)}" | |
iface = gr.Interface( | |
fn=convert_ckpt_to_safetensors, | |
inputs=gr.File(label="Upload .ckpt file"), | |
outputs="text", | |
title="CKPT to Safetensors Converter", | |
description="Upload a .ckpt file to convert it to the safetensors format." | |
) | |
iface.launch() | |