Spaces:
Running
Running
renew
Browse files
app.py
CHANGED
@@ -236,28 +236,60 @@ class MambaIRShadowRemoval(nn.Module):
|
|
236 |
|
237 |
# Load the model with weights
|
238 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
239 |
-
|
240 |
-
|
241 |
-
model
|
242 |
-
model.
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
249 |
with torch.no_grad():
|
250 |
output_tensor = model(input_tensor)
|
251 |
output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu())
|
252 |
return output_image
|
253 |
|
254 |
-
#
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
iface.launch()
|
|
|
236 |
|
237 |
# Load the model with weights
|
238 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
239 |
+
# Define function to load model with specified weights
|
240 |
+
def load_model(weights_path):
|
241 |
+
model = MambaIRShadowRemoval(img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 1], dec_blk_nums=[1, 1, 1, 1], d_state=64)
|
242 |
+
model.load_state_dict(torch.load(weights_path, map_location=device))
|
243 |
+
model.to(device)
|
244 |
+
model.eval()
|
245 |
+
return model
|
246 |
+
|
247 |
+
# Preload models for ISTD+ and SRD
|
248 |
+
models = {
|
249 |
+
"ISTD+": load_model("ISTD+.pth"),
|
250 |
+
"SRD": load_model("SRD.pth")
|
251 |
+
}
|
252 |
+
|
253 |
+
# Define transformation
|
254 |
+
transform = transforms.Compose([
|
255 |
+
transforms.ToTensor(),
|
256 |
+
])
|
257 |
+
|
258 |
+
# Define function to perform shadow removal
|
259 |
+
def remove_shadow(image, dataset):
|
260 |
+
model = models[dataset] # Select the appropriate model based on dataset choice
|
261 |
input_tensor = transform(image).unsqueeze(0).to(device)
|
262 |
with torch.no_grad():
|
263 |
output_tensor = model(input_tensor)
|
264 |
output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu())
|
265 |
return output_image
|
266 |
|
267 |
+
# Define example paths for ISTD+ and SRD
|
268 |
+
examples = [
|
269 |
+
["ISTD+.jpg", "ISTD+"],
|
270 |
+
["SRD.jpg", "SRD"]
|
271 |
+
]
|
272 |
+
|
273 |
+
# Gradio Interface with dropdown and examples
|
274 |
+
with gr.Blocks() as iface:
|
275 |
+
gr.Markdown("## Shadow Removal Model")
|
276 |
+
gr.Markdown("Upload an image to remove shadows using the trained model. Choose the dataset to load the corresponding weights and example images.")
|
277 |
+
|
278 |
+
with gr.Row():
|
279 |
+
dataset_choice = gr.Dropdown(["ISTD+", "SRD"], label="Choose Dataset", value="ISTD+")
|
280 |
+
|
281 |
+
example_image = gr.Image(type="pil", label="Input Image")
|
282 |
+
output_image = gr.Image(type="pil", label="Output Image")
|
283 |
+
|
284 |
+
# Display examples and map them to dataset and images
|
285 |
+
gr.Examples(
|
286 |
+
examples=examples,
|
287 |
+
inputs=[example_image, dataset_choice],
|
288 |
+
label="Examples",
|
289 |
+
)
|
290 |
+
|
291 |
+
submit_btn = gr.Button("Submit")
|
292 |
+
|
293 |
+
submit_btn.click(remove_shadow, inputs=[example_image, dataset_choice], outputs=output_image)
|
294 |
|
295 |
iface.launch()
|