|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import cv2 |
|
|
|
def show_anns(anns): |
|
if len(anns) == 0: |
|
return |
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
|
ax = plt.gca() |
|
ax.set_autoscale_on(False) |
|
|
|
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4)) |
|
img[:,:,3] = 0 |
|
for ann in sorted_anns: |
|
m = ann['segmentation'] |
|
color_mask = np.concatenate([np.random.random(3), [0.35]]) |
|
img[m] = color_mask |
|
ax.imshow(img) |
|
|
|
import sys |
|
sys.path.append("..") |
|
from tinysam import sam_model_registry, SamHierarchicalMaskGenerator |
|
|
|
model_type = "vit_t" |
|
sam = sam_model_registry[model_type](checkpoint="./weights/tinysam.pth") |
|
sam.eval() |
|
mask_generator = SamHierarchicalMaskGenerator(sam) |
|
|
|
|
|
|
|
image = cv2.imread('fig/picture3.jpg') |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
masks = mask_generator.hierarchical_generate(image) |
|
|
|
plt.figure(figsize=(20,20)) |
|
plt.imshow(image) |
|
show_anns(masks) |
|
plt.axis('off') |
|
plt.savefig("test_everthing.png") |
|
|
|
|