Spaces:
Running
Running
import gradio as gr | |
import torch | |
from safetensors.torch import save_file | |
import os | |
import tempfile | |
def convert_ckpt_to_safetensors(ckpt_file): | |
if not ckpt_file.name.endswith('.ckpt'): | |
return None | |
try: | |
# Load the checkpoint | |
ckpt = torch.load(ckpt_file.name, map_location="cpu") | |
# Extract the state dict | |
if "state_dict" in ckpt: | |
state_dict = ckpt["state_dict"] | |
else: | |
state_dict = ckpt | |
# Create a temporary file for the safetensors output | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as temp_file: | |
output_file = temp_file.name | |
# Save as safetensors | |
save_file(state_dict, output_file) | |
return output_file | |
except Exception as e: | |
return None | |
iface = gr.Interface( | |
fn=convert_ckpt_to_safetensors, | |
inputs=gr.File(label="Upload .ckpt file"), | |
outputs=gr.File(label="Download converted safetensors file"), | |
title="CKPT to Safetensors Converter", | |
description="Upload a .ckpt file to convert it to the safetensors format. If you enjoy the app, please ❤ it!" | |
) | |
iface.launch() | |