radames commited on
Commit
6f6eb5f
1 Parent(s): dac7840

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +4 -3
pipeline.py CHANGED
@@ -4,14 +4,15 @@ from transformers import SamModel, SamProcessor
4
  from PIL import Image
5
  import numpy as np
6
 
 
 
7
  class PreTrainedPipeline():
8
  def __init__(self, path=""):
9
 
10
  self.device = torch.device(
11
  "cuda" if torch.cuda.is_available() else "cpu")
12
- self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
13
- self.model = SamModel.from_pretrained(
14
- "facebook/sam-vit-base").to(self.device)
15
  self.model.eval()
16
  self.model = self.model.to(self.device)
17
 
 
4
  from PIL import Image
5
  import numpy as np
6
 
7
+ MODEL_ID = "facebook/sam-vit-huge"
8
+
9
  class PreTrainedPipeline():
10
  def __init__(self, path=""):
11
 
12
  self.device = torch.device(
13
  "cuda" if torch.cuda.is_available() else "cpu")
14
+ self.processor = SamProcessor.from_pretrained(MODEL_ID)
15
+ self.model = SamModel.from_pretrained(MODEL_ID).to(self.device)
 
16
  self.model.eval()
17
  self.model = self.model.to(self.device)
18