Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,10 +7,9 @@ import torch
|
|
7 |
from pytorch_lightning import LightningModule
|
8 |
from safetensors.torch import save_file
|
9 |
from torch import nn
|
10 |
-
from modelalign import BERTAlignModel
|
11 |
|
12 |
import gradio as gr
|
13 |
-
|
14 |
|
15 |
# ===========================
|
16 |
# Utility Functions
|
@@ -48,7 +47,11 @@ def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str =
|
|
48 |
try:
|
49 |
# Load the checkpoint; adjust map_location based on device
|
50 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
return True, "Checkpoint loaded successfully."
|
53 |
except Exception as e:
|
54 |
return False, f"Failed to load checkpoint: {str(e)}"
|
@@ -80,31 +83,29 @@ def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
|
|
80 |
# Step 1: Download the checkpoint
|
81 |
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
|
82 |
if not success:
|
83 |
-
return
|
84 |
|
85 |
# Step 2: Initialize the model
|
86 |
success, model_or_msg = initialize_model(model_name)
|
87 |
if not success:
|
88 |
-
return
|
89 |
model = model_or_msg
|
90 |
|
91 |
# Step 3: Load the checkpoint
|
92 |
success, message = load_checkpoint(model, checkpoint_path)
|
93 |
if not success:
|
94 |
-
return
|
95 |
|
96 |
# Step 4: Convert to SafeTensors
|
97 |
success, message = convert_to_safetensors(model, safetensors_path)
|
98 |
if not success:
|
99 |
-
return
|
100 |
|
101 |
# Step 5: Read the safetensors file for download
|
102 |
try:
|
103 |
-
|
104 |
-
safetensors_bytes = f.read()
|
105 |
-
return safetensors_bytes, "Conversion successful! Download your SafeTensors file below."
|
106 |
except Exception as e:
|
107 |
-
return
|
108 |
|
109 |
# ===========================
|
110 |
# Gradio Interface Setup
|
@@ -125,12 +126,20 @@ Convert your PyTorch Lightning `.ckpt` checkpoints to the secure `safetensors` f
|
|
125 |
iface = gr.Interface(
|
126 |
fn=convert_checkpoint_to_safetensors,
|
127 |
inputs=[
|
128 |
-
gr.
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
],
|
131 |
outputs=[
|
132 |
-
gr.
|
133 |
-
gr.
|
134 |
],
|
135 |
title=title,
|
136 |
description=description,
|
|
|
7 |
from pytorch_lightning import LightningModule
|
8 |
from safetensors.torch import save_file
|
9 |
from torch import nn
|
|
|
10 |
|
11 |
import gradio as gr
|
12 |
+
from modelalign import BERTAlignModel
|
13 |
|
14 |
# ===========================
|
15 |
# Utility Functions
|
|
|
47 |
try:
|
48 |
# Load the checkpoint; adjust map_location based on device
|
49 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
50 |
+
# Assuming the checkpoint has a 'state_dict' key
|
51 |
+
if 'state_dict' in checkpoint:
|
52 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
53 |
+
else:
|
54 |
+
model.load_state_dict(checkpoint, strict=False)
|
55 |
return True, "Checkpoint loaded successfully."
|
56 |
except Exception as e:
|
57 |
return False, f"Failed to load checkpoint: {str(e)}"
|
|
|
83 |
# Step 1: Download the checkpoint
|
84 |
success, message = download_checkpoint(checkpoint_url, checkpoint_path)
|
85 |
if not success:
|
86 |
+
return None, message
|
87 |
|
88 |
# Step 2: Initialize the model
|
89 |
success, model_or_msg = initialize_model(model_name)
|
90 |
if not success:
|
91 |
+
return None, model_or_msg
|
92 |
model = model_or_msg
|
93 |
|
94 |
# Step 3: Load the checkpoint
|
95 |
success, message = load_checkpoint(model, checkpoint_path)
|
96 |
if not success:
|
97 |
+
return None, message
|
98 |
|
99 |
# Step 4: Convert to SafeTensors
|
100 |
success, message = convert_to_safetensors(model, safetensors_path)
|
101 |
if not success:
|
102 |
+
return None, message
|
103 |
|
104 |
# Step 5: Read the safetensors file for download
|
105 |
try:
|
106 |
+
return safetensors_path, "Conversion successful! Download your SafeTensors file below."
|
|
|
|
|
107 |
except Exception as e:
|
108 |
+
return None, f"Failed to prepare download: {str(e)}"
|
109 |
|
110 |
# ===========================
|
111 |
# Gradio Interface Setup
|
|
|
126 |
iface = gr.Interface(
|
127 |
fn=convert_checkpoint_to_safetensors,
|
128 |
inputs=[
|
129 |
+
gr.Textbox(
|
130 |
+
lines=2,
|
131 |
+
placeholder="Enter the checkpoint URL here...",
|
132 |
+
label="Checkpoint URL"
|
133 |
+
),
|
134 |
+
gr.Textbox(
|
135 |
+
lines=1,
|
136 |
+
placeholder="e.g., roberta-base",
|
137 |
+
label="Model Name"
|
138 |
+
)
|
139 |
],
|
140 |
outputs=[
|
141 |
+
gr.File(label="Download SafeTensors File"),
|
142 |
+
gr.Textbox(label="Status")
|
143 |
],
|
144 |
title=title,
|
145 |
description=description,
|