Dan Bochman commited on
Commit
9a2f042
1 Parent(s): 2233e64

attempt #31235 to work with ZERO

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -112,10 +112,9 @@ model.to("cuda")
112
 
113
  @torch.inference_mode()
114
  def run_model(input_tensor, height, width):
115
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
116
- output = model(input_tensor)
117
- output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
118
- _, preds = torch.max(output, 1)
119
  return preds
120
 
121
 
@@ -131,7 +130,7 @@ transform_fn = transforms.Compose(
131
 
132
  @spaces.GPU
133
  def segment(image: Image.Image) -> Image.Image:
134
- input_tensor = transform_fn(image).unsqueeze(0)
135
  preds = run_model(input_tensor, height=image.height, width=image.width)
136
  mask = preds.squeeze(0).cpu().numpy()
137
  mask_image = Image.fromarray(mask.astype("uint8"))
 
112
 
113
  @torch.inference_mode()
114
  def run_model(input_tensor, height, width):
115
+ output = model(input_tensor)
116
+ output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
117
+ _, preds = torch.max(output, 1)
 
118
  return preds
119
 
120
 
 
130
 
131
  @spaces.GPU
132
  def segment(image: Image.Image) -> Image.Image:
133
+ input_tensor = transform_fn(image).unsqueeze(0).to("cuda")
134
  preds = run_model(input_tensor, height=image.height, width=image.width)
135
  mask = preds.squeeze(0).cpu().numpy()
136
  mask_image = Image.fromarray(mask.astype("uint8"))