jamino30 commited on
Commit
59c9938
1 Parent(s): 57ebd4f

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -21,7 +21,6 @@ model = VGG_19().to(device)
21
  for param in model.parameters():
22
  param.requires_grad = False
23
 
24
-
25
  style_files = os.listdir('./style_images')
26
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
27
  optimal_settings = {
@@ -34,6 +33,15 @@ optimal_settings = {
34
  'Watercolor': (10, False),
35
  }
36
 
 
 
 
 
 
 
 
 
 
37
  def compute_loss(generated_features, content_features, style_features, alpha, beta):
38
  content_loss = 0
39
  style_loss = 0
@@ -53,23 +61,21 @@ def compute_loss(generated_features, content_features, style_features, alpha, be
53
  return alpha * content_loss + beta * style_loss
54
 
55
  @spaces.GPU(duration=20)
56
- def inference(content_image, style_image, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
57
  yield None
58
  print('-'*15)
59
  print('DATETIME:', datetime.datetime.now())
60
- print('STYLE:', style_image)
61
 
62
  img_size = 1024 if output_quality else 512
63
  content_img, original_size = preprocess_img(content_image, img_size)
64
  content_img = content_img.to(device)
65
- style_img = preprocess_img_from_path(style_options[style_image], img_size)[0].to(device)
66
 
67
  print('CONTENT IMG SIZE:', original_size)
68
  print('STYLE STRENGTH:', style_strength)
69
  print('HIGH QUALITY:', output_quality)
70
 
71
  iters = 50
72
- # learning rate determined by input
73
  lr = 0.001 + (0.099 / 99) * (style_strength - 1)
74
  alpha = 1
75
  beta = 1
@@ -80,7 +86,9 @@ def inference(content_image, style_image, style_strength, output_quality, progre
80
 
81
  with torch.no_grad():
82
  content_features = model(content_img)
83
- style_features = model(style_img)
 
 
84
 
85
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
86
  optimizer.zero_grad()
 
21
  for param in model.parameters():
22
  param.requires_grad = False
23
 
 
24
  style_files = os.listdir('./style_images')
25
  style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
26
  optimal_settings = {
 
33
  'Watercolor': (10, False),
34
  }
35
 
36
+ cached_style_features = {}
37
+ for style_name, style_img_path in style_options.items():
38
+ style_img_512 = preprocess_img_from_path(style_img_path, 512)[0].to(device)
39
+ style_img_1024 = preprocess_img_from_path(style_img_path, 1024)[0].to(device)
40
+ with torch.no_grad():
41
+ style_features = (model(style_img_512), model(style_img_1024))
42
+ cached_style_features[style_name] = style_features
43
+
44
+
45
  def compute_loss(generated_features, content_features, style_features, alpha, beta):
46
  content_loss = 0
47
  style_loss = 0
 
61
  return alpha * content_loss + beta * style_loss
62
 
63
  @spaces.GPU(duration=20)
64
+ def inference(content_image, style_name, style_strength, output_quality, progress=gr.Progress(track_tqdm=True)):
65
  yield None
66
  print('-'*15)
67
  print('DATETIME:', datetime.datetime.now())
68
+ print('STYLE:', style_name)
69
 
70
  img_size = 1024 if output_quality else 512
71
  content_img, original_size = preprocess_img(content_image, img_size)
72
  content_img = content_img.to(device)
 
73
 
74
  print('CONTENT IMG SIZE:', original_size)
75
  print('STYLE STRENGTH:', style_strength)
76
  print('HIGH QUALITY:', output_quality)
77
 
78
  iters = 50
 
79
  lr = 0.001 + (0.099 / 99) * (style_strength - 1)
80
  alpha = 1
81
  beta = 1
 
86
 
87
  with torch.no_grad():
88
  content_features = model(content_img)
89
+ style_features = cached_style_features[style_name]
90
+ if img_size == 512: style_features = style_features[0]
91
+ else: style_features = style_features[1]
92
 
93
  for _ in tqdm(range(iters), desc='The magic is happening ✨'):
94
  optimizer.zero_grad()