hyz317 commited on
Commit
8629ffb
·
verified ·
1 Parent(s): 9d6a6e1

Update infer_api.py

Browse files
Files changed (1) hide show
  1. infer_api.py +3 -3
infer_api.py CHANGED
@@ -106,7 +106,8 @@ import torch
106
  from typing import Tuple
107
 
108
  @spaces.GPU
109
- def _warmup(glctx, device=None):
 
110
  device = 'cuda' if device is None else device
111
  #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
112
  def tensor(*args, **kwargs):
@@ -115,8 +116,7 @@ def _warmup(glctx, device=None):
115
  tri = tensor([[0, 1, 2]], dtype=torch.int32)
116
  dr.rasterize(glctx, pos, tri, resolution=[256, 256])
117
 
118
- _glctx = dr.RasterizeCudaContext(device=None)
119
- _warmup(_glctx, device)
120
 
121
  #### TEST END ####
122
 
 
106
  from typing import Tuple
107
 
108
  @spaces.GPU
109
+ def _warmup(device=None):
110
+ glctx = dr.RasterizeCudaContext(device=None)
111
  device = 'cuda' if device is None else device
112
  #windows workaround for https://github.com/NVlabs/nvdiffrast/issues/59
113
  def tensor(*args, **kwargs):
 
116
  tri = tensor([[0, 1, 2]], dtype=torch.int32)
117
  dr.rasterize(glctx, pos, tri, resolution=[256, 256])
118
 
119
+ _warmup(device)
 
120
 
121
  #### TEST END ####
122