asigalov61 commited on
Commit
5179cdc
1 Parent(s): 602f44d

Update pytorch_utils.py

Browse files
Files changed (1) hide show
  1. pytorch_utils.py +3 -4
pytorch_utils.py CHANGED
@@ -53,10 +53,9 @@ def forward(model, x, batch_size):
53
  batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device)
54
  pointer += batch_size
55
 
56
- with torch.no_grad():
57
- with torch.amp.autocast(device_type='cuda'):
58
- model.eval()
59
- batch_output_dict = model(batch_waveform)
60
 
61
  for key in batch_output_dict.keys():
62
  append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())
 
53
  batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device)
54
  pointer += batch_size
55
 
56
+ with torch.inference_mode():
57
+ model.eval()
58
+ batch_output_dict = model(batch_waveform)
 
59
 
60
  for key in batch_output_dict.keys():
61
  append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())