File size: 6,137 Bytes
00710e8
88b9835
a0f5d9f
0fc4e06
ad85111
 
75f2ed4
bbcd902
e16d255
75f2ed4
ce6dca2
75f2ed4
57a96a1
e6f200a
91d9343
67d69a3
a84e446
 
 
75f2ed4
14fd49f
75f2ed4
917ebd2
75f2ed4
 
bbcd902
 
 
75f2ed4
00710e8
 
764d4ab
 
88b9835
b7a47e5
 
764d4ab
980e9a0
764d4ab
b7a47e5
 
0fc4e06
9549c37
e21f7c8
57a96a1
75f2ed4
88b9835
91d9343
 
 
88b9835
d767ccb
75f2ed4
764d4ab
962b2f7
91d9343
0fc4e06
e21f7c8
 
 
0fc4e06
e21f7c8
0fc4e06
e21f7c8
 
 
 
 
 
 
 
 
 
 
0fc4e06
 
e21f7c8
 
 
 
 
 
 
 
0fc4e06
88b9835
 
14fd49f
0fc4e06
 
e21f7c8
 
0fc4e06
75f2ed4
424869b
 
ab16048
 
 
 
ce6dca2
ab16048
 
 
7b732c2
cc9f69c
ce6dca2
 
 
 
 
a706eb7
ce6dca2
 
 
 
0fc4e06
 
 
ce6dca2
0fc4e06
ce6dca2
22696bb
ce6dca2
0fc4e06
 
e21f7c8
 
 
0fc4e06
ce6dca2
0fc4e06
 
 
 
 
ce6dca2
 
0fc4e06
 
ce6dca2
ab16048
ce6dca2
 
0fc4e06
e21f7c8
ce6dca2
 
0fc4e06
 
ce6dca2
0fc4e06
 
ce6dca2
de50edd
 
 
e21f7c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import os
import time
from datetime import datetime, timezone, timedelta
from concurrent.futures import ThreadPoolExecutor

import spaces
import torch
import torchvision.models as models
import numpy as np
import gradio as gr
from gradio_imageslider import ImageSlider

from utils import preprocess_img, preprocess_img_from_path, postprocess_img
from vgg.vgg19 import VGG_19
from inference import inference

if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
print('DEVICE:', device)
if device == 'cuda': print('CUDA DEVICE:', torch.cuda.get_device_name())

model = VGG_19().to(device).eval()
for param in model.parameters():
    param.requires_grad = False
segmentation_model = models.segmentation.deeplabv3_resnet101(
    weights='DEFAULT'
).to(device).eval()

style_files = os.listdir('./style_images')
style_options = {' '.join(style_file.split('.')[0].split('_')): f'./style_images/{style_file}' for style_file in style_files}
lrs = np.logspace(np.log10(0.001), np.log10(0.1), 10).tolist()
img_size = 512

cached_style_features = {}
for style_name, style_img_path in style_options.items():
    style_img = preprocess_img_from_path(style_img_path, img_size)[0].to(device)
    with torch.no_grad():
        style_features = model(style_img)
    cached_style_features[style_name] = style_features 

@spaces.GPU(duration=12)
def run(content_image, style_name, style_strength=5):
    yield [None] * 3
    content_img, original_size = preprocess_img(content_image, img_size)
    content_img = content_img.to(device)
    
    print('-'*15)
    print('DATETIME:', datetime.now(timezone.utc) - timedelta(hours=4)) # est
    print('STYLE:', style_name)
    print('CONTENT IMG SIZE:', original_size)
    print('STYLE STRENGTH:', style_strength, f'(lr={lrs[style_strength-1]:.3f})')

    style_features = cached_style_features[style_name]
    
    st = time.time()
    
    if device == 'cuda':
        stream_all = torch.cuda.Stream()
        stream_bg = torch.cuda.Stream()

    def run_inference_cuda(apply_to_background, stream):
        with torch.cuda.stream(stream):
            return run_inference(apply_to_background)
        
    def run_inference(apply_to_background):
        return inference(
            model=model,
            segmentation_model=segmentation_model,
            content_image=content_img,
            style_features=style_features,
            lr=lrs[style_strength-1],
            apply_to_background=apply_to_background
        )

    with ThreadPoolExecutor() as executor:
        if device == 'cuda':
            future_all = executor.submit(run_inference_cuda, False, stream_all)
            future_bg = executor.submit(run_inference_cuda, True, stream_bg)
        else:
            future_all = executor.submit(run_inference, False)
            future_bg = executor.submit(run_inference, True)
        generated_img_all, _ = future_all.result()
        generated_img_bg, bg_ratio = future_bg.result()

    et = time.time()
    print('TIME TAKEN:', et-st)
    
    yield (
        (content_image, postprocess_img(generated_img_all, original_size)),
        (content_image, postprocess_img(generated_img_bg, original_size)),
        f'{bg_ratio:.2f}'
    )

def set_slider(value):
    return gr.update(value=value)

css = """
#container {
    margin: 0 auto;
    max-width: 1100px;
}
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer w/ Salient Object Detection")
    with gr.Row(elem_id='container'):
        with gr.Column():
            content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
            style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
            with gr.Group():
                style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
            submit_button = gr.Button('Submit', variant='primary')
            
            examples = gr.Examples(
                examples=[
                    ['./content_images/Bridge.jpg', 'Starry Night', 6],
                    ['./content_images/GoldenRetriever.jpg', 'Great Wave', 5],
                    ['./content_images/CameraGirl.jpg', 'Bokeh', 10]
                ],
                inputs=[content_image, style_dropdown, style_strength_slider]
            )

        with gr.Column():
            output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
            download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
            with gr.Group():
                output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
                bg_ratio_label = gr.Label(label='Background Ratio')
            download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)

    def save_image(img_tuple1, img_tuple2):
        filename1, filename2 = 'generated-all.jpg', 'generated-bg.jpg'
        img_tuple1[1].save(filename1)
        img_tuple2[1].save(filename2)
        return filename1, filename2
    
    submit_button.click(
        fn=lambda: [gr.update(visible=False) for _ in range(2)],
        outputs=[download_button_1, download_button_2]
    )
        
    submit_button.click(
        fn=run, 
        inputs=[content_image, style_dropdown, style_strength_slider], 
        outputs=[output_image_all, output_image_background, bg_ratio_label]
    ).then(
        fn=save_image,
        inputs=[output_image_all, output_image_background],
        outputs=[download_button_1, download_button_2]
    ).then(
        fn=lambda: [gr.update(visible=True) for _ in range(2)],
        outputs=[download_button_1, download_button_2]
    )

demo.queue = False
demo.config['queue'] = False
demo.launch(show_api=False)