File size: 5,191 Bytes
1f76ea6
 
 
 
 
 
 
 
 
 
 
2d8296f
 
1f76ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8296f
 
1f76ea6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8296f
 
1f76ea6
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import requests
import tempfile
import shutil

import torch
from pytorch_lightning import LightningModule
from safetensors.torch import save_file
from torch import nn
from modelalign import BERTAlignModel 

import gradio as gr


# ===========================
# Utility Functions
# ===========================

def download_checkpoint(url: str, dest_path: str):
    """
    Downloads the checkpoint from the specified URL to the destination path.
    """
    try:
        with requests.get(url, stream=True) as response:
            response.raise_for_status()
            with open(dest_path, 'wb') as f:
                shutil.copyfileobj(response.raw, f)
        return True, "Checkpoint downloaded successfully."
    except Exception as e:
        return False, f"Failed to download checkpoint: {str(e)}"

def initialize_model(model_name: str, device: str = 'cpu'):
    """
    Initializes the BERTAlignModel based on the provided model name.
    """
    try:
        model = BERTAlignModel(base_model_name=model_name)
        model.to(device)
        model.eval()  # Set to evaluation mode
        return True, model
    except Exception as e:
        return False, f"Failed to initialize model: {str(e)}"

def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str = 'cpu'):
    """
    Loads the checkpoint into the model.
    """
    try:
        # Load the checkpoint; adjust map_location based on device
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'], strict=False)
        return True, "Checkpoint loaded successfully."
    except Exception as e:
        return False, f"Failed to load checkpoint: {str(e)}"

def convert_to_safetensors(model: LightningModule, save_path: str):
    """
    Converts the model's state_dict to the safetensors format.
    """
    try:
        state_dict = model.state_dict()
        save_file(state_dict, save_path)
        return True, "Model converted to SafeTensors successfully."
    except Exception as e:
        return False, f"Failed to convert to SafeTensors: {str(e)}"

# ===========================
# Gradio Interface Function
# ===========================

def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
    """
    Orchestrates the download, loading, conversion, and preparation for download.
    Returns the safetensors file or an error message.
    """
    with tempfile.TemporaryDirectory() as tmpdir:
        checkpoint_path = os.path.join(tmpdir, "model.ckpt")
        safetensors_path = os.path.join(tmpdir, "model.safetensors")

        # Step 1: Download the checkpoint
        success, message = download_checkpoint(checkpoint_url, checkpoint_path)
        if not success:
            return gr.update(value=None, visible=False), message

        # Step 2: Initialize the model
        success, model_or_msg = initialize_model(model_name)
        if not success:
            return gr.update(value=None, visible=False), model_or_msg
        model = model_or_msg

        # Step 3: Load the checkpoint
        success, message = load_checkpoint(model, checkpoint_path)
        if not success:
            return gr.update(value=None, visible=False), message

        # Step 4: Convert to SafeTensors
        success, message = convert_to_safetensors(model, safetensors_path)
        if not success:
            return gr.update(value=None, visible=False), message

        # Step 5: Read the safetensors file for download
        try:
            with open(safetensors_path, "rb") as f:
                safetensors_bytes = f.read()
            return safetensors_bytes, "Conversion successful! Download your SafeTensors file below."
        except Exception as e:
            return gr.update(value=None, visible=False), f"Failed to prepare download: {str(e)}"

# ===========================
# Gradio Interface Setup
# ===========================

title = "Checkpoint to SafeTensors Converter"
description = """
Convert your PyTorch Lightning `.ckpt` checkpoints to the secure `safetensors` format.

**Inputs**:
- **Checkpoint URL**: Direct link to the `.ckpt` file.
- **Model Name**: Name of the base model (e.g., `roberta-base`, `bert-base-uncased`).

**Output**:
- Downloadable `safetensors` file.
"""

iface = gr.Interface(
    fn=convert_checkpoint_to_safetensors,
    inputs=[
        gr.inputs.Textbox(lines=2, placeholder="Enter the checkpoint URL here...", label="Checkpoint URL"),
        gr.inputs.Textbox(lines=1, placeholder="e.g., roberta-base", label="Model Name")
    ],
    outputs=[
        gr.outputs.File(label="Download SafeTensors File"),
        gr.outputs.Textbox(label="Status")
    ],
    title=title,
    description=description,
    examples=[
        [
            "https://huggingface.co/yzha/AlignScore/resolve/main/AlignScore-base.ckpt?download=true",
            "roberta-base"
        ],
        [
            "https://path.to/your/checkpoint.ckpt",
            "bert-base-uncased"
        ]
    ],
    allow_flagging="never"
)

# ===========================
# Launch the Interface
# ===========================

if __name__ == "__main__":
    iface.launch()