jamino30 commited on
Commit
3b42de6
·
verified ·
1 Parent(s): 17b3937

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +6 -8
  2. inference.py +1 -6
app.py CHANGED
@@ -93,16 +93,15 @@ def run(content_image, style_name, style_strength=10):
93
  else:
94
  future_all = executor.submit(run_inference, False)
95
  future_bg = executor.submit(run_inference, True)
96
- generated_img_all, _ = future_all.result()
97
- generated_img_bg, bg_ratio = future_bg.result()
98
 
99
  et = time.time()
100
  print('TIME TAKEN:', et-st)
101
 
102
  yield (
103
  (content_image, postprocess_img(generated_img_all, original_size)),
104
- (content_image, postprocess_img(generated_img_bg, original_size)),
105
- f'{bg_ratio:.2f}'
106
  )
107
 
108
  def set_slider(value):
@@ -116,13 +115,13 @@ css = """
116
  """
117
 
118
  with gr.Blocks(css=css) as demo:
119
- gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Object Masking")
120
  with gr.Row(elem_id='container'):
121
  with gr.Column():
122
  content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
123
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
124
  with gr.Group():
125
- style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
126
  submit_button = gr.Button('Submit', variant='primary')
127
 
128
  examples = gr.Examples(
@@ -139,7 +138,6 @@ with gr.Blocks(css=css) as demo:
139
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
140
  with gr.Group():
141
  output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
142
- bg_ratio_label = gr.Label(label='Background Ratio')
143
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
144
 
145
  def save_image(img_tuple1, img_tuple2):
@@ -156,7 +154,7 @@ with gr.Blocks(css=css) as demo:
156
  submit_button.click(
157
  fn=run,
158
  inputs=[content_image, style_dropdown, style_strength_slider],
159
- outputs=[output_image_all, output_image_background, bg_ratio_label]
160
  ).then(
161
  fn=save_image,
162
  inputs=[output_image_all, output_image_background],
 
93
  else:
94
  future_all = executor.submit(run_inference, False)
95
  future_bg = executor.submit(run_inference, True)
96
+ generated_img_all = future_all.result()
97
+ generated_img_bg = future_bg.result()
98
 
99
  et = time.time()
100
  print('TIME TAKEN:', et-st)
101
 
102
  yield (
103
  (content_image, postprocess_img(generated_img_all, original_size)),
104
+ (content_image, postprocess_img(generated_img_bg, original_size))
 
105
  )
106
 
107
  def set_slider(value):
 
115
  """
116
 
117
  with gr.Blocks(css=css) as demo:
118
+ gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Region Preservation")
119
  with gr.Row(elem_id='container'):
120
  with gr.Column():
121
  content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
122
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
123
  with gr.Group():
124
+ style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
125
  submit_button = gr.Button('Submit', variant='primary')
126
 
127
  examples = gr.Examples(
 
138
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
139
  with gr.Group():
140
  output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
 
141
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
142
 
143
  def save_image(img_tuple1, img_tuple2):
 
154
  submit_button.click(
155
  fn=run,
156
  inputs=[content_image, style_dropdown, style_strength_slider],
157
+ outputs=[output_image_all, output_image_background]
158
  ).then(
159
  fn=save_image,
160
  inputs=[output_image_all, output_image_background],
inference.py CHANGED
@@ -55,17 +55,12 @@ def inference(
55
  content_features = model(content_image)
56
 
57
  resized_bg_masks = []
58
- salient_object_ratio = None
59
  if apply_to_background:
60
  segmentation_output = sod_model(content_image_norm)[0]
61
  segmentation_output = torch.sigmoid(segmentation_output)
62
  segmentation_mask = (segmentation_output > 0.7).float()
63
  background_mask = (segmentation_mask == 0).float()
64
  foreground_mask = 1 - background_mask
65
-
66
- salient_object_pixel_count = foreground_mask.sum().item()
67
- total_pixel_count = segmentation_mask.numel()
68
- salient_object_ratio = salient_object_pixel_count / total_pixel_count
69
 
70
  for cf in content_features:
71
  _, _, h_i, w_i = cf.shape
@@ -93,4 +88,4 @@ def inference(
93
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
94
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
95
 
96
- return generated_image, salient_object_ratio
 
55
  content_features = model(content_image)
56
 
57
  resized_bg_masks = []
 
58
  if apply_to_background:
59
  segmentation_output = sod_model(content_image_norm)[0]
60
  segmentation_output = torch.sigmoid(segmentation_output)
61
  segmentation_mask = (segmentation_output > 0.7).float()
62
  background_mask = (segmentation_mask == 0).float()
63
  foreground_mask = 1 - background_mask
 
 
 
 
64
 
65
  for cf in content_features:
66
  _, _, h_i, w_i = cf.shape
 
88
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
89
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
90
 
91
+ return generated_image