|
from transformers import PreTrainedModel |
|
from config_segvol import SegVolConfig |
|
from network.model import SegVol |
|
from segment_anything_volumetric import sam_model_registry |
|
|
|
|
|
class SegVolModel(PreTrainedModel): |
|
config_class = SegVolConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
sam_model = sam_model_registry['vit'](self.config.patch_size, self.config.spatial_size) |
|
self.model = SegVol( |
|
image_encoder=sam_model.image_encoder, |
|
mask_decoder=sam_model.mask_decoder, |
|
prompt_encoder=sam_model.prompt_encoder, |
|
roi_size=self.config.spatial_size, |
|
patch_size=self.config.patch_size, |
|
test_mode=self.config.test_mode, |
|
) |
|
|
|
def forward(self, image, text=None, boxes=None, points=None, **kwargs): |
|
return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs) |
|
|
|
|