Image Classification
timm
drhead commited on
Commit
2c1eb35
1 Parent(s): 5cb75da

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +1 -1
inference_gradio.py CHANGED
@@ -137,7 +137,7 @@ def create_tags(image, threshold):
137
  tensor = transform(img).unsqueeze(0) # type: torch.Tensor
138
 
139
  if torch.cuda.is_available():
140
- tensor.cuda()
141
  if torch.cuda.get_device_capability()[0] >= 7:
142
  tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
143
 
 
137
  tensor = transform(img).unsqueeze(0) # type: torch.Tensor
138
 
139
  if torch.cuda.is_available():
140
+ tensor = tensor.cuda()
141
  if torch.cuda.get_device_capability()[0] >= 7:
142
  tensor = tensor.to(dtype=torch.float16, memory_format=torch.channels_last)
143