jamino30 commited on
Commit
814e69a
1 Parent(s): 5464cad

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -12,6 +12,7 @@ from gradio_imageslider import ImageSlider
12
 
13
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
14
  from vgg.vgg19 import VGG_19
 
15
  from inference import inference
16
 
17
  if torch.cuda.is_available(): device = 'cuda'
@@ -20,10 +21,21 @@ else: device = 'cpu'
20
  print('DEVICE:', device)
21
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
22
 
 
 
 
 
 
 
 
 
 
23
  model = VGG_19().to(device).eval()
24
  for param in model.parameters():
25
  param.requires_grad = False
26
- segmentation_model = models.segmentation.deeplabv3_resnet101(
 
 
27
  weights='DEFAULT'
28
  ).to(device).eval()
29
 
@@ -39,7 +51,7 @@ for style_name, style_img_path in style_options.items():
39
  style_features = model(style_img)
40
  cached_style_features[style_name] = style_features
41
 
42
- @spaces.GPU(duration=20)
43
  def run(content_image, style_name, style_strength=10):
44
  yield [None] * 3
45
  content_img, original_size = preprocess_img(content_image, img_size)
@@ -66,7 +78,7 @@ def run(content_image, style_name, style_strength=10):
66
  def run_inference(apply_to_background):
67
  return inference(
68
  model=model,
69
- segmentation_model=segmentation_model,
70
  content_image=content_img,
71
  style_features=style_features,
72
  lr=lrs[style_strength-1],
@@ -81,7 +93,7 @@ def run(content_image, style_name, style_strength=10):
81
  future_all = executor.submit(run_inference, False)
82
  future_bg = executor.submit(run_inference, True)
83
  generated_img_all, _ = future_all.result()
84
- generated_img_bg, salient_object_ratio = future_bg.result()
85
 
86
  et = time.time()
87
  print('TIME TAKEN:', et-st)
@@ -89,7 +101,7 @@ def run(content_image, style_name, style_strength=10):
89
  yield (
90
  (content_image, postprocess_img(generated_img_all, original_size)),
91
  (content_image, postprocess_img(generated_img_bg, original_size)),
92
- f'{salient_object_ratio:.2f}'
93
  )
94
 
95
  def set_slider(value):
@@ -109,7 +121,7 @@ with gr.Blocks(css=css) as demo:
109
  content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
110
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
111
  with gr.Group():
112
- style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=10, info='Higher values add artistic flair, lower values add a realistic feel.')
113
  submit_button = gr.Button('Submit', variant='primary')
114
 
115
  examples = gr.Examples(
@@ -125,7 +137,7 @@ with gr.Blocks(css=css) as demo:
125
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
126
  with gr.Group():
127
  output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
128
- salient_object_ratio_label = gr.Label(label='Salient Object Ratio')
129
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
130
 
131
  def save_image(img_tuple1, img_tuple2):
@@ -142,7 +154,7 @@ with gr.Blocks(css=css) as demo:
142
  submit_button.click(
143
  fn=run,
144
  inputs=[content_image, style_dropdown, style_strength_slider],
145
- outputs=[output_image_all, output_image_background, salient_object_ratio_label]
146
  ).then(
147
  fn=save_image,
148
  inputs=[output_image_all, output_image_background],
 
12
 
13
  from utils import preprocess_img, preprocess_img_from_path, postprocess_img
14
  from vgg.vgg19 import VGG_19
15
+ from u2net.model import U2Net
16
  from inference import inference
17
 
18
  if torch.cuda.is_available(): device = 'cuda'
 
21
  print('DEVICE:', device)
22
  if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())
23
 
24
+ def load_model_without_module(model, model_path):
25
+ state_dict = torch.load(model_path, map_location=device, weights_only=True)
26
+
27
+ new_state_dict = {}
28
+ for k, v in state_dict.items():
29
+ name = k[7:] if k.startswith('module.') else k
30
+ new_state_dict[name] = v
31
+ model.load_state_dict(new_state_dict)
32
+
33
  model = VGG_19().to(device).eval()
34
  for param in model.parameters():
35
  param.requires_grad = False
36
+ # sod_model = U2Net().to(device).eval()
37
+ # load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts.pt')
38
+ sod_model = models.segmentation.deeplabv3_resnet101(
39
  weights='DEFAULT'
40
  ).to(device).eval()
41
 
 
51
  style_features = model(style_img)
52
  cached_style_features[style_name] = style_features
53
 
54
+ @spaces.GPU(duration=30)
55
  def run(content_image, style_name, style_strength=10):
56
  yield [None] * 3
57
  content_img, original_size = preprocess_img(content_image, img_size)
 
78
  def run_inference(apply_to_background):
79
  return inference(
80
  model=model,
81
+ sod_model=sod_model,
82
  content_image=content_img,
83
  style_features=style_features,
84
  lr=lrs[style_strength-1],
 
93
  future_all = executor.submit(run_inference, False)
94
  future_bg = executor.submit(run_inference, True)
95
  generated_img_all, _ = future_all.result()
96
+ generated_img_bg, bg_ratio = future_bg.result()
97
 
98
  et = time.time()
99
  print('TIME TAKEN:', et-st)
 
101
  yield (
102
  (content_image, postprocess_img(generated_img_all, original_size)),
103
  (content_image, postprocess_img(generated_img_bg, original_size)),
104
+ f'{bg_ratio:.2f}'
105
  )
106
 
107
  def set_slider(value):
 
121
  content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
122
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
123
  with gr.Group():
124
+ style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
125
  submit_button = gr.Button('Submit', variant='primary')
126
 
127
  examples = gr.Examples(
 
137
  download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
138
  with gr.Group():
139
  output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
140
+ bg_ratio_label = gr.Label(label='Background Ratio')
141
  download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
142
 
143
  def save_image(img_tuple1, img_tuple2):
 
154
  submit_button.click(
155
  fn=run,
156
  inputs=[content_image, style_dropdown, style_strength_slider],
157
+ outputs=[output_image_all, output_image_background, bg_ratio_label]
158
  ).then(
159
  fn=save_image,
160
  inputs=[output_image_all, output_image_background],
content_images/228.jpg ADDED
content_images/24.jpg ADDED
content_images/baseball.jpg ADDED
content_images/bleachers.jpg ADDED
content_images/dancers.jpg ADDED
content_images/glassesman.jpg ADDED
content_images/ladies.jpg ADDED
content_images/messi.jpg ADDED
content_images/motorcycle.jpg ADDED
inference.py CHANGED
@@ -6,9 +6,6 @@ import torch.optim as optim
6
  import torch.nn.functional as F
7
  from torchvision.transforms.functional import gaussian_blur
8
 
9
- DEV_MODE = os.environ.get('DEV_MODE', None)
10
- print('DEV MODE:', True if DEV_MODE else False)
11
-
12
  def _gram_matrix(feature):
13
  batch_size, n_feature_maps, height, width = feature.size()
14
  new_feature = feature.view(batch_size * n_feature_maps, height * width)
@@ -39,7 +36,7 @@ def _compute_loss(generated_features, content_features, style_features, resized_
39
  def inference(
40
  *,
41
  model,
42
- segmentation_model,
43
  content_image,
44
  style_features,
45
  apply_to_background,
@@ -49,9 +46,6 @@ def inference(
49
  alpha=1,
50
  beta=1,
51
  ):
52
- if DEV_MODE:
53
- from torch.utils.tensorboard import SummaryWriter
54
- writer = SummaryWriter()
55
  generated_image = content_image.clone().requires_grad_(True)
56
  optimizer = optim_caller([generated_image], lr=lr)
57
  min_losses = [float('inf')] * iterations
@@ -61,12 +55,20 @@ def inference(
61
 
62
  resized_bg_masks = []
63
  salient_object_ratio = None
64
- if apply_to_background:
65
- segmentation_output = segmentation_model(content_image)['out']
66
- segmentation_mask = segmentation_output.argmax(dim=1)
 
67
  background_mask = (segmentation_mask == 0).float()
68
  foreground_mask = 1 - background_mask
69
 
 
 
 
 
 
 
 
70
  salient_object_pixel_count = foreground_mask.sum().item()
71
  total_pixel_count = segmentation_mask.numel()
72
  salient_object_ratio = salient_object_pixel_count / total_pixel_count
@@ -85,12 +87,6 @@ def inference(
85
  total_loss.backward()
86
 
87
  # log loss
88
- if DEV_MODE:
89
- writer.add_scalars(f'style-{"background" if apply_to_background else "image"}', {
90
- 'Loss/content': content_loss.item(),
91
- 'Loss/style': style_loss.item(),
92
- 'Loss/total': total_loss.item()
93
- }, iter)
94
  min_losses[iter] = min(min_losses[iter], total_loss.item())
95
 
96
  return total_loss
@@ -102,8 +98,5 @@ def inference(
102
  with torch.no_grad():
103
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
104
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
105
-
106
- if DEV_MODE:
107
- writer.flush()
108
- writer.close()
109
- return generated_image, salient_object_ratio
 
6
  import torch.nn.functional as F
7
  from torchvision.transforms.functional import gaussian_blur
8
 
 
 
 
9
  def _gram_matrix(feature):
10
  batch_size, n_feature_maps, height, width = feature.size()
11
  new_feature = feature.view(batch_size * n_feature_maps, height * width)
 
36
  def inference(
37
  *,
38
  model,
39
+ sod_model,
40
  content_image,
41
  style_features,
42
  apply_to_background,
 
46
  alpha=1,
47
  beta=1,
48
  ):
 
 
 
49
  generated_image = content_image.clone().requires_grad_(True)
50
  optimizer = optim_caller([generated_image], lr=lr)
51
  min_losses = [float('inf')] * iterations
 
55
 
56
  resized_bg_masks = []
57
  salient_object_ratio = None
58
+ if apply_to_background:
59
+ # original
60
+ segmentation_output = sod_model(content_image)['out'] # [1, 21, 512, 512]
61
+ segmentation_mask = segmentation_output.argmax(dim=1) # [1, 512, 512]
62
  background_mask = (segmentation_mask == 0).float()
63
  foreground_mask = 1 - background_mask
64
 
65
+ # new
66
+ # segmentation_output = sod_model(content_image)[0]
67
+ # segmentation_output = torch.sigmoid(segmentation_output)
68
+ # segmentation_mask = (segmentation_output > 0.7).float()
69
+ # background_mask = (segmentation_mask == 0).float()
70
+ # foreground_mask = 1 - background_mask
71
+
72
  salient_object_pixel_count = foreground_mask.sum().item()
73
  total_pixel_count = segmentation_mask.numel()
74
  salient_object_ratio = salient_object_pixel_count / total_pixel_count
 
87
  total_loss.backward()
88
 
89
  # log loss
 
 
 
 
 
 
90
  min_losses[iter] = min(min_losses[iter], total_loss.item())
91
 
92
  return total_loss
 
98
  with torch.no_grad():
99
  foreground_mask_resized = F.interpolate(foreground_mask.unsqueeze(1), size=generated_image.shape[2:], mode='nearest')
100
  generated_image.data = generated_image.data * (1 - foreground_mask_resized) + content_image.data * foreground_mask_resized
101
+
102
+ return generated_image, salient_object_ratio
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,102 @@
1
- --extra-index-url https://download.pytorch.org/whl/cu113
2
- torch
3
- torchvision
4
- pillow
5
- gradio
6
- gradio_imageslider
7
- spaces
8
- tqdm
9
- tensorboard
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiofiles==23.2.1
3
+ annotated-types==0.7.0
4
+ anyio==4.6.0
5
+ appnope==0.1.4
6
+ asttokens==2.4.1
7
+ certifi==2024.8.30
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ comm==0.2.2
11
+ contourpy==1.3.0
12
+ cycler==0.12.1
13
+ debugpy==1.8.7
14
+ decorator==5.1.1
15
+ executing==2.1.0
16
+ fastapi==0.115.0
17
+ ffmpy==0.4.0
18
+ filelock==3.16.1
19
+ fonttools==4.54.1
20
+ fsspec==2024.9.0
21
+ gradio==4.44.0
22
+ gradio_client==1.3.0
23
+ gradio_imageslider==0.0.20
24
+ grpcio==1.66.1
25
+ h11==0.14.0
26
+ httpcore==1.0.5
27
+ httpx==0.27.2
28
+ huggingface-hub==0.25.1
29
+ idna==3.10
30
+ importlib_resources==6.4.5
31
+ ipykernel==6.29.5
32
+ ipython==8.28.0
33
+ jedi==0.19.1
34
+ Jinja2==3.1.4
35
+ joblib==1.4.2
36
+ jupyter_client==8.6.3
37
+ jupyter_core==5.7.2
38
+ kiwisolver==1.4.7
39
+ Markdown==3.7
40
+ markdown-it-py==3.0.0
41
+ MarkupSafe==2.1.5
42
+ matplotlib==3.9.2
43
+ matplotlib-inline==0.1.7
44
+ mdurl==0.1.2
45
+ mpmath==1.3.0
46
+ nest-asyncio==1.6.0
47
+ networkx==3.3
48
+ numpy==2.1.1
49
+ opencv-python==4.10.0.84
50
+ orjson==3.10.7
51
+ packaging==24.1
52
+ pandas==2.2.3
53
+ parso==0.8.4
54
+ pexpect==4.9.0
55
+ pillow==10.4.0
56
+ platformdirs==4.3.6
57
+ prompt_toolkit==3.0.48
58
+ protobuf==5.28.2
59
+ psutil==5.9.8
60
+ ptyprocess==0.7.0
61
+ pure_eval==0.2.3
62
+ pydantic==2.9.2
63
+ pydantic_core==2.23.4
64
+ pydub==0.25.1
65
+ Pygments==2.18.0
66
+ pyparsing==3.1.4
67
+ python-dateutil==2.9.0.post0
68
+ python-multipart==0.0.10
69
+ pytz==2024.2
70
+ PyYAML==6.0.2
71
+ pyzmq==26.2.0
72
+ requests==2.32.3
73
+ rich==13.8.1
74
+ ruff==0.6.8
75
+ scikit-learn==1.5.2
76
+ scipy==1.14.1
77
+ semantic-version==2.10.0
78
+ setuptools==75.1.0
79
+ shellingham==1.5.4
80
+ six==1.16.0
81
+ sniffio==1.3.1
82
+ spaces==0.30.2
83
+ stack-data==0.6.3
84
+ starlette==0.38.6
85
+ sympy==1.13.3
86
+ tensorboard==2.18.0
87
+ tensorboard-data-server==0.7.2
88
+ threadpoolctl==3.5.0
89
+ tomlkit==0.12.0
90
+ torch==2.4.1
91
+ torchvision==0.19.1
92
+ tornado==6.4.1
93
+ tqdm==4.66.5
94
+ traitlets==5.14.3
95
+ typer==0.12.5
96
+ typing_extensions==4.12.2
97
+ tzdata==2024.2
98
+ urllib3==2.2.3
99
+ uvicorn==0.30.6
100
+ wcwidth==0.2.13
101
+ websockets==12.0
102
+ Werkzeug==3.0.4
utils.py CHANGED
@@ -12,7 +12,8 @@ def preprocess_img(img: Image, img_size):
12
 
13
  transform = transforms.Compose([
14
  transforms.Resize((img_size, img_size)),
15
- transforms.ToTensor()
 
16
  ])
17
  img = transform(img).unsqueeze(0)
18
  return img, original_size
@@ -20,9 +21,11 @@ def preprocess_img(img: Image, img_size):
20
  def postprocess_img(img, original_size):
21
  img = img.detach().cpu().squeeze(0)
22
 
23
- # address tensor value scaling and quantization
 
 
 
24
  img = torch.clamp(img, 0, 1)
25
- img = img.mul(255).byte()
26
 
27
  img = transforms.ToPILImage()(img)
28
  img = img.resize(original_size, Image.Resampling.LANCZOS)
 
12
 
13
  transform = transforms.Compose([
14
  transforms.Resize((img_size, img_size)),
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
17
  ])
18
  img = transform(img).unsqueeze(0)
19
  return img, original_size
 
21
  def postprocess_img(img, original_size):
22
  img = img.detach().cpu().squeeze(0)
23
 
24
+ # Denormalize the image
25
+ mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
26
+ std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
27
+ img = img * std + mean
28
  img = torch.clamp(img, 0, 1)
 
29
 
30
  img = transforms.ToPILImage()(img)
31
  img = img.resize(original_size, Image.Resampling.LANCZOS)