jamino30 commited on
Commit
0fc4e06
1 Parent(s): 7e83030

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +55 -33
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import time
3
  from datetime import datetime, timezone, timedelta
 
4
 
5
  import spaces
6
  import torch
@@ -38,9 +39,9 @@ for style_name, style_img_path in style_options.items():
38
  style_features = model(style_img)
39
  cached_style_features[style_name] = style_features
40
 
41
- @spaces.GPU(duration=10)
42
- def run(content_image, style_name, style_strength=5, apply_to_background=False, progress=gr.Progress(track_tqdm=True)):
43
- yield None
44
  content_img, original_size = preprocess_img(content_image, img_size)
45
  content_img = content_img.to(device)
46
 
@@ -53,18 +54,37 @@ def run(content_image, style_name, style_strength=5, apply_to_background=False,
53
  style_features = cached_style_features[style_name]
54
 
55
  st = time.time()
56
- generated_img = inference(
57
- model=model,
58
- segmentation_model=segmentation_model,
59
- content_image=content_img,
60
- style_features=style_features,
61
- lr=lrs[style_strength-1],
62
- apply_to_background=apply_to_background
63
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  et = time.time()
65
  print('TIME TAKEN:', et-st)
66
 
67
- yield (content_image, postprocess_img(generated_img, original_size))
 
 
 
68
 
69
  def set_slider(value):
70
  return gr.update(value=value)
@@ -77,52 +97,54 @@ css = """
77
  """
78
 
79
  with gr.Blocks(css=css) as demo:
80
- gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Neural Style Transfer</h1>")
81
  with gr.Row(elem_id='container'):
82
  with gr.Column():
83
  content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
84
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
85
  with gr.Group():
86
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
87
- apply_to_background = gr.Checkbox(label='Apply to background only', info='Note: This experimental feature may not always detect desired backgrounds.')
88
  submit_button = gr.Button('Submit', variant='primary')
89
 
90
  examples = gr.Examples(
91
  examples=[
92
- ['./content_images/Bridge.jpg', 'Starry Night'],
93
- ['./content_images/GoldenRetriever.jpg', 'Great Wave'],
94
- ['./content_images/CameraGirl.jpg', 'Bokeh']
95
  ],
96
- inputs=[content_image, style_dropdown]
97
  )
98
 
99
  with gr.Column():
100
- output_image = ImageSlider(position=0.15, label='Output', show_label=True, type='pil', interactive=False, show_download_button=False)
101
- download_button = gr.DownloadButton(label='Download Image', visible=False)
 
 
102
 
103
- def save_image(img_tuple):
104
- filename = 'generated.jpg'
105
- img_tuple[1].save(filename)
106
- return filename
 
107
 
108
  submit_button.click(
109
- fn=lambda: gr.update(visible=False),
110
- outputs=[download_button]
111
  )
112
 
113
  submit_button.click(
114
  fn=run,
115
- inputs=[content_image, style_dropdown, style_strength_slider, apply_to_background],
116
- outputs=[output_image]
117
  ).then(
118
  fn=save_image,
119
- inputs=[output_image],
120
- outputs=[download_button]
121
  ).then(
122
- fn=lambda: gr.update(visible=True),
123
- outputs=[download_button]
124
  )
125
 
126
  demo.queue = False
127
  demo.config['queue'] = False
128
- demo.launch(show_api=False)
 
1
  import os
2
  import time
3
  from datetime import datetime, timezone, timedelta
4
+ from concurrent.futures import ThreadPoolExecutor
5
 
6
  import spaces
7
  import torch
 
39
  style_features = model(style_img)
40
  cached_style_features[style_name] = style_features
41
 
42
+ @spaces.GPU(duration=12)
43
+ def run(content_image, style_name, style_strength=5, progress=gr.Progress(total=2)):
44
+ yield None, None
45
  content_img, original_size = preprocess_img(content_image, img_size)
46
  content_img = content_img.to(device)
47
 
 
54
  style_features = cached_style_features[style_name]
55
 
56
  st = time.time()
57
+
58
+ stream_all = torch.cuda.Stream()
59
+ stream_bg = torch.cuda.Stream()
60
+
61
+ def run_inference(apply_to_background, stream):
62
+ with torch.cuda.stream(stream):
63
+ return inference(
64
+ model=model,
65
+ segmentation_model=segmentation_model,
66
+ content_image=content_img,
67
+ style_features=style_features,
68
+ lr=lrs[style_strength-1],
69
+ apply_to_background=apply_to_background
70
+ )
71
+
72
+ with ThreadPoolExecutor() as executor:
73
+ progress(0, desc='Styling image')
74
+ future_all = executor.submit(run_inference, False, stream_all)
75
+ progress(1, desc='Styling background')
76
+ future_bg = executor.submit(run_inference, True, stream_bg)
77
+ generated_img_all = future_all.result()
78
+ generated_img_bg = future_bg.result()
79
+ progress(2)
80
+
81
  et = time.time()
82
  print('TIME TAKEN:', et-st)
83
 
84
+ yield (
85
+ (content_image, postprocess_img(generated_img_all, original_size)),
86
+ (content_image, postprocess_img(generated_img_bg, original_size))
87
+ )
88
 
89
  def set_slider(value):
90
  return gr.update(value=value)
 
97
  """
98
 
99
  with gr.Blocks(css=css) as demo:
100
+ gr.HTML("<h1 style='text-align: center; padding: 10px'>🖼️ Dual Style Transfer: Artistic Output and Background Transformation</h1>")
101
  with gr.Row(elem_id='container'):
102
  with gr.Column():
103
  content_image = gr.Image(label='Content', type='pil', sources=['upload', 'webcam', 'clipboard'], format='jpg', show_download_button=False)
104
  style_dropdown = gr.Radio(choices=list(style_options.keys()), label='Style', value='Starry Night', type='value')
105
  with gr.Group():
106
  style_strength_slider = gr.Slider(label='Style Strength', minimum=1, maximum=10, step=1, value=5, info='Higher values add artistic flair, lower values add a realistic feel.')
 
107
  submit_button = gr.Button('Submit', variant='primary')
108
 
109
  examples = gr.Examples(
110
  examples=[
111
+ ['./content_images/Bridge.jpg', 'Starry Night', 6],
112
+ ['./content_images/GoldenRetriever.jpg', 'Great Wave', 5],
113
+ ['./content_images/CameraGirl.jpg', 'Bokeh', 10]
114
  ],
115
+ inputs=[content_image, style_dropdown, style_strength_slider]
116
  )
117
 
118
  with gr.Column():
119
+ output_image_all = ImageSlider(position=0.15, label='Styled Image', type='pil', interactive=False, show_download_button=False)
120
+ download_button_1 = gr.DownloadButton(label='Download Styled Image', visible=False)
121
+ output_image_background = ImageSlider(position=0.15, label='Styled Background', type='pil', interactive=False, show_download_button=False)
122
+ download_button_2 = gr.DownloadButton(label='Download Styled Background', visible=False)
123
 
124
+ def save_image(img_tuple1, img_tuple2):
125
+ filename1, filename2 = 'generated-all.jpg', 'generated-bg.jpg'
126
+ img_tuple1[1].save(filename1)
127
+ img_tuple2[1].save(filename2)
128
+ return filename1, filename2
129
 
130
  submit_button.click(
131
+ fn=lambda: [gr.update(visible=False) for _ in range(2)],
132
+ outputs=[download_button_1, download_button_2]
133
  )
134
 
135
  submit_button.click(
136
  fn=run,
137
+ inputs=[content_image, style_dropdown, style_strength_slider],
138
+ outputs=[output_image_all, output_image_background]
139
  ).then(
140
  fn=save_image,
141
+ inputs=[output_image_all, output_image_background],
142
+ outputs=[download_button_1, download_button_2]
143
  ).then(
144
+ fn=lambda: [gr.update(visible=True) for _ in range(2)],
145
+ outputs=[download_button_1, download_button_2]
146
  )
147
 
148
  demo.queue = False
149
  demo.config['queue'] = False
150
+ demo.launch(show_api=False)