File size: 6,383 Bytes
00710e8
88b9835
a0f5d9f
0fc4e06
ad85111
 
75f2ed4
bbcd902
e16d255
75f2ed4
ce6dca2
75f2ed4
57a96a1
e6f200a
814e69a
91d9343
67d69a3
a84e446
 
 
75f2ed4
14fd49f
75f2ed4
814e69a
289faff
814e69a
 
 
 
 
 
 
917ebd2
75f2ed4
 
fc92636
 
75f2ed4
00710e8
 
764d4ab
 
88b9835
b7a47e5
 
764d4ab
980e9a0
764d4ab
b7a47e5
 
814e69a
fa57e87
e21f7c8
57a96a1
75f2ed4
88b9835
91d9343
 
 
88b9835
d767ccb
75f2ed4
764d4ab
962b2f7
91d9343
0fc4e06
e21f7c8
 
 
0fc4e06
e21f7c8
0fc4e06
e21f7c8
 
 
 
 
814e69a
e21f7c8
 
 
 
 
0fc4e06
 
e21f7c8
 
 
 
 
 
 
814e69a
0fc4e06
88b9835
 
14fd49f
0fc4e06
 
e21f7c8
814e69a
0fc4e06
75f2ed4
424869b
 
ab16048
 
 
 
7f82183
ab16048
 
 
7b732c2
0803fb8
ce6dca2
 
 
 
 
814e69a
ce6dca2
 
 
 
fa57e87
 
ce6dca2
fa57e87
ce6dca2
22696bb
ce6dca2
0fc4e06
 
e21f7c8
 
814e69a
0fc4e06
ce6dca2
0fc4e06
 
 
 
 
ce6dca2
 
0fc4e06
 
ce6dca2
ab16048
ce6dca2
 
0fc4e06
814e69a
ce6dca2
 
0fc4e06
 
ce6dca2
0fc4e06
 
ce6dca2
de50edd
 
 
fc92636
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import time
from datetime import datetime, timezone, timedelta
from concurrent.futures import ThreadPoolExecutor

import spaces
import torch
import torchvision.models as models
import numpy as np
import gradio as gr
from gradio_imageslider import ImageSlider

from utils import preprocess_img, preprocess_img_from_path, postprocess_img
from vgg.vgg19 import VGG_19
from u2net.model import U2Net
from inference import inference

if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
print('DEVICE:', device)
if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())

def load_model_without_module(model, model_path):
    state_dict = torch.load(model_path, map_location=device, weights_only=False)

    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    
model = VGG_19().to(device).eval()
for param in model.parameters():
    param.requires_grad = False
sod_model = U2Net().to(device).eval()
load_model_without_module(sod_model, 'u2net/saved_models/u2net-duts.pt')

style_files = os.listdir('./style_images')
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
img_size = 512

cached_style_features = {}
for style_name, style_img_path in style_options.items():
    style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device)
    with torch.no_grad():
        style_features = model(style_img)
    cached_style_features[style_name] = style_features 

@spaces.GPU(duration=30)
def run(content_image, style_name, style_strength=10):
    yield [None] * 3
    content_img, original_size = preprocess_img(content_image, img_size)
    content_img = content_img.to(device)
    
    print('-'*15)
    print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
    print('STYLE:', style_name)
    print('CONTENT IMG SIZE:', original_size)
    print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')

    style_features = cached_style_features[style_name]
    
    st = time.time()
    
    if device == 'cuda':
        stream_all = torch.cuda.Stream()
        stream_bg = torch.cuda.Stream()

    def run_inference_cuda(apply_to_background, stream):
        with torch.cuda.stream(stream):
            return run_inference(apply_to_background)
        
    def run_inference(apply_to_background):
        return inference(
            model=model,
            sod_model=sod_model,
            content_image=content_img,
            style_features=style_features,
            lr=lrs[style_strength-1],
            apply_to_background=apply_to_background
        )

    with ThreadPoolExecutor() as executor:
        if device == 'cuda':
            future_all = executor.submit(run_inference_cuda, False, stream_all)
            future_bg = executor.submit(run_inference_cuda, True, stream_bg)
        else:
            future_all = executor.submit(run_inference, False)
            future_bg = executor.submit(run_inference, True)
        generated_img_all, _ = future_all.result()
        generated_img_bg, bg_ratio = future_bg.result()

    et = time.time()
    print('TIME TAKEN:', et-st)
    
    yield (
        (content_image, postprocess_img(generated_img_all, original_size)),
        (content_image, postprocess_img(generated_img_bg, original_size)),
        f'{bg_ratio:.2f}'
    )

def set_slider(value):
    return gr.update(value=value)

css = """
#container {
    margin: 0 auto;
    max-width: 1200px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Object Masking")
    with gr.Row(elem_id='container'):
        with gr.Column():
            content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
            style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
            with gr.Group():
                style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
            submit_button = gr.Button('Submit', variant='primary')
            
            examples = gr.Examples(
                examples=[
                    ['./content_images/GoldenRetriever.jpg', 'Starry Night'],
                    ['./content_images/CameraGirl.jpg', 'Bokeh']
                ],
                inputs=[content_image, style_dropdown]
            )

        with gr.Column():
            output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
            download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
            with gr.Group():
                output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
                bg_ratio_label = gr.Label(label='Background Ratio')
            download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)

    def save_image(img_tuple1, img_tuple2):
        filename1, filename2 = 'generated-all.jpg', 'generated-bg.jpg'
        img_tuple1[1].save(filename1)
        img_tuple2[1].save(filename2)
        return filename1, filename2
    
    submit_button.click(
        fn=lambda: [gr.update(visible=False) for _ in range(2)],
        outputs=[download_button_1, download_button_2]
    )
        
    submit_button.click(
        fn=run, 
        inputs=[content_image, style_dropdown, style_strength_slider], 
        outputs=[output_image_all, output_image_background, bg_ratio_label]
    ).then(
        fn=save_image,
        inputs=[output_image_all, output_image_background],
        outputs=[download_button_1, download_button_2]
    ).then(
        fn=lambda: [gr.update(visible=True) for _ in range(2)],
        outputs=[download_button_1, download_button_2]
    )

demo.queue = False
demo.config['queue'] = False
demo.launch(show_api=False)