jamino30 commited on
Commit
14fd49f
1 Parent(s): f581e92

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +25 -18
app.py CHANGED
@@ -15,6 +15,7 @@ if torch.cuda.is_available(): device = 'cuda'
15
  elif torch.backends.mps.is_available(): device = 'mps'
16
  else: device = 'cpu'
17
  print('DEVICE:', device)
 
18
 
19
  model = VGG_19().to(device)
20
  for param in model.parameters():
@@ -33,6 +34,24 @@ optimal_settings = {
33
  'Watercolor': (10, False),
34
  }
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  @spaces.GPU(duration=20)
37
  def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
38
  yield None
@@ -57,34 +76,22 @@ def inference(content_image, style_image, style_strength, output_quality, progre
57
  st = time.time()
58
  generated_img = content_img.clone().requires_grad_(True)
59
  optimizer = optim.Adam([generated_img], lr=lr)
 
 
 
60
 
61
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
62
  generated_features = model(generated_img)
63
- content_features = model(content_img)
64
- style_features = model(style_img)
65
-
66
- content_loss = 0
67
- style_loss = 0
68
 
69
- for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
70
- batch_size, n_feature_maps, height, width = generated_feature.size()
71
-
72
- content_loss += (torch.mean((generated_feature - content_feature) ** 2))
73
-
74
- G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
75
- A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
76
-
77
- E_l = ((G - A) ** 2)
78
- w_l = 1/5
79
- style_loss += torch.mean(w_l * E_l)
80
-
81
- total_loss = alpha * content_loss + beta * style_loss
82
  optimizer.zero_grad()
83
  total_loss.backward()
84
  optimizer.step()
85
 
86
  et = time.time()
87
  print('TIME TAKEN:', et-st)
 
88
  yield postprocess_img(generated_img, original_size)
89
 
90
 
 
15
  elif torch.backends.mps.is_available(): device = 'mps'
16
  else: device = 'cpu'
17
  print('DEVICE:', device)
18
+ if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
19
 
20
  model = VGG_19().to(device)
21
  for param in model.parameters():
 
34
  'Watercolor': (10, False),
35
  }
36
 
37
+ def compute_loss(generated_features, content_features, style_features, alpha, beta):
38
+ content_loss = 0
39
+ style_loss = 0
40
+
41
+ for generated_feature, content_feature, style_feature in zip(generated_features, content_features, style_features):
42
+ batch_size, n_feature_maps, height, width = generated_feature.size()
43
+
44
+ content_loss += (torch.mean((generated_feature - content_feature) ** 2))
45
+
46
+ G = torch.mm((generated_feature.view(batch_size * n_feature_maps, height * width)), (generated_feature.view(batch_size * n_feature_maps, height * width)).t())
47
+ A = torch.mm((style_feature.view(batch_size * n_feature_maps, height * width)), (style_feature.view(batch_size * n_feature_maps, height * width)).t())
48
+
49
+ E_l = ((G - A) ** 2)
50
+ w_l = 1/5
51
+ style_loss += torch.mean(w_l * E_l)
52
+
53
+ return alpha * content_loss + beta * style_loss
54
+
55
  @spaces.GPU(duration=20)
56
  def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
57
  yield None
 
76
  st = time.time()
77
  generated_img = content_img.clone().requires_grad_(True)
78
  optimizer = optim.Adam([generated_img], lr=lr)
79
+
80
+ content_features = model(content_img)
81
+ style_features = model(style_img)
82
 
83
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
84
  generated_features = model(generated_img)
 
 
 
 
 
85
 
86
+ total_loss = compute_loss(generated_features, content_features, style_features, alpha, beta)
87
+
 
 
 
 
 
 
 
 
 
 
 
88
  optimizer.zero_grad()
89
  total_loss.backward()
90
  optimizer.step()
91
 
92
  et = time.time()
93
  print('TIME TAKEN:', et-st)
94
+
95
  yield postprocess_img(generated_img, original_size)
96
 
97