AnasHXH commited on
Commit
af440c4
1 Parent(s): f960fad
Files changed (1) hide show
  1. app.py +49 -17
app.py CHANGED
@@ -236,28 +236,60 @@ class MambaIRShadowRemoval(nn.Module):
236
 
237
  # Load the model with weights
238
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
- model = MambaIRShadowRemoval(img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 1], dec_blk_nums=[1, 1, 1, 1], d_state=64)
240
- model.load_state_dict(torch.load("shadow_removal_model.pth", map_location=device))
241
- model.to(device)
242
- model.eval()
243
-
244
- # Define the Gradio function
245
- transform = transforms.Compose([transforms.ToTensor()])
246
-
247
- def remove_shadow(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  input_tensor = transform(image).unsqueeze(0).to(device)
249
  with torch.no_grad():
250
  output_tensor = model(input_tensor)
251
  output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu())
252
  return output_image
253
 
254
- # Set up Gradio interface
255
- iface = gr.Interface(
256
- fn=remove_shadow,
257
- inputs=gr.Image(type="pil"),
258
- outputs=gr.Image(type="pil"),
259
- title="Shadow Removal Model",
260
- description="Upload an image to remove shadows using the trained model."
261
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  iface.launch()
 
236
 
237
  # Load the model with weights
238
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
+ # Define function to load model with specified weights
240
+ def load_model(weights_path):
241
+ model = MambaIRShadowRemoval(img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 1], dec_blk_nums=[1, 1, 1, 1], d_state=64)
242
+ model.load_state_dict(torch.load(weights_path, map_location=device))
243
+ model.to(device)
244
+ model.eval()
245
+ return model
246
+
247
+ # Preload models for ISTD+ and SRD
248
+ models = {
249
+ "ISTD+": load_model("ISTD+.pth"),
250
+ "SRD": load_model("SRD.pth")
251
+ }
252
+
253
+ # Define transformation
254
+ transform = transforms.Compose([
255
+ transforms.ToTensor(),
256
+ ])
257
+
258
+ # Define function to perform shadow removal
259
+ def remove_shadow(image, dataset):
260
+ model = models[dataset] # Select the appropriate model based on dataset choice
261
  input_tensor = transform(image).unsqueeze(0).to(device)
262
  with torch.no_grad():
263
  output_tensor = model(input_tensor)
264
  output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu())
265
  return output_image
266
 
267
+ # Define example paths for ISTD+ and SRD
268
+ examples = [
269
+ ["ISTD+.jpg", "ISTD+"],
270
+ ["SRD.jpg", "SRD"]
271
+ ]
272
+
273
+ # Gradio Interface with dropdown and examples
274
+ with gr.Blocks() as iface:
275
+ gr.Markdown("## Shadow Removal Model")
276
+ gr.Markdown("Upload an image to remove shadows using the trained model. Choose the dataset to load the corresponding weights and example images.")
277
+
278
+ with gr.Row():
279
+ dataset_choice = gr.Dropdown(["ISTD+", "SRD"], label="Choose Dataset", value="ISTD+")
280
+
281
+ example_image = gr.Image(type="pil", label="Input Image")
282
+ output_image = gr.Image(type="pil", label="Output Image")
283
+
284
+ # Display examples and map them to dataset and images
285
+ gr.Examples(
286
+ examples=examples,
287
+ inputs=[example_image, dataset_choice],
288
+ label="Examples",
289
+ )
290
+
291
+ submit_btn = gr.Button("Submit")
292
+
293
+ submit_btn.click(remove_shadow, inputs=[example_image, dataset_choice], outputs=output_image)
294
 
295
  iface.launch()