yizhangliu
commited on
Commit
β’
0cc37e5
1
Parent(s):
f47bc1e
Update app.py
Browse files
app.py
CHANGED
@@ -8,9 +8,6 @@ import gradio as gr
|
|
8 |
|
9 |
from loguru import logger
|
10 |
|
11 |
-
# os.system("pip install diffuser==0.6.0")
|
12 |
-
# os.system("pip install transformers==4.29.1")
|
13 |
-
|
14 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
15 |
|
16 |
if os.environ.get('IS_MY_DEBUG') is None:
|
@@ -69,7 +66,10 @@ ckpt_repo_id = "ShilongLiu/GroundingDINO"
|
|
69 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
70 |
sam_checkpoint = './sam_vit_h_4b8939.pth'
|
71 |
output_dir = "outputs"
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
os.makedirs(output_dir, exist_ok=True)
|
75 |
groundingdino_model = None
|
@@ -77,8 +77,9 @@ sam_device = None
|
|
77 |
sam_model = None
|
78 |
sam_predictor = None
|
79 |
sam_mask_generator = None
|
80 |
-
|
81 |
lama_cleaner_model= None
|
|
|
82 |
ram_model = None
|
83 |
|
84 |
def get_sam_vit_h_4b8939():
|
@@ -165,16 +166,6 @@ def load_image(image_path):
|
|
165 |
image, _ = transform(image_pil, None) # 3, h, w
|
166 |
return image_pil, image
|
167 |
|
168 |
-
def load_model(model_config_path, model_checkpoint_path, device):
|
169 |
-
args = SLConfig.fromfile(model_config_path)
|
170 |
-
args.device = device
|
171 |
-
model = build_model(args)
|
172 |
-
checkpoint = torch.load(model_checkpoint_path, map_location=device) #"cpu")
|
173 |
-
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
174 |
-
print(load_res)
|
175 |
-
_ = model.eval()
|
176 |
-
return model
|
177 |
-
|
178 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
179 |
caption = caption.lower()
|
180 |
caption = caption.strip()
|
@@ -258,18 +249,21 @@ def mix_masks(imgs):
|
|
258 |
return Image.fromarray(np.uint8(255*re_img))
|
259 |
|
260 |
def set_device():
|
261 |
-
device
|
262 |
-
|
|
|
|
|
|
|
263 |
|
264 |
def load_groundingdino_model():
|
265 |
# initialize groundingdino model
|
266 |
global groundingdino_model
|
267 |
logger.info(f"initialize groundingdino model...")
|
268 |
-
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
|
269 |
|
270 |
def load_sam_model():
|
271 |
# initialize SAM
|
272 |
-
global sam_model, sam_predictor, sam_mask_generator, sam_device
|
273 |
logger.info(f"initialize SAM model...")
|
274 |
sam_device = device
|
275 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
@@ -278,26 +272,26 @@ def load_sam_model():
|
|
278 |
|
279 |
def load_sd_model():
|
280 |
# initialize stable-diffusion-inpainting
|
281 |
-
global
|
282 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
283 |
-
|
284 |
if os.environ.get('IS_MY_DEBUG') is None:
|
285 |
-
|
286 |
"runwayml/stable-diffusion-inpainting",
|
287 |
revision="fp16",
|
288 |
# "stabilityai/stable-diffusion-2-inpainting",
|
289 |
torch_dtype=torch.float16,
|
290 |
)
|
291 |
-
|
292 |
|
293 |
def load_lama_cleaner_model():
|
294 |
# initialize lama_cleaner
|
295 |
-
global lama_cleaner_model
|
296 |
logger.info(f"initialize lama_cleaner...")
|
297 |
|
298 |
lama_cleaner_model = ModelManager(
|
299 |
name='lama',
|
300 |
-
device=
|
301 |
)
|
302 |
|
303 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
@@ -517,6 +511,7 @@ mask_source_segment = "type what to detect below"
|
|
517 |
|
518 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
519 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
|
|
|
520 |
if (task_type == 'relate anything'):
|
521 |
output_images = relate_anything(input_image['image'], num_relation)
|
522 |
return output_images, gr.Gallery.update(label='relate images')
|
@@ -566,7 +561,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
566 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
567 |
)
|
568 |
if boxes_filt.size(0) == 0:
|
569 |
-
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]
|
570 |
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ')
|
571 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
572 |
|
@@ -640,7 +635,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
|
|
640 |
# inpainting pipeline
|
641 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
642 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
643 |
-
image_inpainting =
|
644 |
else:
|
645 |
# remove from mask
|
646 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
|
@@ -707,6 +702,8 @@ def change_radio_display(task_type, mask_source_radio):
|
|
707 |
|
708 |
def get_model_device(module):
|
709 |
try:
|
|
|
|
|
710 |
if isinstance(module, torch.nn.DataParallel):
|
711 |
module = module.module
|
712 |
for submodule in module.children():
|
@@ -714,8 +711,9 @@ def get_model_device(module):
|
|
714 |
parameters = submodule._parameters
|
715 |
if "weight" in parameters:
|
716 |
return parameters["weight"].device
|
|
|
717 |
except Exception as e:
|
718 |
-
return '
|
719 |
|
720 |
if __name__ == "__main__":
|
721 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
@@ -732,10 +730,12 @@ if __name__ == "__main__":
|
|
732 |
load_lama_cleaner_model()
|
733 |
load_ram_model()
|
734 |
|
735 |
-
os.
|
|
|
|
|
736 |
print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
|
737 |
print(f'sam_model__{get_model_device(sam_model)}')
|
738 |
-
print(f'sd_model__{get_model_device(
|
739 |
print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
|
740 |
print(f'ram_model__{get_model_device(ram_model)}')
|
741 |
|
@@ -790,3 +790,4 @@ if __name__ == "__main__":
|
|
790 |
|
791 |
computer_info()
|
792 |
block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
|
|
|
|
8 |
|
9 |
from loguru import logger
|
10 |
|
|
|
|
|
|
|
11 |
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
12 |
|
13 |
if os.environ.get('IS_MY_DEBUG') is None:
|
|
|
66 |
ckpt_filenmae = "groundingdino_swint_ogc.pth"
|
67 |
sam_checkpoint = './sam_vit_h_4b8939.pth'
|
68 |
output_dir = "outputs"
|
69 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
70 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
71 |
+
else:
|
72 |
+
device = 'cpu'
|
73 |
|
74 |
os.makedirs(output_dir, exist_ok=True)
|
75 |
groundingdino_model = None
|
|
|
77 |
sam_model = None
|
78 |
sam_predictor = None
|
79 |
sam_mask_generator = None
|
80 |
+
sd_model = None
|
81 |
lama_cleaner_model= None
|
82 |
+
lama_cleaner_model_device = device
|
83 |
ram_model = None
|
84 |
|
85 |
def get_sam_vit_h_4b8939():
|
|
|
166 |
image, _ = transform(image_pil, None) # 3, h, w
|
167 |
return image_pil, image
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"):
|
170 |
caption = caption.lower()
|
171 |
caption = caption.strip()
|
|
|
249 |
return Image.fromarray(np.uint8(255*re_img))
|
250 |
|
251 |
def set_device():
|
252 |
+
global device
|
253 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
254 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
255 |
+
else:
|
256 |
+
device = 'cpu'
|
257 |
|
258 |
def load_groundingdino_model():
|
259 |
# initialize groundingdino model
|
260 |
global groundingdino_model
|
261 |
logger.info(f"initialize groundingdino model...")
|
262 |
+
groundingdino_model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae, device='cpu')
|
263 |
|
264 |
def load_sam_model():
|
265 |
# initialize SAM
|
266 |
+
global sam_model, sam_predictor, sam_mask_generator, sam_device, device
|
267 |
logger.info(f"initialize SAM model...")
|
268 |
sam_device = device
|
269 |
sam_model = build_sam(checkpoint=sam_checkpoint).to(sam_device)
|
|
|
272 |
|
273 |
def load_sd_model():
|
274 |
# initialize stable-diffusion-inpainting
|
275 |
+
global sd_model, device
|
276 |
logger.info(f"initialize stable-diffusion-inpainting...")
|
277 |
+
sd_model = None
|
278 |
if os.environ.get('IS_MY_DEBUG') is None:
|
279 |
+
sd_model = StableDiffusionInpaintPipeline.from_pretrained(
|
280 |
"runwayml/stable-diffusion-inpainting",
|
281 |
revision="fp16",
|
282 |
# "stabilityai/stable-diffusion-2-inpainting",
|
283 |
torch_dtype=torch.float16,
|
284 |
)
|
285 |
+
sd_model = sd_model.to(device)
|
286 |
|
287 |
def load_lama_cleaner_model():
|
288 |
# initialize lama_cleaner
|
289 |
+
global lama_cleaner_model, device
|
290 |
logger.info(f"initialize lama_cleaner...")
|
291 |
|
292 |
lama_cleaner_model = ModelManager(
|
293 |
name='lama',
|
294 |
+
device=lama_cleaner_model_device,
|
295 |
)
|
296 |
|
297 |
def lama_cleaner_process(image, mask, cleaner_size_limit=1080):
|
|
|
511 |
|
512 |
def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
|
513 |
iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, cleaner_size_limit=1080):
|
514 |
+
|
515 |
if (task_type == 'relate anything'):
|
516 |
output_images = relate_anything(input_image['image'], num_relation)
|
517 |
return output_images, gr.Gallery.update(label='relate images')
|
|
|
561 |
groundingdino_model, image, text_prompt, box_threshold, text_threshold, device=groundingdino_device
|
562 |
)
|
563 |
if boxes_filt.size(0) == 0:
|
564 |
+
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
|
565 |
return [], gr.Gallery.update(label='No objects detected, please try others.ππππ')
|
566 |
boxes_filt_ori = copy.deepcopy(boxes_filt)
|
567 |
|
|
|
635 |
# inpainting pipeline
|
636 |
image_source_for_inpaint = image_pil.resize((512, 512))
|
637 |
image_mask_for_inpaint = mask_pil.resize((512, 512))
|
638 |
+
image_inpainting = sd_model(prompt=inpaint_prompt, image=image_source_for_inpaint, mask_image=image_mask_for_inpaint).images[0]
|
639 |
else:
|
640 |
# remove from mask
|
641 |
logger.info(f'run_anything_task_[{file_temp}]_{task_type}_5_')
|
|
|
702 |
|
703 |
def get_model_device(module):
|
704 |
try:
|
705 |
+
if module is None:
|
706 |
+
return 'None'
|
707 |
if isinstance(module, torch.nn.DataParallel):
|
708 |
module = module.module
|
709 |
for submodule in module.children():
|
|
|
711 |
parameters = submodule._parameters
|
712 |
if "weight" in parameters:
|
713 |
return parameters["weight"].device
|
714 |
+
return 'UnKnown'
|
715 |
except Exception as e:
|
716 |
+
return 'Error'
|
717 |
|
718 |
if __name__ == "__main__":
|
719 |
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
|
|
|
730 |
load_lama_cleaner_model()
|
731 |
load_ram_model()
|
732 |
|
733 |
+
if os.environ.get('IS_MY_DEBUG') is None:
|
734 |
+
os.system("pip list")
|
735 |
+
|
736 |
print(f'groundingdino_model__{get_model_device(groundingdino_model)}')
|
737 |
print(f'sam_model__{get_model_device(sam_model)}')
|
738 |
+
print(f'sd_model__{get_model_device(sd_model)}')
|
739 |
print(f'lama_cleaner_model__{get_model_device(lama_cleaner_model)}')
|
740 |
print(f'ram_model__{get_model_device(ram_model)}')
|
741 |
|
|
|
790 |
|
791 |
computer_info()
|
792 |
block.launch(server_name='0.0.0.0', debug=args.debug, share=args.share)
|
793 |
+
|