liuyizhang
commited on
Commit
•
9003ca5
1
Parent(s):
f0d812d
update app.py
Browse files
app.py
CHANGED
@@ -250,11 +250,13 @@ def set_device():
|
|
250 |
|
251 |
def load_groundingdino_model():
|
252 |
# initialize groundingdino model
|
|
|
253 |
logger.info(f"initialize groundingdino model...")
|
254 |
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
255 |
|
256 |
def load_sam_model():
|
257 |
# initialize SAM
|
|
|
258 |
logger.info(f"initialize SAM model...")
|
259 |
sam_device = device
|
260 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
@@ -263,6 +265,7 @@ def load_sam_model():
|
|
263 |
|
264 |
def load_sd_model():
|
265 |
# initialize stable-diffusion-inpainting
|
|
|
266 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
267 |
sd_pipe = None
|
268 |
if os.environ.get('IS_MY_DEBUG') is None:
|
@@ -276,6 +279,7 @@ def load_sd_model():
|
|
276 |
|
277 |
def load_lama_cleaner_model():
|
278 |
# initialize lama_cleaner
|
|
|
279 |
logger.info(f"initialize lama_cleaner...")
|
280 |
from lama_cleaner.helper import (
|
281 |
load_img,
|
@@ -359,6 +363,7 @@ class Ram_Predictor(RamPredictor):
|
|
359 |
|
360 |
def load_ram_model():
|
361 |
# load ram model
|
|
|
362 |
model_path = "./checkpoints/ram_epoch12.pth"
|
363 |
ram_config = dict(
|
364 |
model=dict(
|
@@ -674,23 +679,22 @@ def change_radio_display(task_type, mask_source_radio):
|
|
674 |
num_relation_visible = True
|
675 |
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
676 |
|
677 |
-
set_device()
|
678 |
-
load_groundingdino_model()
|
679 |
-
load_sam_model()
|
680 |
-
load_sd_model()
|
681 |
-
load_lama_cleaner_model()
|
682 |
-
load_ram_model()
|
683 |
-
|
684 |
-
os.system("pip list")
|
685 |
-
|
686 |
if __name__ == "__main__":
|
687 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
688 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
689 |
parser.add_argument("--share", action="store_true", help="share the app")
|
690 |
args = parser.parse_args()
|
691 |
-
|
692 |
print(f'args = {args}')
|
693 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
694 |
block = gr.Blocks().queue()
|
695 |
with block:
|
696 |
with gr.Row():
|
|
|
250 |
|
251 |
def load_groundingdino_model():
|
252 |
# initialize groundingdino model
|
253 |
+
global groundingdino_model
|
254 |
logger.info(f"initialize groundingdino model...")
|
255 |
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
256 |
|
257 |
def load_sam_model():
|
258 |
# initialize SAM
|
259 |
+
global sam_model, sam_predictor, sam_mask_generator
|
260 |
logger.info(f"initialize SAM model...")
|
261 |
sam_device = device
|
262 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
|
|
265 |
|
266 |
def load_sd_model():
|
267 |
# initialize stable-diffusion-inpainting
|
268 |
+
global sd_pipe
|
269 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
270 |
sd_pipe = None
|
271 |
if os.environ.get('IS_MY_DEBUG') is None:
|
|
|
279 |
|
280 |
def load_lama_cleaner_model():
|
281 |
# initialize lama_cleaner
|
282 |
+
global lama_cleaner_model
|
283 |
logger.info(f"initialize lama_cleaner...")
|
284 |
from lama_cleaner.helper import (
|
285 |
load_img,
|
|
|
363 |
|
364 |
def load_ram_model():
|
365 |
# load ram model
|
366 |
+
global ram_model
|
367 |
model_path = "./checkpoints/ram_epoch12.pth"
|
368 |
ram_config = dict(
|
369 |
model=dict(
|
|
|
679 |
num_relation_visible = True
|
680 |
return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
|
681 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
682 |
if __name__ == "__main__":
|
683 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
684 |
parser.add_argument("--debug", action="store_true", help="using debug mode")
|
685 |
parser.add_argument("--share", action="store_true", help="share the app")
|
686 |
args = parser.parse_args()
|
|
|
687 |
print(f'args = {args}')
|
688 |
|
689 |
+
set_device()
|
690 |
+
load_groundingdino_model()
|
691 |
+
load_sam_model()
|
692 |
+
load_sd_model()
|
693 |
+
load_lama_cleaner_model()
|
694 |
+
load_ram_model()
|
695 |
+
|
696 |
+
os.system("pip list")
|
697 |
+
|
698 |
block = gr.Blocks().queue()
|
699 |
with block:
|
700 |
with gr.Row():
|