jamino30 commited on
Commit
349bdfb
·
verified ·
1 Parent(s): 2d8b485

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +15 -14
  2. u2net.py +6 -1
inference.py CHANGED
@@ -1,16 +1,9 @@
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
- import matplotlib.pyplot as plt
5
  from torchvision.transforms.functional import gaussian_blur
6
 
7
- def save_mask(mask, title='mask'):
8
- plt.imshow(mask.cpu().numpy()[0], cmap='gray')
9
- plt.title(title)
10
- plt.axis('off')
11
- plt.savefig(f'{title}.png', bbox_inches='tight')
12
- plt.close()
13
-
14
  def _gram_matrix(feature):
15
  batch_size, n_feature_maps, height, width = feature.size()
16
  new_feature = feature.view(batch_size * n_feature_maps, height * width)
@@ -35,7 +28,8 @@ def _compute_loss(generated_features, content_features, style_features, resized_
35
  A = _gram_matrix(sf)
36
  style_loss += w_l * F.mse_loss(G, A)
37
 
38
- return alpha * content_loss + beta * style_loss
 
39
 
40
  def inference(
41
  *,
@@ -50,6 +44,7 @@ def inference(
50
  alpha=1,
51
  beta=1,
52
  ):
 
53
  generated_image = content_image.clone().requires_grad_(True)
54
  optimizer = optim_caller([generated_image], lr=lr)
55
  min_losses = [float('inf')] * iterations
@@ -64,12 +59,10 @@ def inference(
64
  segmentation_mask = segmentation_output.argmax(dim=1)
65
  background_mask = (segmentation_mask == 0).float()
66
  foreground_mask = 1 - background_mask
67
- save_mask(background_mask, title='background-mask')
68
 
69
  background_pixel_count = background_mask.sum().item()
70
  total_pixel_count = segmentation_mask.numel()
71
  background_ratio = background_pixel_count / total_pixel_count
72
- print(f'Background Detected: {background_ratio * 100:.2f}%')
73
 
74
  for cf in content_features:
75
  _, _, h_i, w_i = cf.shape
@@ -79,11 +72,19 @@ def inference(
79
  def closure(iter):
80
  optimizer.zero_grad()
81
  generated_features = model(generated_image)
82
- total_loss = _compute_loss(
83
  generated_features, content_features, style_features, resized_bg_masks, alpha, beta
84
  )
85
  total_loss.backward()
 
 
 
 
 
 
 
86
  min_losses[iter] = min(min_losses[iter], total_loss.item())
 
87
  return total_loss
88
 
89
  for iter in range(iterations):
@@ -94,6 +95,6 @@ def inference(
94
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
95
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
96
 
97
- if iter % 10 == 0: print(f'[{"Background" if apply_to_background else "Image"}] Loss ({iter}):', min_losses[iter])
98
-
99
  return generated_image, background_ratio
 
1
  import torch
2
  import torch.optim as optim
3
  import torch.nn.functional as F
4
+ from torch.utils.tensorboard import SummaryWriter
5
  from torchvision.transforms.functional import gaussian_blur
6
 
 
 
 
 
 
 
 
7
  def _gram_matrix(feature):
8
  batch_size, n_feature_maps, height, width = feature.size()
9
  new_feature = feature.view(batch_size * n_feature_maps, height * width)
 
28
  A = _gram_matrix(sf)
29
  style_loss += w_l * F.mse_loss(G, A)
30
 
31
+ total_loss = alpha * content_loss + beta * style_loss
32
+ return content_loss, style_loss, total_loss
33
 
34
  def inference(
35
  *,
 
44
  alpha=1,
45
  beta=1,
46
  ):
47
+ writer = SummaryWriter()
48
  generated_image = content_image.clone().requires_grad_(True)
49
  optimizer = optim_caller([generated_image], lr=lr)
50
  min_losses = [float('inf')] * iterations
 
59
  segmentation_mask = segmentation_output.argmax(dim=1)
60
  background_mask = (segmentation_mask == 0).float()
61
  foreground_mask = 1 - background_mask
 
62
 
63
  background_pixel_count = background_mask.sum().item()
64
  total_pixel_count = segmentation_mask.numel()
65
  background_ratio = background_pixel_count / total_pixel_count
 
66
 
67
  for cf in content_features:
68
  _, _, h_i, w_i = cf.shape
 
72
  def closure(iter):
73
  optimizer.zero_grad()
74
  generated_features = model(generated_image)
75
+ content_loss, style_loss, total_loss = _compute_loss(
76
  generated_features, content_features, style_features, resized_bg_masks, alpha, beta
77
  )
78
  total_loss.backward()
79
+
80
+ # log loss
81
+ writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
82
+ 'Loss/content': content_loss.item(),
83
+ 'Loss/style': style_loss.item(),
84
+ 'Loss/total': total_loss.item()
85
+ }, iter)
86
  min_losses[iter] = min(min_losses[iter], total_loss.item())
87
+
88
  return total_loss
89
 
90
  for iter in range(iterations):
 
95
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
96
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
97
 
98
+ writer.flush()
99
+ writer.close()
100
  return generated_image, background_ratio
u2net.py CHANGED
@@ -1 +1,6 @@
1
- # initial code for u2net arch
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class U2Net(nn.Module):
4
+ def __init__(self):
5
+ super(U2Net, self).__init__()
6
+ pass