radames's picture
Update pipeline.py
6f6eb5f
raw
history blame contribute delete
No virus
850 Bytes
from typing import List
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
MODEL_ID = "facebook/sam-vit-huge"
class PreTrainedPipeline():
def __init__(self, path=""):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.processor = SamProcessor.from_pretrained(MODEL_ID)
self.model = SamModel.from_pretrained(MODEL_ID).to(self.device)
self.model.eval()
self.model = self.model.to(self.device)
def __call__(self, inputs: "Image.Image") -> List[float]:
raw_image = inputs.convert("RGB")
inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
feature_vector = self.model.get_image_embeddings(
inputs["pixel_values"])
return feature_vector.tolist()