Spaces:
Runtime error
Runtime error
File size: 5,191 Bytes
1f76ea6 2d8296f 1f76ea6 2d8296f 1f76ea6 2d8296f 1f76ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import requests
import tempfile
import shutil
import torch
from pytorch_lightning import LightningModule
from safetensors.torch import save_file
from torch import nn
from modelalign import BERTAlignModel
import gradio as gr
# ===========================
# Utility Functions
# ===========================
def download_checkpoint(url: str, dest_path: str):
"""
Downloads the checkpoint from the specified URL to the destination path.
"""
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(dest_path, 'wb') as f:
shutil.copyfileobj(response.raw, f)
return True, "Checkpoint downloaded successfully."
except Exception as e:
return False, f"Failed to download checkpoint: {str(e)}"
def initialize_model(model_name: str, device: str = 'cpu'):
"""
Initializes the BERTAlignModel based on the provided model name.
"""
try:
model = BERTAlignModel(base_model_name=model_name)
model.to(device)
model.eval() # Set to evaluation mode
return True, model
except Exception as e:
return False, f"Failed to initialize model: {str(e)}"
def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'):
"""
Loads the checkpoint into the model.
"""
try:
# Load the checkpoint; adjust map_location based on device
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['state_dict'], strict=False)
return True, "Checkpoint loaded successfully."
except Exception as e:
return False, f"Failed to load checkpoint: {str(e)}"
def convert_to_safetensors(model: LightningModule, save_path: str):
"""
Converts the model's state_dict to the safetensors format.
"""
try:
state_dict = model.state_dict()
save_file(state_dict, save_path)
return True, "Model converted to SafeTensors successfully."
except Exception as e:
return False, f"Failed to convert to SafeTensors: {str(e)}"
# ===========================
# Gradio Interface Function
# ===========================
def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
"""
Orchestrates the download, loading, conversion, and preparation for download.
Returns the safetensors file or an error message.
"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = os.path.join(tmpdir, "model.ckpt")
safetensors_path = os.path.join(tmpdir, "model.safetensors")
# Step 1: Download the checkpoint
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
if not success:
return gr.update(value=None, visible=False), message
# Step 2: Initialize the model
success, model_or_msg = initialize_model(model_name)
if not success:
return gr.update(value=None, visible=False), model_or_msg
model = model_or_msg
# Step 3: Load the checkpoint
success, message = load_checkpoint(model, checkpoint_path)
if not success:
return gr.update(value=None, visible=False), message
# Step 4: Convert to SafeTensors
success, message = convert_to_safetensors(model, safetensors_path)
if not success:
return gr.update(value=None, visible=False), message
# Step 5: Read the safetensors file for download
try:
with open(safetensors_path, "rb") as f:
safetensors_bytes = f.read()
return safetensors_bytes, "Conversion successful! Download your SafeTensors file below."
except Exception as e:
return gr.update(value=None, visible=False), f"Failed to prepare download: {str(e)}"
# ===========================
# Gradio Interface Setup
# ===========================
title = "Checkpoint to SafeTensors Converter"
description = """
Convert your PyTorch Lightning `.ckpt` checkpoints to the secure `safetensors` format.
**Inputs**:
- **Checkpoint URL**: Direct link to the `.ckpt` file.
- **Model Name**: Name of the base model (e.g., `roberta-base`, `bert-base-uncased`).
**Output**:
- Downloadable `safetensors` file.
"""
iface = gr.Interface(
fn=convert_checkpoint_to_safetensors,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Enter the checkpoint URL here...", label="Checkpoint URL"),
gr.inputs.Textbox(lines=1, placeholder="e.g., roberta-base", label="Model Name")
],
outputs=[
gr.outputs.File(label="Download SafeTensors File"),
gr.outputs.Textbox(label="Status")
],
title=title,
description=description,
examples=[
[
"https://huggingface.co/yzha/AlignScore/resolve/main/AlignScore-base.ckpt?download=true",
"roberta-base"
],
[
"https://path.to/your/checkpoint.ckpt",
"bert-base-uncased"
]
],
allow_flagging="never"
)
# ===========================
# Launch the Interface
# ===========================
if __name__ == "__main__":
iface.launch()
|