Refactor image post-processing in gradio_demo_rgb2x.py
Browse files
rgb2x/gradio_demo_rgb2x.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import spaces
|
|
|
2 |
import os
|
3 |
from typing import cast
|
4 |
import gradio as gr
|
@@ -94,7 +95,11 @@ def generate(
|
|
94 |
|
95 |
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
96 |
return_list.append(generated_image)
|
97 |
-
|
|
|
|
|
|
|
|
|
98 |
return return_list
|
99 |
|
100 |
|
|
|
1 |
import spaces
|
2 |
+
import numpy as np
|
3 |
import os
|
4 |
from typing import cast
|
5 |
import gradio as gr
|
|
|
95 |
|
96 |
generated_image = (generated_image, f"Generated {aov_name} {i}")
|
97 |
return_list.append(generated_image)
|
98 |
+
|
99 |
+
def post_process_image(img, **kwargs):
|
100 |
+
return (img.cpu().numpy().transpose(1, 2, 0), kwargs.get("label", "Image"))
|
101 |
+
|
102 |
+
return_list.append(post_process_image(photo, label="Input Image"))
|
103 |
return return_list
|
104 |
|
105 |
|