segvol / model_segvol.py
yuxin
add model
2fbf9d3
raw
history blame
908 Bytes
from transformers import PreTrainedModel
from config_segvol import SegVolConfig
from network.model import SegVol
from segment_anything_volumetric import sam_model_registry
class SegVolModel(PreTrainedModel):
config_class = SegVolConfig
def __init__(self, config):
super().__init__(config)
sam_model = sam_model_registry['vit'](self.config.patch_size, self.config.spatial_size)
self.model = SegVol(
image_encoder=sam_model.image_encoder,
mask_decoder=sam_model.mask_decoder,
prompt_encoder=sam_model.prompt_encoder,
roi_size=self.config.spatial_size,
patch_size=self.config.patch_size,
test_mode=self.config.test_mode,
)
def forward(self, image, text=None, boxes=None, points=None, **kwargs):
return self.model.forward(image, text=text, boxes=boxes, points=points, **kwargs)