import gradio as gr import torch from safetensors.torch import save_file import os import tempfile def convert_ckpt_to_safetensors(ckpt_file): if not ckpt_file.name.endswith('.ckpt'): return None try: # Load the checkpoint ckpt = torch.load(ckpt_file.name, map_location="cpu") # Extract the state dict if "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: state_dict = ckpt # Create a temporary file for the safetensors output with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as temp_file: output_file = temp_file.name # Save as safetensors save_file(state_dict, output_file) return output_file except Exception as e: return None iface = gr.Interface( fn=convert_ckpt_to_safetensors, inputs=gr.File(label="Upload .ckpt file"), outputs=gr.File(label="Download converted safetensors file"), title="CKPT to Safetensors Converter", description="Upload a .ckpt file to convert it to the safetensors format. If you enjoy the app, please ❤ it!" ) iface.launch()