sanchit-gandhi HF staff commited on
Commit
f8dd558
·
1 Parent(s): f833382

fix dtype/device

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -103,10 +103,11 @@ def transcribe(inputs):
103
  yield distil_text, distil_runtime_pipeline, text, runtime_pipeline
104
 
105
  else:
106
- input_features = processor(inputs, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt")
 
107
 
108
  # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
109
- generation_kwargs = dict(input_features, streamer=streamer, max_new_tokens=128, language="en", task="transcribe")
110
  thread = Thread(target=distilled_model.generate, kwargs=generation_kwargs)
111
 
112
  thread.start()
 
103
  yield distil_text, distil_runtime_pipeline, text, runtime_pipeline
104
 
105
  else:
106
+ input_features = processor(inputs, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt").input_features
107
+ input_features = input_features.to(device, dtype=torch_dtype)
108
 
109
  # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
110
+ generation_kwargs = dict(input_features=input_features, streamer=streamer, max_new_tokens=128, language="en", task="transcribe")
111
  thread = Thread(target=distilled_model.generate, kwargs=generation_kwargs)
112
 
113
  thread.start()