Peijie commited on
Commit
a410a68
1 Parent(s): ef7d4dd

add gpu support

Browse files
Files changed (3) hide show
  1. app.py +1 -0
  2. utils/load_model.py +2 -0
  3. utils/predict.py +2 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import io
2
  import os
3
  debug = False
 
1
+ import spaces
2
  import io
3
  import os
4
  debug = False
utils/load_model.py CHANGED
@@ -1,10 +1,12 @@
1
 
2
 
 
3
  import torch
4
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
 
6
  from .model import OwlViTForClassification
7
 
 
8
  def load_xclip(device: str = "cuda:0",
9
  n_classes: int = 183,
10
  use_teacher_logits: bool = False,
 
1
 
2
 
3
+ import spaces
4
  import torch
5
  from transformers import OwlViTProcessor, OwlViTForObjectDetection
6
 
7
  from .model import OwlViTForClassification
8
 
9
+ @spaces.GPU
10
  def load_xclip(device: str = "cuda:0",
11
  n_classes: int = 183,
12
  use_teacher_logits: bool = False,
utils/predict.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import PIL
2
  import torch
3
 
@@ -29,7 +30,7 @@ def encode_descs_xclip(owlvit_det_processor: callable, model: callable, descs: l
29
  # text_embeds = torch.cat(text_embeds, dim=0)
30
  # text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
31
  # return text_embeds.to(device)
32
-
33
  def xclip_pred(new_desc: dict,
34
  new_part_mask: dict,
35
  new_class: str,
 
1
+ import spaces
2
  import PIL
3
  import torch
4
 
 
30
  # text_embeds = torch.cat(text_embeds, dim=0)
31
  # text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)
32
  # return text_embeds.to(device)
33
+ @spaces.GPU
34
  def xclip_pred(new_desc: dict,
35
  new_part_mask: dict,
36
  new_class: str,