Spaces:
Runtime error
Runtime error
File size: 4,883 Bytes
1f76ea6 2d8296f b8db24f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 91291fb 1f76ea6 91291fb 1f76ea6 91291fb 1f76ea6 2d8296f 1f76ea6 b8db24f 1f76ea6 b8db24f 1f76ea6 2d8296f 1f76ea6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
import requests
import tempfile
import shutil
import torch
from pytorch_lightning import LightningModule
from safetensors.torch import save_file
from torch import nn
import gradio as gr
from modelalign import BERTAlignModel
# ===========================
# Utility Functions
# ===========================
def download_checkpoint(url: str, dest_path: str):
"""
Downloads the checkpoint from the specified URL to the destination path.
"""
try:
with requests.get(url, stream=True) as response:
response.raise_for_status()
with open(dest_path, 'wb') as f:
shutil.copyfileobj(response.raw, f)
return True, "Checkpoint downloaded successfully."
except Exception as e:
return False, f"Failed to download checkpoint: {str(e)}"
def initialize_model(model_name: str, device: str = 'cpu'):
"""
Initializes the BERTAlignModel based on the provided model name.
"""
try:
model = BERTAlignModel(base_model_name=model_name)
model.to(device)
model.eval() # Set to evaluation mode
return True, model
except Exception as e:
return False, f"Failed to initialize model: {str(e)}"
def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'):
"""
Loads the checkpoint into the model.
"""
try:
# Load the checkpoint; adjust map_location based on device
checkpoint = torch.load(checkpoint_path, map_location=device)
# Assuming the checkpoint has a 'state_dict' key
if 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'], strict=False)
else:
model.load_state_dict(checkpoint, strict=False)
return True, "Checkpoint loaded successfully."
except Exception as e:
return False, f"Failed to load checkpoint: {str(e)}"
def convert_to_safetensors(model: LightningModule, save_path: str):
"""
Converts the model's state_dict to the safetensors format.
"""
try:
state_dict = model.state_dict()
save_file(state_dict, save_path)
return True, "Model converted to SafeTensors successfully."
except Exception as e:
return False, f"Failed to convert to SafeTensors: {str(e)}"
# ===========================
# Gradio Interface Function
# ===========================
def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
"""
Orchestrates the download, loading, conversion, and preparation for download.
Returns the safetensors file or an error message.
"""
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint_path = os.path.join(tmpdir, "model.ckpt")
safetensors_path = os.path.join(tmpdir, "model.safetensors")
# Step 1: Download the checkpoint
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
if not success:
return None, message
# Step 2: Initialize the model
success, model_or_msg = initialize_model(model_name)
if not success:
return None, model_or_msg
model = model_or_msg
# Step 3: Load the checkpoint
success, message = load_checkpoint(model, checkpoint_path)
if not success:
return None, message
# Step 4: Convert to SafeTensors
success, message = convert_to_safetensors(model, safetensors_path)
if not success:
return None, message
# Step 5: Read the safetensors file for download
try:
return safetensors_path, "Conversion successful! Download your SafeTensors file below."
except Exception as e:
return None, f"Failed to prepare download: {str(e)}"
# ===========================
# Gradio Interface Setup
# ===========================
title = "Checkpoint to SafeTensors Converter"
description = """
Convert your PyTorch Lightning .ckpt checkpoints to the secure safetensors format.
**Inputs**:
- **Checkpoint URL**: Direct link to the .ckpt file.
- **Model Name**: Name of the base model (e.g., roberta-base, bert-base-uncased).
**Output**:
- Downloadable safetensors file.
"""
iface = gr.Interface(
fn=convert_checkpoint_to_safetensors,
inputs=[
gr.Textbox(
lines=2,
placeholder="Enter the checkpoint URL here...",
label="Checkpoint URL"
),
gr.Textbox(
lines=1,
placeholder="e.g., roberta-base",
label="Model Name"
)
],
outputs=[
gr.File(label="Download SafeTensors File"),
gr.Textbox(label="Status")
],
title=title,
description=description,
allow_flagging="never"
)
# ===========================
# Launch the Interface
# ===========================
if __name__ == "__main__":
iface.launch()
|