jamino30 commited on
Commit
a9077eb
1 Parent(s): a706eb7

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +5 -3
  2. inference.py +44 -6
app.py CHANGED
@@ -35,7 +35,7 @@ for style_name, style_img_path in style_options.items():
35
  cached_style_features[style_name] = style_features
36
 
37
  @spaces.GPU(duration=10)
38
- def run(content_image, style_name, style_strength=5, progress=gr.Progress(track_tqdm=True)):
39
  yield None
40
  content_img, original_size = preprocess_img(content_image, img_size)
41
  content_img = content_img.to(device)
@@ -53,7 +53,8 @@ def run(content_image, style_name, style_strength=5, progress=gr.Progress(track_
53
  model=model,
54
  content_image=content_img,
55
  style_features=style_features,
56
- lr=lrs[style_strength-1]
 
57
  )
58
  et = time.time()
59
  print('TIME TAKEN:', et-st)
@@ -78,6 +79,7 @@ with gr.Blocks(css=css) as demo:
78
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
79
  with gr.Group():
80
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
 
81
  submit_button = gr.Button('Submit', variant='primary')
82
 
83
  examples = gr.Examples(
@@ -105,7 +107,7 @@ with gr.Blocks(css=css) as demo:
105
 
106
  submit_button.click(
107
  fn=run,
108
- inputs=[content_image, style_dropdown, style_strength_slider],
109
  outputs=[output_image]
110
  ).then(
111
  fn=save_image,
 
35
  cached_style_features[style_name] = style_features
36
 
37
  @spaces.GPU(duration=10)
38
+ def run(content_image, style_name, style_strength=5, apply_to_background=False, progress=gr.Progress(track_tqdm=True)):
39
  yield None
40
  content_img, original_size = preprocess_img(content_image, img_size)
41
  content_img = content_img.to(device)
 
53
  model=model,
54
  content_image=content_img,
55
  style_features=style_features,
56
+ lr=lrs[style_strength-1],
57
+ apply_to_background=apply_to_background
58
  )
59
  et = time.time()
60
  print('TIME TAKEN:', et-st)
 
79
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
80
  with gr.Group():
81
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
82
+ apply_to_background = gr.Checkbox(label='Apply to background only')
83
  submit_button = gr.Button('Submit', variant='primary')
84
 
85
  examples = gr.Examples(
 
107
 
108
  submit_button.click(
109
  fn=run,
110
+ inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background],
111
  outputs=[output_image]
112
  ).then(
113
  fn=save_image,
inference.py CHANGED
@@ -3,21 +3,34 @@ from tqdm import tqdm
3
  import torch
4
  import torch.optim as optim
5
  import torch.nn.functional as F
 
 
6
 
7
  def _gram_matrix(feature):
8
  batch_size, n_feature_maps, height, width = feature.size()
9
  new_feature = feature.view(batch_size * n_feature_maps, height * width)
10
  return torch.mm(new_feature, new_feature.t())
11
 
12
- def _compute_loss(generated_features, content_features, style_features, alpha, beta):
13
  content_loss = 0
14
  style_loss = 0
15
  w_l = 1 / len(generated_features)
16
- for gf, cf, sf in zip(generated_features, content_features, style_features):
 
17
  content_loss += F.mse_loss(gf, cf)
18
- G = _gram_matrix(gf)
19
- A = _gram_matrix(sf)
 
 
 
 
 
 
 
 
 
20
  style_loss += w_l * F.mse_loss(G, A)
 
21
  return alpha * content_loss + beta * style_loss
22
 
23
  def inference(
@@ -25,11 +38,12 @@ def inference(
25
  model,
26
  content_image,
27
  style_features,
 
28
  lr,
29
  iterations=101,
30
  optim_caller=optim.AdamW,
31
  alpha=1,
32
- beta=1
33
  ):
34
  generated_image = content_image.clone().requires_grad_(True)
35
  optimizer = optim_caller([generated_image], lr=lr)
@@ -37,17 +51,41 @@ def inference(
37
 
38
  with torch.no_grad():
39
  content_features = model(content_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def closure(iter):
42
  optimizer.zero_grad()
43
  generated_features = model(generated_image)
44
- total_loss = _compute_loss(generated_features, content_features, style_features, alpha, beta)
 
 
45
  total_loss.backward()
46
  min_losses[iter] = min(min_losses[iter], total_loss.item())
47
  return total_loss
48
 
49
  for iter in tqdm(range(iterations), desc='The magic is happening ✨'):
50
  optimizer.step(lambda: closure(iter))
 
 
 
 
 
 
51
  if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
52
 
53
  return generated_image
 
3
  import torch
4
  import torch.optim as optim
5
  import torch.nn.functional as F
6
+ from torchvision.transforms.functional import gaussian_blur
7
+ from torchvision import models
8
 
9
  def _gram_matrix(feature):
10
  batch_size, n_feature_maps, height, width = feature.size()
11
  new_feature = feature.view(batch_size * n_feature_maps, height * width)
12
  return torch.mm(new_feature, new_feature.t())
13
 
14
+ def _compute_loss(generated_features, content_features, style_features, resized_bg_masks, alpha, beta):
15
  content_loss = 0
16
  style_loss = 0
17
  w_l = 1 / len(generated_features)
18
+
19
+ for i, (gf, cf, sf) in enumerate(zip(generated_features, content_features, style_features)):
20
  content_loss += F.mse_loss(gf, cf)
21
+
22
+ if resized_bg_masks:
23
+ blurred_bg_mask = gaussian_blur(resized_bg_masks[i], kernel_size=5)
24
+ masked_gf = gf * blurred_bg_mask
25
+ masked_sf = sf * blurred_bg_mask
26
+ G = _gram_matrix(masked_gf)
27
+ A = _gram_matrix(masked_sf)
28
+ else:
29
+ G = _gram_matrix(gf)
30
+ A = _gram_matrix(sf)
31
+ style_loss += w_l * F.mse_loss(G, A)
32
  style_loss += w_l * F.mse_loss(G, A)
33
+
34
  return alpha * content_loss + beta * style_loss
35
 
36
  def inference(
 
38
  model,
39
  content_image,
40
  style_features,
41
+ apply_to_background,
42
  lr,
43
  iterations=101,
44
  optim_caller=optim.AdamW,
45
  alpha=1,
46
+ beta=1,
47
  ):
48
  generated_image = content_image.clone().requires_grad_(True)
49
  optimizer = optim_caller([generated_image], lr=lr)
 
51
 
52
  with torch.no_grad():
53
  content_features = model(content_image)
54
+
55
+ resized_bg_masks = []
56
+ if apply_to_background:
57
+ segmentation_model = models.segmentation.deeplabv3_resnet101(weights='DEFAULT').eval()
58
+ segmentation_model = segmentation_model.to(content_image.device)
59
+
60
+ segmentation_output = segmentation_model(content_image)['out']
61
+ segmentation_mask = segmentation_output.argmax(dim=1)
62
+
63
+ background_mask = (segmentation_mask == 0).float()
64
+ foreground_mask = (segmentation_mask != 0).float()
65
+
66
+ for cf in content_features:
67
+ _, _, h_i, w_i = cf.shape
68
+ bg_mask = F.interpolate(background_mask.unsqueeze(1), size=(h_i, w_i), mode='bilinear', align_corners=False)
69
+ resized_bg_masks.append(bg_mask)
70
 
71
  def closure(iter):
72
  optimizer.zero_grad()
73
  generated_features = model(generated_image)
74
+ total_loss = _compute_loss(
75
+ generated_features, content_features, style_features, resized_bg_masks, alpha, beta
76
+ )
77
  total_loss.backward()
78
  min_losses[iter] = min(min_losses[iter], total_loss.item())
79
  return total_loss
80
 
81
  for iter in tqdm(range(iterations), desc='The magic is happening ✨'):
82
  optimizer.step(lambda: closure(iter))
83
+
84
+ if apply_to_background:
85
+ with torch.no_grad():
86
+ foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
87
+ generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
88
+
89
  if iter % 10 == 0: print(f'Loss ({iter}):', min_losses[iter])
90
 
91
  return generated_image