Yuxiang Wang commited on
Commit
438a207
·
1 Parent(s): 97f07be

debug:cpu/gpu device settings

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. env.py +3 -1
  3. inference_resnet.py +1 -0
  4. inference_sam.py +2 -3
.gitignore CHANGED
@@ -8,3 +8,6 @@ images/
8
  *.pyd
9
  *.swp
10
  *.__pycache__
 
 
 
 
8
  *.pyd
9
  *.swp
10
  *.__pycache__
11
+
12
+ model/
13
+ model_classification/
env.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import subprocess
3
  import importlib.metadata
4
 
@@ -20,6 +21,7 @@ def config_env():
20
 
21
  name_to_command = {'segment_anything':'git+https://github.com/facebookresearch/segment-anything.git',
22
  'panopticapi':'git+https://github.com/cocodataset/panopticapi.git',
 
23
  'torch':'torch --index-url https://download.pytorch.org/whl/cu118',
24
  'torchvision':'torchvision --index-url https://download.pytorch.org/whl/cu118',
25
  }
@@ -28,7 +30,7 @@ def config_env():
28
  if env_name == 'fossil': # in case pkgs installed to unexpected env during local dev
29
  for package, version in packages_to_install:
30
  package_spec = f"{package}=={version}" if version else package
31
- package_spec = name_to_command[package_spec] if package_spec in name_to_command
32
  if not is_pkg_installed(package):
33
  #TODO
34
  if package=='torch' or 'torchvision' or 'tensorflow':
 
1
  import os
2
+ import sys
3
  import subprocess
4
  import importlib.metadata
5
 
 
21
 
22
  name_to_command = {'segment_anything':'git+https://github.com/facebookresearch/segment-anything.git',
23
  'panopticapi':'git+https://github.com/cocodataset/panopticapi.git',
24
+ #TODO
25
  'torch':'torch --index-url https://download.pytorch.org/whl/cu118',
26
  'torchvision':'torchvision --index-url https://download.pytorch.org/whl/cu118',
27
  }
 
30
  if env_name == 'fossil': # in case pkgs installed to unexpected env during local dev
31
  for package, version in packages_to_install:
32
  package_spec = f"{package}=={version}" if version else package
33
+ package_spec = name_to_command[package_spec] if package_spec in name_to_command else package_spec
34
  if not is_pkg_installed(package):
35
  #TODO
36
  if package=='torch' or 'torchvision' or 'tensorflow':
inference_resnet.py CHANGED
@@ -74,6 +74,7 @@ def get_triplet_model(input_shape = (600, 600, 3),
74
  backbone = backbone_class(input_shape=input_shape, include_top=False)
75
  if load_weights:
76
  model = get_model(backbone_name,input_shape=input_shape)
 
77
  model.load_weights('/users/irodri15/data/irodri15/Fossils/Models/pretrained-herbarium/Resnet50v2_NO_imagenet_None_best_1600.h5')
78
  trw = model.layers[0].get_weights()
79
  backbone.set_weights(trw)
 
74
  backbone = backbone_class(input_shape=input_shape, include_top=False)
75
  if load_weights:
76
  model = get_model(backbone_name,input_shape=input_shape)
77
+ #TODO
78
  model.load_weights('/users/irodri15/data/irodri15/Fossils/Models/pretrained-herbarium/Resnet50v2_NO_imagenet_None_best_1600.h5')
79
  trw = model.layers[0].get_weights()
80
  backbone.set_weights(trw)
inference_sam.py CHANGED
@@ -1,9 +1,8 @@
1
  import torch
2
 
3
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4
- if torch.cuda.is_available():
5
- device = "cuda"
6
- torch.cuda.set_per_process_memory_fraction(0.3, device=device)
7
  else:
8
  device = "cpu"
9
  print(f"Torch device: {device}")
 
1
  import torch
2
 
3
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4
+ if device.type == "cuda":
5
+ torch.cuda.set_per_process_memory_fraction(0.3, device=device.index if device.index is not None else 0)
 
6
  else:
7
  device = "cpu"
8
  print(f"Torch device: {device}")