Zhibinhong commited on
Commit
6b683e4
1 Parent(s): f1779a1

Update visual_chatgpt.py

Browse files
Files changed (1) hide show
  1. visual_chatgpt.py +4 -4
visual_chatgpt.py CHANGED
@@ -797,7 +797,7 @@ class Segmenting:
797
  print(f"Inintializing Segmentation to {device}")
798
  self.device = device
799
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
800
- self.model_checkpoint_path = "/repository/checkpoints/sam"
801
 
802
  self.download_parameters()
803
  self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
@@ -813,9 +813,9 @@ class Segmenting:
813
  print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
814
  url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
815
  if not os.path.exists(path):
816
- print("我进来了!")
817
  # wget.download(url,out=self.model_checkpoint_path)
818
- wget.download(url,out=path)
819
 
820
  def show_mask(self, mask, ax, random_color=False):
821
  if random_color:
@@ -917,7 +917,7 @@ class Text2Box:
917
  print(f"Initializing ObjectDetection to {device}")
918
  self.device = device
919
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
920
- self.model_checkpoint_path = "repository/checkpoints/groundingdino"
921
  self.model_config_path = "repository/checkpoints/grounding_config.py"
922
  self.download_parameters()
923
  self.box_threshold = 0.3
 
797
  print(f"Inintializing Segmentation to {device}")
798
  self.device = device
799
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
800
+ self.model_checkpoint_path = os.path.abspath("/repository/checkpoints/sam")
801
 
802
  self.download_parameters()
803
  self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
 
813
  print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
814
  url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
815
  if not os.path.exists(path):
816
+ print("I'm in!")
817
  # wget.download(url,out=self.model_checkpoint_path)
818
+ wget.download(url,out=self.model_checkpoint_path)
819
 
820
  def show_mask(self, mask, ax, random_color=False):
821
  if random_color:
 
917
  print(f"Initializing ObjectDetection to {device}")
918
  self.device = device
919
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
920
+ self.model_checkpoint_path = os.path.abspath("repository/checkpoints/groundingdino")
921
  self.model_config_path = "repository/checkpoints/grounding_config.py"
922
  self.download_parameters()
923
  self.box_threshold = 0.3