File size: 3,425 Bytes
91d9343
 
 
 
 
a9077eb
 
91d9343
 
 
 
 
 
a9077eb
91d9343
 
 
a9077eb
 
91d9343
a9077eb
 
 
 
 
 
 
 
 
 
 
91d9343
a9077eb
91d9343
 
 
 
 
 
 
a9077eb
91d9343
ce6dca2
d3ca146
91d9343
a9077eb
91d9343
06894c7
d3ca146
c5d8238
91d9343
 
 
a9077eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246dd82
b9f6209
91d9343
06894c7
a9077eb
 
 
91d9343
b9f6209
246dd82
 
b9f6209
fa762f9
a9077eb
 
 
 
 
 
c59d0bc
91d9343
06894c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from tqdm import tqdm

import torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision.transforms.functional import gaussian_blur
from torchvision import models

def _gram_matrix(feature):
    batch_size, n_feature_maps, height, width = feature.size()
    new_feature = feature.view(batch_size * n_feature_maps, height * width)
    return torch.mm(new_feature, new_feature.t())

def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
    content_loss = 0
    style_loss = 0
    w_l = 1 / len(generated_features)
    
    for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
        content_loss += F.mse_loss(gf, cf)
        
        if resized_bg_masks:
            blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
            masked_gf = gf * blurred_bg_mask
            masked_sf = sf * blurred_bg_mask
            G = _gram_matrix(masked_gf)
            A = _gram_matrix(masked_sf)
        else:
            G = _gram_matrix(gf)
            A = _gram_matrix(sf)
            style_loss += w_l * F.mse_loss(G, A)
        style_loss += w_l * F.mse_loss(G, A)
        
    return alpha * content_loss + beta * style_loss

def inference(
    *,
    model,
    content_image,
    style_features,
    apply_to_background,
    lr,
    iterations=101,
    optim_caller=optim.AdamW,
    alpha=1,
    beta=1,
):
    generated_image = content_image.clone().requires_grad_(True)
    optimizer = optim_caller([generated_image], lr=lr)
    min_losses = [float('inf')] * iterations

    with torch.no_grad():
        content_features = model(content_image)

        resized_bg_masks = []        
        if apply_to_background:
            segmentation_model = models.segmentation.deeplabv3_resnet101(weights='DEFAULT').eval()
            segmentation_model = segmentation_model.to(content_image.device)
            
            segmentation_output = segmentation_model(content_image)['out']
            segmentation_mask = segmentation_output.argmax(dim=1)
            
            background_mask = (segmentation_mask == 0).float()
            foreground_mask = (segmentation_mask != 0).float()

            for cf in content_features:
                _, _, h_i, w_i = cf.shape
                bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
                resized_bg_masks.append(bg_mask)
        
    def closure(iter):
        optimizer.zero_grad()
        generated_features = model(generated_image)
        total_loss = _compute_loss(
            generated_features, content_features, style_features, resized_bg_masks, alpha, beta
        )
        total_loss.backward()
        min_losses[iter] = min(min_losses[iter], total_loss.item())
        return total_loss
    
    for iter in tqdm(range(iterations), desc='The magic is happening ✨'):
        optimizer.step(lambda: closure(iter))

        if apply_to_background:
            with torch.no_grad():
                foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
                generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized

        if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
    
    return generated_image