AnasHXH commited on
Commit
6f0136b
1 Parent(s): ccd0fd7

upload app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+
8
+ # Define model classes (same as before)
9
+ class SimpleGate(nn.Module):
10
+ def forward(self, x):
11
+ x1, x2 = x.chunk(2, dim=-1)
12
+ return x1 * x2
13
+
14
+ class ASPP(nn.Module):
15
+ def __init__(self, in_channels, out_channels):
16
+ super(ASPP, self).__init__()
17
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
18
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False)
19
+ self.conv3 = nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False)
20
+ self.conv4 = nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False)
21
+ self.pool = nn.AdaptiveAvgPool2d(1)
22
+ self.conv5 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
23
+ self.conv_out = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False)
24
+ self.norm = nn.LayerNorm(out_channels)
25
+ self.act = nn.SiLU()
26
+
27
+ def forward(self, x):
28
+ size = x.shape[-2:]
29
+ feat1 = self.conv1(x)
30
+ feat2 = self.conv2(x)
31
+ feat3 = self.conv3(x)
32
+ feat4 = self.conv4(x)
33
+ feat5 = F.interpolate(self.conv5(self.pool(x)), size=size, mode='bilinear', align_corners=False)
34
+ out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
35
+ out = self.conv_out(out)
36
+ out = out.permute(0, 2, 3, 1) # Change to (B, H, W, C)
37
+ out = self.norm(out)
38
+ out = out.permute(0, 3, 1, 2) # Change back to (B, C, H, W)
39
+ return self.act(out)
40
+
41
+ class ChannelwiseSelfAttention(nn.Module):
42
+ def __init__(self, dim):
43
+ super(ChannelwiseSelfAttention, self).__init__()
44
+ self.dim = dim
45
+ self.query_conv = nn.Linear(dim, dim)
46
+ self.key_conv = nn.Linear(dim, dim)
47
+ self.value_conv = nn.Linear(dim, dim)
48
+ self.scale = dim ** -0.5
49
+ self.pos_embedding = nn.Parameter(torch.randn(1, 1, 1, dim))
50
+
51
+ def forward(self, x):
52
+ # x: (B, H, W, C)
53
+ B, H, W, C = x.shape
54
+ x = x + self.pos_embedding # Positional embedding
55
+ x = x.view(B, H * W, C) # Reshape to (B, N, C)
56
+
57
+ # Linear projections
58
+ q = self.query_conv(x) # (B, N, C)
59
+ k = self.key_conv(x) # (B, N, C)
60
+ v = self.value_conv(x) # (B, N, C)
61
+
62
+ # Compute attention over channels at each spatial location
63
+ q = q.view(B, H * W, 1, C) # (B, N, 1, C)
64
+ k = k.view(B, H * W, C, 1) # (B, N, C, 1)
65
+ attn = torch.matmul(q, k).squeeze(2) * self.scale # (B, N, C)
66
+ attn = attn.softmax(dim=-1) # Softmax over channels
67
+
68
+ # Apply attention to values
69
+ out = attn * v # Element-wise multiplication
70
+ out = out.view(B, H, W, C) # Reshape back to (B, H, W, C)
71
+ return out
72
+
73
+ class EnhancedSS2D(nn.Module):
74
+ def __init__(self, d_model, d_state=16, d_conv=3, expand=2., dt_rank=64, dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0):
75
+ super().__init__()
76
+ self.d_model = d_model
77
+ self.d_state = d_state
78
+ self.d_conv = d_conv
79
+ self.expand = expand
80
+ self.d_inner = int(self.expand * self.d_model) # self.d_inner = 2 * d_model
81
+ self.dt_rank = dt_rank
82
+
83
+ self.in_proj = nn.Linear(self.d_model, self.d_inner * 2)
84
+ self.conv2d = nn.Conv2d(self.d_inner, self.d_inner, kernel_size=d_conv, padding=(d_conv - 1) // 2, groups=self.d_inner)
85
+ self.act = nn.SiLU()
86
+
87
+ self.x_proj = nn.Linear(self.d_inner, self.d_inner * 2)
88
+ self.dt_proj = nn.Linear(self.d_inner, self.d_inner)
89
+
90
+ self.out_norm = nn.LayerNorm(self.d_inner)
91
+
92
+ # Update here
93
+ self.out_proj = nn.Linear(self.d_inner // 2, d_model)
94
+
95
+ # New components
96
+ self.simple_gate = SimpleGate()
97
+ self.aspp = ASPP(d_model, d_model)
98
+ self.channel_attn = ChannelwiseSelfAttention(d_model)
99
+
100
+ def forward(self, x):
101
+ B, H, W, C = x.shape
102
+
103
+ # Apply ASPP
104
+ x_aspp = self.aspp(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
105
+
106
+ # Original SS2D operations
107
+ x = self.in_proj(x)
108
+ x, z = x.chunk(2, dim=-1)
109
+ x = x.permute(0, 3, 1, 2)
110
+ x = self.conv2d(x)
111
+ x = x.permute(0, 2, 3, 1)
112
+ x = self.act(x)
113
+ y = self.selective_scan(x)
114
+ y = self.out_norm(y)
115
+ y = y * F.silu(z)
116
+
117
+ # Apply SimpleGate
118
+ y = self.simple_gate(y)
119
+
120
+ # Apply Channel-wise Self-Attention
121
+ y = self.channel_attn(y)
122
+
123
+ # Combine with ASPP output
124
+ y = y + x_aspp
125
+
126
+ out = self.out_proj(y)
127
+ return out
128
+
129
+ def selective_scan(self, x):
130
+ B, H, W, C = x.shape
131
+ x_flat = x.reshape(B, H*W, C)
132
+ x_dbl = self.x_proj(x_flat)
133
+ x_dbl = x_dbl.view(B, H, W, -1)
134
+ dt, x_proj = x_dbl.chunk(2, dim=-1)
135
+ dt = F.softplus(self.dt_proj(dt))
136
+ y = x * torch.sigmoid(dt) + x_proj * torch.tanh(x_proj)
137
+ return y
138
+
139
+ class EnhancedVSSBlock(nn.Module):
140
+ def __init__(self, d_model, d_state=16):
141
+ super().__init__()
142
+ self.ln_1 = nn.LayerNorm(d_model)
143
+ self.ss2d = EnhancedSS2D(d_model, d_state)
144
+ self.ln_2 = nn.LayerNorm(d_model)
145
+ self.conv_blk = nn.Sequential(
146
+ nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
147
+ nn.ReLU(inplace=True),
148
+ nn.Conv2d(d_model, d_model, kernel_size=3, padding=1)
149
+ )
150
+
151
+ def forward(self, x):
152
+ residual = x
153
+ x = self.ln_1(x)
154
+ x = residual + self.ss2d(x)
155
+ residual = x
156
+ x = self.ln_2(x)
157
+ x = x.permute(0, 3, 1, 2)
158
+ x = self.conv_blk(x)
159
+ x = x.permute(0, 2, 3, 1)
160
+ x = residual + x
161
+ return x
162
+
163
+ class MambaIRShadowRemoval(nn.Module):
164
+ def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 1], dec_blk_nums=[1, 1, 1, 1], d_state=64):
165
+ super().__init__()
166
+ self.intro = nn.Conv2d(img_channel, width, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
167
+ self.ending = nn.Conv2d(width, img_channel, kernel_size=3, padding=1, stride=1, groups=1, bias=True)
168
+
169
+ self.encoders = nn.ModuleList()
170
+ self.decoders = nn.ModuleList()
171
+ self.middle_blks = nn.ModuleList()
172
+ self.ups = nn.ModuleList()
173
+ self.downs = nn.ModuleList()
174
+
175
+ chan = width
176
+ for num in enc_blk_nums:
177
+ self.encoders.append(
178
+ nn.Sequential(*[EnhancedVSSBlock(chan, d_state) for _ in range(num)])
179
+ )
180
+ self.downs.append(nn.Conv2d(chan, 2*chan, 2, 2))
181
+ chan = chan * 2
182
+
183
+ self.middle_blks = nn.Sequential(
184
+ *[EnhancedVSSBlock(chan, d_state) for _ in range(middle_blk_num)]
185
+ )
186
+
187
+ for num in dec_blk_nums:
188
+ self.ups.append(nn.Sequential(
189
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
190
+ nn.PixelShuffle(2)
191
+ ))
192
+ chan = chan // 2
193
+ self.decoders.append(
194
+ nn.Sequential(*[EnhancedVSSBlock(chan, d_state) for _ in range(num)])
195
+ )
196
+
197
+ self.padder_size = 2 ** len(self.encoders)
198
+
199
+ def forward(self, inp):
200
+ B, C, H, W = inp.shape
201
+ inp = self.check_image_size(inp)
202
+ x = self.intro(inp)
203
+ x = x.permute(0, 2, 3, 1)
204
+
205
+ encs = []
206
+ for encoder, down in zip(self.encoders, self.downs):
207
+ x = encoder(x)
208
+ encs.append(x)
209
+ x = x.permute(0, 3, 1, 2)
210
+ x = down(x)
211
+ x = x.permute(0, 2, 3, 1)
212
+
213
+ x = self.middle_blks(x)
214
+
215
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
216
+ x = x.permute(0, 3, 1, 2)
217
+ x = up(x)
218
+ x = x.permute(0, 2, 3, 1)
219
+ x = x + enc_skip
220
+ x = decoder(x)
221
+
222
+ x = x.permute(0, 3, 1, 2)
223
+ x = self.ending(x)
224
+ x = x + inp
225
+
226
+ return x[:, :, :H, :W]
227
+
228
+ def check_image_size(self, x):
229
+ _, _, h, w = x.size()
230
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
231
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
232
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
233
+ return x
234
+
235
+
236
+
237
+ # Load the model with weights
238
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
+ model = MambaIRShadowRemoval(img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1, 1, 1, 1], dec_blk_nums=[1, 1, 1, 1], d_state=64)
240
+ model.load_state_dict(torch.load("shadow_removal_model.pth", map_location=device))
241
+ model.to(device)
242
+ model.eval()
243
+
244
+ # Define the Gradio function
245
+ transform = transforms.Compose([transforms.ToTensor()])
246
+
247
+ def remove_shadow(image):
248
+ input_tensor = transform(image).unsqueeze(0).to(device)
249
+ with torch.no_grad():
250
+ output_tensor = model(input_tensor)
251
+ output_image = transforms.ToPILImage()(output_tensor.squeeze(0).cpu())
252
+ return output_image
253
+
254
+ # Set up Gradio interface
255
+ iface = gr.Interface(
256
+ fn=remove_shadow,
257
+ inputs=gr.Image(type="pil"),
258
+ outputs=gr.Image(type="pil"),
259
+ title="Shadow Removal Model",
260
+ description="Upload an image to remove shadows using the trained model."
261
+ )
262
+
263
+ iface.launch()