Update run/gradio_ootd.py
Browse files- run/gradio_ootd.py +5 -2
run/gradio_ootd.py
CHANGED
@@ -10,7 +10,6 @@ from utils_ootd import get_mask_location
|
|
10 |
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
|
11 |
sys.path.insert(0, str(PROJECT_ROOT))
|
12 |
|
13 |
-
import time
|
14 |
from preprocess.openpose.run_openpose import OpenPose
|
15 |
from preprocess.humanparsing.run_parsing import Parsing
|
16 |
from ootd.inference_ootd_hd import OOTDiffusionHD
|
@@ -36,6 +35,10 @@ garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
|
|
36 |
model_dc = os.path.join(example_path, 'model/model_8.png')
|
37 |
garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
|
38 |
|
|
|
|
|
|
|
|
|
39 |
def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
|
40 |
model_type = 'hd'
|
41 |
category = 0 # 0:upperbody; 1:lowerbody; 2:dress
|
@@ -257,4 +260,4 @@ with block:
|
|
257 |
ips_dc = [vton_img_dc, garm_img_dc, category_dc, n_samples_dc, n_steps_dc, image_scale_dc, seed_dc]
|
258 |
run_button_dc.click(fn=process_dc, inputs=ips_dc, outputs=[result_gallery_dc])
|
259 |
|
260 |
-
block.launch(
|
|
|
10 |
PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
|
11 |
sys.path.insert(0, str(PROJECT_ROOT))
|
12 |
|
|
|
13 |
from preprocess.openpose.run_openpose import OpenPose
|
14 |
from preprocess.humanparsing.run_parsing import Parsing
|
15 |
from ootd.inference_ootd_hd import OOTDiffusionHD
|
|
|
35 |
model_dc = os.path.join(example_path, 'model/model_8.png')
|
36 |
garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
|
37 |
|
38 |
+
|
39 |
+
import spaces
|
40 |
+
|
41 |
+
@spaces.GPU
|
42 |
def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
|
43 |
model_type = 'hd'
|
44 |
category = 0 # 0:upperbody; 1:lowerbody; 2:dress
|
|
|
260 |
ips_dc = [vton_img_dc, garm_img_dc, category_dc, n_samples_dc, n_steps_dc, image_scale_dc, seed_dc]
|
261 |
run_button_dc.click(fn=process_dc, inputs=ips_dc, outputs=[result_gallery_dc])
|
262 |
|
263 |
+
block.launch()
|