File size: 3,510 Bytes
cc9f69c
 
 
91d9343
 
 
a9077eb
8d1740c
91d9343
 
 
 
 
a9077eb
91d9343
 
 
a9077eb
 
91d9343
a9077eb
 
 
 
 
 
 
 
 
 
91d9343
a9077eb
349bdfb
 
91d9343
 
 
 
814e69a
91d9343
 
a9077eb
91d9343
ce6dca2
d3ca146
91d9343
a9077eb
91d9343
06894c7
d3ca146
c5d8238
91d9343
 
 
a9077eb
e21f7c8
d879848
814e69a
fc92636
 
 
a9077eb
8d1740c
e21f7c8
d879848
e21f7c8
d879848
a9077eb
 
 
 
 
246dd82
b9f6209
91d9343
06894c7
349bdfb
a9077eb
 
91d9343
349bdfb
 
b9f6209
349bdfb
246dd82
 
cc9f69c
fa762f9
a9077eb
 
 
 
 
814e69a
fc92636
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur

def _gram_matrix(feature):
    batch_size, n_feature_maps, height, width = feature.size()
    new_feature = feature.view(batch_size * n_feature_maps, height * width)
    return torch.mm(new_feature, new_feature.t())

def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
    content_loss = 0
    style_loss = 0
    w_l = 1 / len(generated_features)
    
    for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
        content_loss += F.mse_loss(gf, cf)
        
        if resized_bg_masks:
            blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
            masked_gf = gf * blurred_bg_mask
            masked_sf = sf * blurred_bg_mask
            G = _gram_matrix(masked_gf)
            A = _gram_matrix(masked_sf)
        else:
            G = _gram_matrix(gf)
            A = _gram_matrix(sf)
        style_loss += w_l * F.mse_loss(G, A)
        
    total_loss = alpha * content_loss + beta * style_loss
    return content_loss, style_loss, total_loss

def inference(
    *,
    model,
    sod_model,
    content_image,
    style_features,
    apply_to_background,
    lr,
    iterations=101,
    optim_caller=optim.AdamW,
    alpha=1,
    beta=1,
):
    generated_image = content_image.clone().requires_grad_(True)
    optimizer = optim_caller([generated_image], lr=lr)
    min_losses = [float('inf')] * iterations

    with torch.no_grad():
        content_features = model(content_image)

        resized_bg_masks = []
        salient_object_ratio = None
        if apply_to_background:
            segmentation_output = sod_model(content_image)[0]
            segmentation_output = torch.sigmoid(segmentation_output)
            segmentation_mask = (segmentation_output > 0.7).float()
            background_mask = (segmentation_mask == 0).float()
            foreground_mask = 1 - background_mask
            
            salient_object_pixel_count = foreground_mask.sum().item()
            total_pixel_count = segmentation_mask.numel()
            salient_object_ratio = salient_object_pixel_count / total_pixel_count

            for cf in content_features:
                _, _, h_i, w_i = cf.shape
                bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
                resized_bg_masks.append(bg_mask)
        
    def closure(iter):
        optimizer.zero_grad()
        generated_features = model(generated_image)
        content_loss, style_loss, total_loss = _compute_loss(
            generated_features, content_features, style_features, resized_bg_masks, alpha, beta
        )
        total_loss.backward()
        
        # log loss
        min_losses[iter] = min(min_losses[iter], total_loss.item())
        
        return total_loss
    
    for iter in tqdm(range(iterations)):
        optimizer.step(lambda: closure(iter))

        if apply_to_background:
            with torch.no_grad():
                foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
                generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
                
    return generated_image, salient_object_ratio