ovshake commited on
Commit
7f50dc5
1 Parent(s): 45007eb

fix gpu error

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -85,7 +85,8 @@ def process_image(args, inpainting_pipeline, net):
85
  img_with_green_bg = img
86
  image_tensor = transform_rgb(img_with_green_bg)
87
  image_tensor = image_tensor.unsqueeze(0)
88
- output_tensor = net(image_tensor.to(device))
 
89
  output_tensor = F.log_softmax(output_tensor[0], dim=1)
90
  output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
91
  output_tensor = torch.squeeze(output_tensor, dim=0)
 
85
  img_with_green_bg = img
86
  image_tensor = transform_rgb(img_with_green_bg)
87
  image_tensor = image_tensor.unsqueeze(0)
88
+ with torch.autocast(device_type=device):
89
+ output_tensor = net(image_tensor.to(device))
90
  output_tensor = F.log_softmax(output_tensor[0], dim=1)
91
  output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
92
  output_tensor = torch.squeeze(output_tensor, dim=0)