Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- app.py +5 -3
- inference.py +44 -6
app.py
CHANGED
@@ -35,7 +35,7 @@ for style_name, style_img_path in style_options.items():
|
|
35 |
cached_style_features[style_name] = style_features
|
36 |
|
37 |
@spaces.GPU(duration=10)
|
38 |
-
def run(content_image, style_name, style_strength=5, progress=gr.Progress(track_tqdm=True)):
|
39 |
yield None
|
40 |
content_img, original_size = preprocess_img(content_image, img_size)
|
41 |
content_img = content_img.to(device)
|
@@ -53,7 +53,8 @@ def run(content_image, style_name, style_strength=5, progress=gr.Progress(track_
|
|
53 |
model=model,
|
54 |
content_image=content_img,
|
55 |
style_features=style_features,
|
56 |
-
lr=lrs[style_strength-1]
|
|
|
57 |
)
|
58 |
et = time.time()
|
59 |
print('TIME TAKEN:', et-st)
|
@@ -78,6 +79,7 @@ with gr.Blocks(css=css) as demo:
|
|
78 |
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
|
79 |
with gr.Group():
|
80 |
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
|
|
|
81 |
submit_button = gr.Button('Submit', variant='primary')
|
82 |
|
83 |
examples = gr.Examples(
|
@@ -105,7 +107,7 @@ with gr.Blocks(css=css) as demo:
|
|
105 |
|
106 |
submit_button.click(
|
107 |
fn=run,
|
108 |
-
inputs=[content_image, style_dropdown, style_strength_slider],
|
109 |
outputs=[output_image]
|
110 |
).then(
|
111 |
fn=save_image,
|
|
|
35 |
cached_style_features[style_name] = style_features
|
36 |
|
37 |
@spaces.GPU(duration=10)
|
38 |
+
def run(content_image, style_name, style_strength=5, apply_to_background=False, progress=gr.Progress(track_tqdm=True)):
|
39 |
yield None
|
40 |
content_img, original_size = preprocess_img(content_image, img_size)
|
41 |
content_img = content_img.to(device)
|
|
|
53 |
model=model,
|
54 |
content_image=content_img,
|
55 |
style_features=style_features,
|
56 |
+
lr=lrs[style_strength-1],
|
57 |
+
apply_to_background=apply_to_background
|
58 |
)
|
59 |
et = time.time()
|
60 |
print('TIME TAKEN:', et-st)
|
|
|
79 |
style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
|
80 |
with gr.Group():
|
81 |
style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
|
82 |
+
apply_to_background = gr.Checkbox(label='Apply to background only')
|
83 |
submit_button = gr.Button('Submit', variant='primary')
|
84 |
|
85 |
examples = gr.Examples(
|
|
|
107 |
|
108 |
submit_button.click(
|
109 |
fn=run,
|
110 |
+
inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background],
|
111 |
outputs=[output_image]
|
112 |
).then(
|
113 |
fn=save_image,
|
inference.py
CHANGED
@@ -3,21 +3,34 @@ from tqdm import tqdm
|
|
3 |
import torch
|
4 |
import torch.optim as optim
|
5 |
import torch.nn.functional as F
|
|
|
|
|
6 |
|
7 |
def _gram_matrix(feature):
|
8 |
batch_size, n_feature_maps, height, width = feature.size()
|
9 |
new_feature = feature.view(batch_size * n_feature_maps, height * width)
|
10 |
return torch.mm(new_feature, new_feature.t())
|
11 |
|
12 |
-
def _compute_loss(generated_features, content_features, style_features, alpha, beta):
|
13 |
content_loss = 0
|
14 |
style_loss = 0
|
15 |
w_l = 1 / len(generated_features)
|
16 |
-
|
|
|
17 |
content_loss += F.mse_loss(gf, cf)
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
style_loss += w_l * F.mse_loss(G, A)
|
|
|
21 |
return alpha * content_loss + beta * style_loss
|
22 |
|
23 |
def inference(
|
@@ -25,11 +38,12 @@ def inference(
|
|
25 |
model,
|
26 |
content_image,
|
27 |
style_features,
|
|
|
28 |
lr,
|
29 |
iterations=101,
|
30 |
optim_caller=optim.AdamW,
|
31 |
alpha=1,
|
32 |
-
beta=1
|
33 |
):
|
34 |
generated_image = content_image.clone().requires_grad_(True)
|
35 |
optimizer = optim_caller([generated_image], lr=lr)
|
@@ -37,17 +51,41 @@ def inference(
|
|
37 |
|
38 |
with torch.no_grad():
|
39 |
content_features = model(content_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
def closure(iter):
|
42 |
optimizer.zero_grad()
|
43 |
generated_features = model(generated_image)
|
44 |
-
total_loss = _compute_loss(
|
|
|
|
|
45 |
total_loss.backward()
|
46 |
min_losses[iter] = min(min_losses[iter], total_loss.item())
|
47 |
return total_loss
|
48 |
|
49 |
for iter in tqdm(range(iterations), desc='The magic is happening ✨'):
|
50 |
optimizer.step(lambda: closure(iter))
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
|
52 |
|
53 |
return generated_image
|
|
|
3 |
import torch
|
4 |
import torch.optim as optim
|
5 |
import torch.nn.functional as F
|
6 |
+
from torchvision.transforms.functional import gaussian_blur
|
7 |
+
from torchvision import models
|
8 |
|
9 |
def _gram_matrix(feature):
|
10 |
batch_size, n_feature_maps, height, width = feature.size()
|
11 |
new_feature = feature.view(batch_size * n_feature_maps, height * width)
|
12 |
return torch.mm(new_feature, new_feature.t())
|
13 |
|
14 |
+
def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
|
15 |
content_loss = 0
|
16 |
style_loss = 0
|
17 |
w_l = 1 / len(generated_features)
|
18 |
+
|
19 |
+
for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
|
20 |
content_loss += F.mse_loss(gf, cf)
|
21 |
+
|
22 |
+
if resized_bg_masks:
|
23 |
+
blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
|
24 |
+
masked_gf = gf * blurred_bg_mask
|
25 |
+
masked_sf = sf * blurred_bg_mask
|
26 |
+
G = _gram_matrix(masked_gf)
|
27 |
+
A = _gram_matrix(masked_sf)
|
28 |
+
else:
|
29 |
+
G = _gram_matrix(gf)
|
30 |
+
A = _gram_matrix(sf)
|
31 |
+
style_loss += w_l * F.mse_loss(G, A)
|
32 |
style_loss += w_l * F.mse_loss(G, A)
|
33 |
+
|
34 |
return alpha * content_loss + beta * style_loss
|
35 |
|
36 |
def inference(
|
|
|
38 |
model,
|
39 |
content_image,
|
40 |
style_features,
|
41 |
+
apply_to_background,
|
42 |
lr,
|
43 |
iterations=101,
|
44 |
optim_caller=optim.AdamW,
|
45 |
alpha=1,
|
46 |
+
beta=1,
|
47 |
):
|
48 |
generated_image = content_image.clone().requires_grad_(True)
|
49 |
optimizer = optim_caller([generated_image], lr=lr)
|
|
|
51 |
|
52 |
with torch.no_grad():
|
53 |
content_features = model(content_image)
|
54 |
+
|
55 |
+
resized_bg_masks = []
|
56 |
+
if apply_to_background:
|
57 |
+
segmentation_model = models.segmentation.deeplabv3_resnet101(weights='DEFAULT').eval()
|
58 |
+
segmentation_model = segmentation_model.to(content_image.device)
|
59 |
+
|
60 |
+
segmentation_output = segmentation_model(content_image)['out']
|
61 |
+
segmentation_mask = segmentation_output.argmax(dim=1)
|
62 |
+
|
63 |
+
background_mask = (segmentation_mask == 0).float()
|
64 |
+
foreground_mask = (segmentation_mask != 0).float()
|
65 |
+
|
66 |
+
for cf in content_features:
|
67 |
+
_, _, h_i, w_i = cf.shape
|
68 |
+
bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
|
69 |
+
resized_bg_masks.append(bg_mask)
|
70 |
|
71 |
def closure(iter):
|
72 |
optimizer.zero_grad()
|
73 |
generated_features = model(generated_image)
|
74 |
+
total_loss = _compute_loss(
|
75 |
+
generated_features, content_features, style_features, resized_bg_masks, alpha, beta
|
76 |
+
)
|
77 |
total_loss.backward()
|
78 |
min_losses[iter] = min(min_losses[iter], total_loss.item())
|
79 |
return total_loss
|
80 |
|
81 |
for iter in tqdm(range(iterations), desc='The magic is happening ✨'):
|
82 |
optimizer.step(lambda: closure(iter))
|
83 |
+
|
84 |
+
if apply_to_background:
|
85 |
+
with torch.no_grad():
|
86 |
+
foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
|
87 |
+
generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
|
88 |
+
|
89 |
if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
|
90 |
|
91 |
return generated_image
|