Amitai Getzler commited on
Commit
93c7837
1 Parent(s): dc7652d
Files changed (1) hide show
  1. handler.py +5 -1
handler.py CHANGED
@@ -10,8 +10,12 @@ from typing import Dict, Any
10
  class EndpointHandler:
11
  def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
12
  self.model, self.preprocess_train, self.preprocess_val = (
13
- open_clip.create_model_and_transforms("hf-hub:Styld/marqo-fashionSigLIP", device="cuda" if torch.cuda.is_available() else "cpu")
14
  )
 
 
 
 
15
  self.tokenizer = open_clip.get_tokenizer("hf-hub:Styld/marqo-fashionSigLIP")
16
 
17
  def classify_image(self, candidate_labels, image):
 
10
  class EndpointHandler:
11
  def __init__(self, path="hf-hub:Styld/marqo-fashionSigLIP"):
12
  self.model, self.preprocess_train, self.preprocess_val = (
13
+ open_clip.create_model_and_transforms("hf-hub:Styld/marqo-fashionSigLIP")
14
  )
15
+
16
+ if torch.cuda.is_available():
17
+ self.model = self.model.cuda()
18
+
19
  self.tokenizer = open_clip.get_tokenizer("hf-hub:Styld/marqo-fashionSigLIP")
20
 
21
  def classify_image(self, candidate_labels, image):