evf-sam2 / inference.py
wondervictor's picture
Update inference.py
112d224 verified
import argparse
import os
import sys
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoTokenizer, BitsAndBytesConfig
from model.segment_anything.utils.transforms import ResizeLongestSide
def parse_args(args):
parser = argparse.ArgumentParser(description="EVF infer")
parser.add_argument("--version", required=True)
parser.add_argument("--vis_save_path", default="./infer", type=str)
parser.add_argument(
"--precision",
default="fp16",
type=str,
choices=["fp32", "bf16", "fp16"],
help="precision for inference",
)
parser.add_argument("--image_size", default=224, type=int, help="image size")
parser.add_argument("--model_max_length", default=512, type=int)
parser.add_argument("--local-rank", default=0, type=int, help="node rank")
parser.add_argument("--load_in_8bit", action="store_true", default=False)
parser.add_argument("--load_in_4bit", action="store_true", default=False)
parser.add_argument("--model_type", default="ori", choices=["ori", "effi", "sam2"])
parser.add_argument("--image_path", type=str, default="assets/zebra.jpg")
parser.add_argument("--prompt", type=str, default="zebra top left")
return parser.parse_args(args)
def sam_preprocess(
x: np.ndarray,
pixel_mean=torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1),
pixel_std=torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1),
img_size=1024,
model_type="ori") -> torch.Tensor:
'''
preprocess of Segment Anything Model, including scaling, normalization and padding.
preprocess differs between SAM and Effi-SAM, where Effi-SAM use no padding.
input: ndarray
output: torch.Tensor
'''
assert img_size==1024, \
"both SAM and Effi-SAM receive images of size 1024^2, don't change this setting unless you're sure that your employed model works well with another size."
# Normalize colors
if model_type=="ori":
x = ResizeLongestSide(img_size).apply_image(x)
h, w = resize_shape = x.shape[:2]
x = torch.from_numpy(x).permute(2,0,1).contiguous()
x = (x - pixel_mean) / pixel_std
# Pad
padh = img_size - h
padw = img_size - w
x = F.pad(x, (0, padw, 0, padh))
else:
x = torch.from_numpy(x).permute(2,0,1).contiguous()
x = F.interpolate(x.unsqueeze(0), (img_size, img_size), mode="bilinear", align_corners=False).squeeze(0)
x = (x - pixel_mean) / pixel_std
resize_shape = None
return x, resize_shape
def beit3_preprocess(x: np.ndarray, img_size=224) -> torch.Tensor:
'''
preprocess for BEIT-3 model.
input: ndarray
output: torch.Tensor
'''
beit_preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((img_size, img_size), interpolation=InterpolationMode.BICUBIC),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
return beit_preprocess(x)
def init_models(args):
tokenizer = AutoTokenizer.from_pretrained(
args.version,
padding_side="right",
use_fast=False,
)
torch_dtype = torch.float32
if args.precision == "bf16":
torch_dtype = torch.bfloat16
elif args.precision == "fp16":
torch_dtype = torch.half
kwargs = {"torch_dtype": torch_dtype}
if args.load_in_4bit:
kwargs.update(
{
"torch_dtype": torch.half,
"quantization_config": BitsAndBytesConfig(
llm_int8_skip_modules=["visual_model"],
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
}
)
elif args.load_in_8bit:
kwargs.update(
{
"torch_dtype": torch.half,
"quantization_config": BitsAndBytesConfig(
llm_int8_skip_modules=["visual_model"],
load_in_8bit=True,
),
}
)
if args.model_type=="ori":
from model.evf_sam import EvfSamModel
model = EvfSamModel.from_pretrained(
args.version, low_cpu_mem_usage=True, **kwargs
)
elif args.model_type=="effi":
from model.evf_effisam import EvfEffiSamModel
model = EvfEffiSamModel.from_pretrained(
args.version, low_cpu_mem_usage=True, **kwargs
)
elif args.model_type=="sam2":
from model.evf_sam2 import EvfSam2Model
model = EvfSam2Model.from_pretrained(
args.version, low_cpu_mem_usage=True, **kwargs
)
if (not args.load_in_4bit) and (not args.load_in_8bit):
model = model.cuda()
model.eval()
return tokenizer, model
def main(args):
args = parse_args(args)
# clarify IO
image_path = args.image_path
if not os.path.exists(image_path):
print("File not found in {}".format(image_path))
exit()
prompt = args.prompt
os.makedirs(args.vis_save_path, exist_ok=True)
save_path = "{}/{}_vis.png".format(
args.vis_save_path, os.path.basename(image_path).split(".")[0]
)
# initialize model and tokenizer
tokenizer, model = init_models(args)
# preprocess
image_np = cv2.imread(image_path)
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
original_size_list = [image_np.shape[:2]]
image_beit = beit3_preprocess(image_np, args.image_size).to(dtype=model.dtype, device=model.device)
image_sam, resize_shape = sam_preprocess(image_np, model_type=args.model_type)
image_sam = image_sam.to(dtype=model.dtype, device=model.device)
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device=model.device)
# infer
pred_mask = model.inference(
image_sam.unsqueeze(0),
image_beit.unsqueeze(0),
input_ids,
resize_list=[resize_shape],
original_size_list=original_size_list,
)
pred_mask = pred_mask.detach().cpu().numpy()[0]
pred_mask = pred_mask > 0
# save visualization
save_img = image_np.copy()
save_img[pred_mask] = (
image_np * 0.5
+ pred_mask[:, :, None].astype(np.uint8) * np.array([50, 120, 220]) * 0.5
)[pred_mask]
save_img = cv2.cvtColor(save_img, cv2.COLOR_RGB2BGR)
cv2.imwrite(save_path, save_img)
if __name__ == "__main__":
main(sys.argv[1:])