jamino30 commited on
Commit
962b2f7
·
verified ·
1 Parent(s): 1b9bef7

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -40,6 +40,7 @@ def inference(content_image, style_image, style_strength, output_quality, progre
40
  print('-'*15)
41
  print('DATETIME:', datetime.datetime.now())
42
  print('STYLE:', style_image)
 
43
  img_size = 1024 if output_quality else 512
44
  content_img, original_size = preprocess_img(content_image, img_size)
45
  content_img = content_img.to(device)
@@ -58,10 +59,14 @@ def inference(content_image, style_image, style_strength, output_quality, progre
58
  st = time.time()
59
  generated_img = content_img.clone().requires_grad_(True)
60
  optimizer = optim.Adam([generated_img], lr=lr)
61
-
62
- for _ in tqdm(range(iters), desc='The magic is happening ✨'):
63
  content_features = model(content_img)
64
  style_features = model(style_img)
 
 
 
 
65
  generated_features = model(generated_img)
66
 
67
  content_loss = 0
@@ -81,7 +86,6 @@ def inference(content_image, style_image, style_strength, output_quality, progre
81
 
82
  total_loss = alpha * content_loss + beta * style_loss
83
 
84
- optimizer.zero_grad()
85
  total_loss.backward()
86
  optimizer.step()
87
 
 
40
  print('-'*15)
41
  print('DATETIME:', datetime.datetime.now())
42
  print('STYLE:', style_image)
43
+
44
  img_size = 1024 if output_quality else 512
45
  content_img, original_size = preprocess_img(content_image, img_size)
46
  content_img = content_img.to(device)
 
59
  st = time.time()
60
  generated_img = content_img.clone().requires_grad_(True)
61
  optimizer = optim.Adam([generated_img], lr=lr)
62
+
63
+ with torch.no_grad():
64
  content_features = model(content_img)
65
  style_features = model(style_img)
66
+
67
+ for _ in tqdm(range(iters), desc='The magic is happening ✨'):
68
+ optimizer.zero_grad()
69
+
70
  generated_features = model(generated_img)
71
 
72
  content_loss = 0
 
86
 
87
  total_loss = alpha * content_loss + beta * style_loss
88
 
 
89
  total_loss.backward()
90
  optimizer.step()
91