PANH commited on
Commit
b8db24f
1 Parent(s): ce3706d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -15
app.py CHANGED
@@ -7,10 +7,9 @@ import torch
7
  from pytorch_lightning import LightningModule
8
  from safetensors.torch import save_file
9
  from torch import nn
10
- from modelalign import BERTAlignModel
11
 
12
  import gradio as gr
13
-
14
 
15
  # ===========================
16
  # Utility Functions
@@ -48,7 +47,11 @@ def load_checkpoint(model: LightningModule, checkpoint_path: str, device: str =
48
  try:
49
  # Load the checkpoint; adjust map_location based on device
50
  checkpoint = torch.load(checkpoint_path, map_location=device)
51
- model.load_state_dict(checkpoint['state_dict'], strict=False)
 
 
 
 
52
  return True, "Checkpoint loaded successfully."
53
  except Exception as e:
54
  return False, f"Failed to load checkpoint: {str(e)}"
@@ -80,31 +83,29 @@ def convert_checkpoint_to_safetensors(checkpoint_url: str, model_name: str):
80
  # Step 1: Download the checkpoint
81
  success, message = download_checkpoint(checkpoint_url, checkpoint_path)
82
  if not success:
83
- return gr.update(value=None, visible=False), message
84
 
85
  # Step 2: Initialize the model
86
  success, model_or_msg = initialize_model(model_name)
87
  if not success:
88
- return gr.update(value=None, visible=False), model_or_msg
89
  model = model_or_msg
90
 
91
  # Step 3: Load the checkpoint
92
  success, message = load_checkpoint(model, checkpoint_path)
93
  if not success:
94
- return gr.update(value=None, visible=False), message
95
 
96
  # Step 4: Convert to SafeTensors
97
  success, message = convert_to_safetensors(model, safetensors_path)
98
  if not success:
99
- return gr.update(value=None, visible=False), message
100
 
101
  # Step 5: Read the safetensors file for download
102
  try:
103
- with open(safetensors_path, "rb") as f:
104
- safetensors_bytes = f.read()
105
- return safetensors_bytes, "Conversion successful! Download your SafeTensors file below."
106
  except Exception as e:
107
- return gr.update(value=None, visible=False), f"Failed to prepare download: {str(e)}"
108
 
109
  # ===========================
110
  # Gradio Interface Setup
@@ -125,12 +126,20 @@ Convert your PyTorch Lightning `.ckpt` checkpoints to the secure `safetensors` f
125
  iface = gr.Interface(
126
  fn=convert_checkpoint_to_safetensors,
127
  inputs=[
128
- gr.inputs.Textbox(lines=2, placeholder="Enter the checkpoint URL here...", label="Checkpoint URL"),
129
- gr.inputs.Textbox(lines=1, placeholder="e.g., roberta-base", label="Model Name")
 
 
 
 
 
 
 
 
130
  ],
131
  outputs=[
132
- gr.outputs.File(label="Download SafeTensors File"),
133
- gr.outputs.Textbox(label="Status")
134
  ],
135
  title=title,
136
  description=description,
 
7
  from pytorch_lightning import LightningModule
8
  from safetensors.torch import save_file
9
  from torch import nn
 
10
 
11
  import gradio as gr
12
+ from modelalign import BERTAlignModel
13
 
14
  # ===========================
15
  # Utility Functions
 
47
  try:
48
  # Load the checkpoint; adjust map_location based on device
49
  checkpoint = torch.load(checkpoint_path, map_location=device)
50
+ # Assuming the checkpoint has a 'state_dict' key
51
+ if 'state_dict' in checkpoint:
52
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
53
+ else:
54
+ model.load_state_dict(checkpoint, strict=False)
55
  return True, "Checkpoint loaded successfully."
56
  except Exception as e:
57
  return False, f"Failed to load checkpoint: {str(e)}"
 
83
  # Step 1: Download the checkpoint
84
  success, message = download_checkpoint(checkpoint_url, checkpoint_path)
85
  if not success:
86
+ return None, message
87
 
88
  # Step 2: Initialize the model
89
  success, model_or_msg = initialize_model(model_name)
90
  if not success:
91
+ return None, model_or_msg
92
  model = model_or_msg
93
 
94
  # Step 3: Load the checkpoint
95
  success, message = load_checkpoint(model, checkpoint_path)
96
  if not success:
97
+ return None, message
98
 
99
  # Step 4: Convert to SafeTensors
100
  success, message = convert_to_safetensors(model, safetensors_path)
101
  if not success:
102
+ return None, message
103
 
104
  # Step 5: Read the safetensors file for download
105
  try:
106
+ return safetensors_path, "Conversion successful! Download your SafeTensors file below."
 
 
107
  except Exception as e:
108
+ return None, f"Failed to prepare download: {str(e)}"
109
 
110
  # ===========================
111
  # Gradio Interface Setup
 
126
  iface = gr.Interface(
127
  fn=convert_checkpoint_to_safetensors,
128
  inputs=[
129
+ gr.Textbox(
130
+ lines=2,
131
+ placeholder="Enter the checkpoint URL here...",
132
+ label="Checkpoint URL"
133
+ ),
134
+ gr.Textbox(
135
+ lines=1,
136
+ placeholder="e.g., roberta-base",
137
+ label="Model Name"
138
+ )
139
  ],
140
  outputs=[
141
+ gr.File(label="Download SafeTensors File"),
142
+ gr.Textbox(label="Status")
143
  ],
144
  title=title,
145
  description=description,