jamino30 commited on
Commit
1b9bef7
1 Parent(s): 66586b2

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -34,24 +34,6 @@ optimal_settings = {
34
  'Watercolor': (10, False),
35
  }
36
 
37
- def compute_loss(generated_features, content_features, style_features, alpha, beta):
38
- content_loss = 0
39
- style_loss = 0
40
-
41
- for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
42
- batch_size, n_feature_maps, height, width = generated_feature.size()
43
-
44
- content_loss += (torch.mean((generated_feature - content_feature) ** 2))
45
-
46
- G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
47
- A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
48
-
49
- E_l = ((G - A) ** 2)
50
- w_l = 1/5
51
- style_loss += torch.mean(w_l * E_l)
52
-
53
- return alpha * content_loss + beta * style_loss
54
-
55
  @spaces.GPU(duration=20)
56
  def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
57
  yield None
@@ -76,14 +58,28 @@ def inference(content_image, style_image, style_strength, output_quality, progre
76
  st = time.time()
77
  generated_img = content_img.clone().requires_grad_(True)
78
  optimizer = optim.Adam([generated_img], lr=lr)
79
-
80
- content_features = model(content_img)
81
- style_features = model(style_img)
82
 
83
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
 
 
84
  generated_features = model(generated_img)
85
 
86
- total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  optimizer.zero_grad()
89
  total_loss.backward()
 
34
  'Watercolor': (10, False),
35
  }
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @spaces.GPU(duration=20)
38
  def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
39
  yield None
 
58
  st = time.time()
59
  generated_img = content_img.clone().requires_grad_(True)
60
  optimizer = optim.Adam([generated_img], lr=lr)
 
 
 
61
 
62
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
63
+ content_features = model(content_img)
64
+ style_features = model(style_img)
65
  generated_features = model(generated_img)
66
 
67
+ content_loss = 0
68
+ style_loss = 0
69
+
70
+ for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
71
+ batch_size, n_feature_maps, height, width = generated_feature.size()
72
+
73
+ content_loss += (torch.mean((generated_feature - content_feature) ** 2))
74
+
75
+ G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
76
+ A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
77
+
78
+ E_l = ((G - A) ** 2)
79
+ w_l = 1/5
80
+ style_loss += torch.mean(w_l * E_l)
81
+
82
+ total_loss = alpha * content_loss + beta * style_loss
83
 
84
  optimizer.zero_grad()
85
  total_loss.backward()