Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,854 Bytes
cc9f69c 91d9343 a9077eb 8d1740c 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 91d9343 a9077eb 349bdfb 91d9343 814e69a 91d9343 a9077eb 91d9343 ce6dca2 d3ca146 91d9343 a9077eb 91d9343 06894c7 d3ca146 c5d8238 91d9343 a9077eb e21f7c8 d879848 814e69a a9077eb 8d1740c e21f7c8 814e69a d879848 e21f7c8 d879848 a9077eb 246dd82 b9f6209 91d9343 06894c7 349bdfb a9077eb 91d9343 349bdfb b9f6209 349bdfb 246dd82 cc9f69c fa762f9 a9077eb 814e69a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
from tqdm import tqdm
import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur
def _gram_matrix(feature):
batch_size, n_feature_maps, height, width = feature.size()
new_feature = feature.view(batch_size * n_feature_maps, height * width)
return torch.mm(new_feature, new_feature.t())
def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
content_loss = 0
style_loss = 0
w_l = 1 / len(generated_features)
for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
content_loss += F.mse_loss(gf, cf)
if resized_bg_masks:
blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
masked_gf = gf * blurred_bg_mask
masked_sf = sf * blurred_bg_mask
G = _gram_matrix(masked_gf)
A = _gram_matrix(masked_sf)
else:
G = _gram_matrix(gf)
A = _gram_matrix(sf)
style_loss += w_l * F.mse_loss(G, A)
total_loss = alpha * content_loss + beta * style_loss
return content_loss, style_loss, total_loss
def inference(
*,
model,
sod_model,
content_image,
style_features,
apply_to_background,
lr,
iterations=101,
optim_caller=optim.AdamW,
alpha=1,
beta=1,
):
generated_image = content_image.clone().requires_grad_(True)
optimizer = optim_caller([generated_image], lr=lr)
min_losses = [float('inf')] * iterations
with torch.no_grad():
content_features = model(content_image)
resized_bg_masks = []
salient_object_ratio = None
if apply_to_background:
# original
segmentation_output = sod_model(content_image)['out'] # [1, 21, 512, 512]
segmentation_mask = segmentation_output.argmax(dim=1) # [1, 512, 512]
background_mask = (segmentation_mask == 0).float()
foreground_mask = 1 - background_mask
# new
# segmentation_output = sod_model(content_image)[0]
# segmentation_output = torch.sigmoid(segmentation_output)
# segmentation_mask = (segmentation_output > 0.7).float()
# background_mask = (segmentation_mask == 0).float()
# foreground_mask = 1 - background_mask
salient_object_pixel_count = foreground_mask.sum().item()
total_pixel_count = segmentation_mask.numel()
salient_object_ratio = salient_object_pixel_count / total_pixel_count
for cf in content_features:
_, _, h_i, w_i = cf.shape
bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
resized_bg_masks.append(bg_mask)
def closure(iter):
optimizer.zero_grad()
generated_features = model(generated_image)
content_loss, style_loss, total_loss = _compute_loss(
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
)
total_loss.backward()
# log loss
min_losses[iter] = min(min_losses[iter], total_loss.item())
return total_loss
for iter in tqdm(range(iterations)):
optimizer.step(lambda: closure(iter))
if apply_to_background:
with torch.no_grad():
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
return generated_image, salient_object_ratio |