ArkanDash commited on
Commit
a1f1896
1 Parent(s): 450be50

feat: update config.py

Browse files
Files changed (1) hide show
  1. config.py +20 -9
config.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import torch
3
  from multiprocessing import cpu_count
4
 
@@ -21,11 +22,10 @@ class Config:
21
 
22
  @staticmethod
23
  def arg_parse() -> tuple:
 
24
  parser = argparse.ArgumentParser()
25
  parser.add_argument("--port", type=int, default=7865, help="Listen port")
26
- parser.add_argument(
27
- "--pycmd", type=str, default="python", help="Python command"
28
- )
29
  parser.add_argument("--colab", action="store_true", help="Launch in colab")
30
  parser.add_argument(
31
  "--noparallel", action="store_true", help="Disable parallel processing"
@@ -49,6 +49,18 @@ class Config:
49
  cmd_opts.api
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def device_config(self) -> tuple:
53
  if torch.cuda.is_available():
54
  i_device = int(self.device.split(":")[-1])
@@ -60,11 +72,10 @@ class Config:
60
  or "1070" in self.gpu_name
61
  or "1080" in self.gpu_name
62
  ):
63
- print("16系/10系显卡和P40强制单精度")
64
  self.is_half = False
65
-
66
  else:
67
- self.gpu_name = None
68
  self.gpu_mem = int(
69
  torch.cuda.get_device_properties(i_device).total_memory
70
  / 1024
@@ -72,12 +83,12 @@ class Config:
72
  / 1024
73
  + 0.4
74
  )
75
- elif torch.backends.mps.is_available():
76
- print("没有发现支持的N卡, 使用MPS进行推理")
77
  self.device = "mps"
78
  self.is_half = False
79
  else:
80
- print("没有发现支持的N卡, 使用CPU进行推理")
81
  self.device = "cpu"
82
  self.is_half = False
83
 
 
1
  import argparse
2
+ import sys
3
  import torch
4
  from multiprocessing import cpu_count
5
 
 
22
 
23
  @staticmethod
24
  def arg_parse() -> tuple:
25
+ exe = sys.executable or "python"
26
  parser = argparse.ArgumentParser()
27
  parser.add_argument("--port", type=int, default=7865, help="Listen port")
28
+ parser.add_argument("--pycmd", type=str, default=exe, help="Python command")
 
 
29
  parser.add_argument("--colab", action="store_true", help="Launch in colab")
30
  parser.add_argument(
31
  "--noparallel", action="store_true", help="Disable parallel processing"
 
49
  cmd_opts.api
50
  )
51
 
52
+ # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
53
+ # check `getattr` and try it for compatibility
54
+ @staticmethod
55
+ def has_mps() -> bool:
56
+ if not torch.backends.mps.is_available():
57
+ return False
58
+ try:
59
+ torch.zeros(1).to(torch.device("mps"))
60
+ return True
61
+ except Exception:
62
+ return False
63
+
64
  def device_config(self) -> tuple:
65
  if torch.cuda.is_available():
66
  i_device = int(self.device.split(":")[-1])
 
72
  or "1070" in self.gpu_name
73
  or "1080" in self.gpu_name
74
  ):
75
+ print("Found GPU", self.gpu_name, ", force to fp32")
76
  self.is_half = False
 
77
  else:
78
+ print("Found GPU", self.gpu_name)
79
  self.gpu_mem = int(
80
  torch.cuda.get_device_properties(i_device).total_memory
81
  / 1024
 
83
  / 1024
84
  + 0.4
85
  )
86
+ elif self.has_mps():
87
+ print("No supported Nvidia GPU found, use MPS instead")
88
  self.device = "mps"
89
  self.is_half = False
90
  else:
91
+ print("No supported Nvidia GPU found, use CPU instead")
92
  self.device = "cpu"
93
  self.is_half = False
94