Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- inference.py +15 -14
- u2net.py +6 -1
inference.py
CHANGED
@@ -1,16 +1,9 @@
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
4 |
-
|
5 |
from torchvision.transforms.functional import gaussian_blur
|
6 |
|
7 |
-
def save_mask(mask, title='mask'):
|
8 |
-
plt.imshow(mask.cpu().numpy()[0], cmap='gray')
|
9 |
-
plt.title(title)
|
10 |
-
plt.axis('off')
|
11 |
-
plt.savefig(f'{title}.png', bbox_inches='tight')
|
12 |
-
plt.close()
|
13 |
-
|
14 |
def _gram_matrix(feature):
|
15 |
batch_size, n_feature_maps, height, width = feature.size()
|
16 |
new_feature = feature.view(batch_size * n_feature_maps, height * width)
|
@@ -35,7 +28,8 @@ def _compute_loss(generated_features, content_features, style_features, resized_
|
|
35 |
A = _gram_matrix(sf)
|
36 |
style_loss += w_l * F.mse_loss(G, A)
|
37 |
|
38 |
-
|
|
|
39 |
|
40 |
def inference(
|
41 |
*,
|
@@ -50,6 +44,7 @@ def inference(
|
|
50 |
alpha=1,
|
51 |
beta=1,
|
52 |
):
|
|
|
53 |
generated_image = content_image.clone().requires_grad_(True)
|
54 |
optimizer = optim_caller([generated_image], lr=lr)
|
55 |
min_losses = [float('inf')] * iterations
|
@@ -64,12 +59,10 @@ def inference(
|
|
64 |
segmentation_mask = segmentation_output.argmax(dim=1)
|
65 |
background_mask = (segmentation_mask == 0).float()
|
66 |
foreground_mask = 1 - background_mask
|
67 |
-
save_mask(background_mask, title='background-mask')
|
68 |
|
69 |
background_pixel_count = background_mask.sum().item()
|
70 |
total_pixel_count = segmentation_mask.numel()
|
71 |
background_ratio = background_pixel_count / total_pixel_count
|
72 |
-
print(f'Background Detected: {background_ratio * 100:.2f}%')
|
73 |
|
74 |
for cf in content_features:
|
75 |
_, _, h_i, w_i = cf.shape
|
@@ -79,11 +72,19 @@ def inference(
|
|
79 |
def closure(iter):
|
80 |
optimizer.zero_grad()
|
81 |
generated_features = model(generated_image)
|
82 |
-
total_loss = _compute_loss(
|
83 |
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
|
84 |
)
|
85 |
total_loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
min_losses[iter] = min(min_losses[iter], total_loss.item())
|
|
|
87 |
return total_loss
|
88 |
|
89 |
for iter in range(iterations):
|
@@ -94,6 +95,6 @@ def inference(
|
|
94 |
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
|
95 |
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
|
96 |
|
97 |
-
|
98 |
-
|
99 |
return generated_image, background_ratio
|
|
|
1 |
import torch
|
2 |
import torch.optim as optim
|
3 |
import torch.nn.functional as F
|
4 |
+
from torch.utils.tensorboard import SummaryWriter
|
5 |
from torchvision.transforms.functional import gaussian_blur
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def _gram_matrix(feature):
|
8 |
batch_size, n_feature_maps, height, width = feature.size()
|
9 |
new_feature = feature.view(batch_size * n_feature_maps, height * width)
|
|
|
28 |
A = _gram_matrix(sf)
|
29 |
style_loss += w_l * F.mse_loss(G, A)
|
30 |
|
31 |
+
total_loss = alpha * content_loss + beta * style_loss
|
32 |
+
return content_loss, style_loss, total_loss
|
33 |
|
34 |
def inference(
|
35 |
*,
|
|
|
44 |
alpha=1,
|
45 |
beta=1,
|
46 |
):
|
47 |
+
writer = SummaryWriter()
|
48 |
generated_image = content_image.clone().requires_grad_(True)
|
49 |
optimizer = optim_caller([generated_image], lr=lr)
|
50 |
min_losses = [float('inf')] * iterations
|
|
|
59 |
segmentation_mask = segmentation_output.argmax(dim=1)
|
60 |
background_mask = (segmentation_mask == 0).float()
|
61 |
foreground_mask = 1 - background_mask
|
|
|
62 |
|
63 |
background_pixel_count = background_mask.sum().item()
|
64 |
total_pixel_count = segmentation_mask.numel()
|
65 |
background_ratio = background_pixel_count / total_pixel_count
|
|
|
66 |
|
67 |
for cf in content_features:
|
68 |
_, _, h_i, w_i = cf.shape
|
|
|
72 |
def closure(iter):
|
73 |
optimizer.zero_grad()
|
74 |
generated_features = model(generated_image)
|
75 |
+
content_loss, style_loss, total_loss = _compute_loss(
|
76 |
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
|
77 |
)
|
78 |
total_loss.backward()
|
79 |
+
|
80 |
+
# log loss
|
81 |
+
writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
|
82 |
+
'Loss/content': content_loss.item(),
|
83 |
+
'Loss/style': style_loss.item(),
|
84 |
+
'Loss/total': total_loss.item()
|
85 |
+
}, iter)
|
86 |
min_losses[iter] = min(min_losses[iter], total_loss.item())
|
87 |
+
|
88 |
return total_loss
|
89 |
|
90 |
for iter in range(iterations):
|
|
|
95 |
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
|
96 |
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
|
97 |
|
98 |
+
writer.flush()
|
99 |
+
writer.close()
|
100 |
return generated_image, background_ratio
|
u2net.py
CHANGED
@@ -1 +1,6 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
class U2Net(nn.Module):
|
4 |
+
def __init__(self):
|
5 |
+
super(U2Net, self).__init__()
|
6 |
+
pass
|