cyberspyde commited on
Commit
99d311a
1 Parent(s): c858f8e

model update

Browse files
Files changed (1) hide show
  1. main.py +13 -41
main.py CHANGED
@@ -1,35 +1,12 @@
1
  from flask import Flask, request, jsonify
2
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
3
  import numpy as np
4
  import torch
5
 
6
  app = Flask(__name__)
7
- model = AutoModelForSpeechSeq2Seq.from_pretrained("GitNazarov/whisper-small-pt-3-uz")
8
- processor = AutoProcessor.from_pretrained("GitNazarov/whisper-small-pt-3-uz")
9
-
10
-
11
- USE_ONNX = False # change this to True if you want to test onnx model
12
- silero_vad_path = 'snakers4/silero-vad'
13
- vad_model, vad_utils = torch.hub.load(silero_vad_path,
14
- model='silero_vad',
15
- force_reload=True,
16
- onnx=USE_ONNX)
17
-
18
- (get_speech_timestamps,
19
- save_audio,
20
- read_audio,
21
- VADIterator,
22
- collect_chunks) = vad_utils
23
- STT_SAMPLE_RATE = 16000
24
-
25
-
26
- def int2float(sound):
27
- abs_max = np.abs(sound).max()
28
- sound = sound.astype('float32')
29
- if abs_max > 0:
30
- sound *= 1/32768
31
- sound = sound.squeeze() # depends on the use case
32
- return sound
33
 
34
  @app.route('/', methods=['GET'])
35
  def index():
@@ -38,21 +15,16 @@ def index():
38
  @app.route('/transcribe', methods=['POST'])
39
  def transcribe():
40
  data_frames = request.data
41
- audio_data = np.frombuffer(data_frames, dtype=np.int16)
42
- audio_float = int2float(audio_data)
43
- final_data = torch.from_numpy(audio_float)
44
- sp_timestamps = get_speech_timestamps(final_data, vad_model, sampling_rate=STT_SAMPLE_RATE)
45
- try:
46
- final_audio_data = collect_chunks(sp_timestamps, final_data)
47
- inputs = processor(final_audio_data, return_tensors="pt", sampling_rate=16000, max_new_tokens=100)
48
- input_features = inputs.input_features
49
- generated_ids = model.generate(inputs=input_features)
50
 
51
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
52
- transcription = ''.join(transcription)
53
- except Exception as e:
54
- transcription = str(e)
55
- return str(transcription), {'Content-Type': 'application/json'}
56
 
57
  if __name__ == '__main__':
58
  app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
  import numpy as np
4
  import torch
5
 
6
  app = Flask(__name__)
7
+ processor = Wav2Vec2Processor.from_pretrained("oyqiz/uzbek_stt")
8
+ model = Wav2Vec2ForCTC.from_pretrained("oyqiz/uzbek_stt")
9
+ SAMPLE_RATE = 16000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  @app.route('/', methods=['GET'])
12
  def index():
 
15
  @app.route('/transcribe', methods=['POST'])
16
  def transcribe():
17
  data_frames = request.data
18
+ audio_np = np.frombuffer(data_frames, dtype=np.int16)
19
+ audio_np = audio_np / np.iinfo(np.int16).max
20
+ inputs = processor(audio_np, sampling_rate=SAMPLE_RATE, return_tensors="pt")
21
+
22
+ with torch.no_grad():
23
+ logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
 
 
 
24
 
25
+ predicted_ids = torch.argmax(logits, dim=-1)
26
+ transcription = processor.decode(predicted_ids[0])
27
+ return transcription
 
 
28
 
29
  if __name__ == '__main__':
30
  app.run(host='0.0.0.0', port=7860)