Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
@@ -15,6 +15,7 @@ if torch.cuda.is_available(): device = 'cuda'
|
|
15 |
elif torch.backends.mps.is_available(): device = 'mps'
|
16 |
else: device = 'cpu'
|
17 |
print('DEVICE:', device)
|
|
|
18 |
|
19 |
model = VGG_19().to(device)
|
20 |
for param in model.parameters():
|
@@ -33,6 +34,24 @@ optimal_settings = {
|
|
33 |
'Watercolor': (10, False),
|
34 |
}
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
@spaces.GPU(duration=20)
|
37 |
def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
|
38 |
yield None
|
@@ -57,34 +76,22 @@ def inference(content_image, style_image, style_strength, output_quality, progre
|
|
57 |
st = time.time()
|
58 |
generated_img = content_img.clone().requires_grad_(True)
|
59 |
optimizer = optim.Adam([generated_img], lr=lr)
|
|
|
|
|
|
|
60 |
|
61 |
for _ in tqdm(range(iters), desc='The magic is happening ✨'):
|
62 |
generated_features = model(generated_img)
|
63 |
-
content_features = model(content_img)
|
64 |
-
style_features = model(style_img)
|
65 |
-
|
66 |
-
content_loss = 0
|
67 |
-
style_loss = 0
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
content_loss += (torch.mean((generated_feature - content_feature) ** 2))
|
73 |
-
|
74 |
-
G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
|
75 |
-
A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
|
76 |
-
|
77 |
-
E_l = ((G - A) ** 2)
|
78 |
-
w_l = 1/5
|
79 |
-
style_loss += torch.mean(w_l * E_l)
|
80 |
-
|
81 |
-
total_loss = alpha * content_loss + beta * style_loss
|
82 |
optimizer.zero_grad()
|
83 |
total_loss.backward()
|
84 |
optimizer.step()
|
85 |
|
86 |
et = time.time()
|
87 |
print('TIME TAKEN:', et-st)
|
|
|
88 |
yield postprocess_img(generated_img, original_size)
|
89 |
|
90 |
|
|
|
15 |
elif torch.backends.mps.is_available(): device = 'mps'
|
16 |
else: device = 'cpu'
|
17 |
print('DEVICE:', device)
|
18 |
+
if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
|
19 |
|
20 |
model = VGG_19().to(device)
|
21 |
for param in model.parameters():
|
|
|
34 |
'Watercolor': (10, False),
|
35 |
}
|
36 |
|
37 |
+
def compute_loss(generated_features, content_features, style_features, alpha, beta):
|
38 |
+
content_loss = 0
|
39 |
+
style_loss = 0
|
40 |
+
|
41 |
+
for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
|
42 |
+
batch_size, n_feature_maps, height, width = generated_feature.size()
|
43 |
+
|
44 |
+
content_loss += (torch.mean((generated_feature - content_feature) ** 2))
|
45 |
+
|
46 |
+
G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
|
47 |
+
A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
|
48 |
+
|
49 |
+
E_l = ((G - A) ** 2)
|
50 |
+
w_l = 1/5
|
51 |
+
style_loss += torch.mean(w_l * E_l)
|
52 |
+
|
53 |
+
return alpha * content_loss + beta * style_loss
|
54 |
+
|
55 |
@spaces.GPU(duration=20)
|
56 |
def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
|
57 |
yield None
|
|
|
76 |
st = time.time()
|
77 |
generated_img = content_img.clone().requires_grad_(True)
|
78 |
optimizer = optim.Adam([generated_img], lr=lr)
|
79 |
+
|
80 |
+
content_features = model(content_img)
|
81 |
+
style_features = model(style_img)
|
82 |
|
83 |
for _ in tqdm(range(iters), desc='The magic is happening ✨'):
|
84 |
generated_features = model(generated_img)
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
+
total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
|
87 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
optimizer.zero_grad()
|
89 |
total_loss.backward()
|
90 |
optimizer.step()
|
91 |
|
92 |
et = time.time()
|
93 |
print('TIME TAKEN:', et-st)
|
94 |
+
|
95 |
yield postprocess_img(generated_img, original_size)
|
96 |
|
97 |
|