avsolatorio commited on
Commit
df8632c
·
1 Parent(s): 3f9fe25

Fix device detection

Browse files

Signed-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>

Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -28,7 +28,7 @@ def get_model(model_name: str = None):
28
  if _MODEL.get(model_name) is None:
29
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
30
 
31
- if torch.cuda.is_available() and _MODEL[model_name].device.type.startswith("cuda"):
32
  _MODEL[model_name] = _MODEL[model_name].to("cuda")
33
 
34
  print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")
 
28
  if _MODEL.get(model_name) is None:
29
  _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
30
 
31
+ if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
32
  _MODEL[model_name] = _MODEL[model_name].to("cuda")
33
 
34
  print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")