Zeimoto commited on
Commit
071265e
1 Parent(s): 9d881ca

models split into files

Browse files
Files changed (1) hide show
  1. app.py +9 -82
app.py CHANGED
@@ -1,104 +1,31 @@
1
  import streamlit as st
2
  from st_audiorec import st_audiorec
3
 
4
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
5
- #from datasets import load_dataset
6
- import torch
7
- from gliner import GLiNER
8
-
9
- from resources import Lead_Labels, entity_labels, set_start, audit_elapsedtime
10
-
11
 
12
  def main ():
13
  print("------------------------------")
14
  print(f"Running main")
15
 
16
- rec = init_model_trans()
17
  ner = init_model_ner() #async
18
 
19
- labels = entity_labels
20
-
21
- # text = "I have a proposal from cgd where they want one outsystems junior developers and one senior for an estimate of three hundred euros a day, for six months."
22
- # print(f"get entities from sample text: {text}")
23
- # get_entity_labels(model=ner, text=text, labels=labels)
24
-
25
  print("Rendering UI...")
26
  start_render = set_start()
27
  wav_audio_data = st_audiorec()
28
  audit_elapsedtime(function="Rendering UI", start=start_render)
29
 
30
- if wav_audio_data is not None and rec is not None:
31
  print("Loading data...")
32
  start_loading = set_start()
33
  st.audio(wav_audio_data, format='audio/wav')
34
- text = transcribe(wav_audio_data, rec)
35
- if text is not None:
36
- get_entity_labels(labels=labels, model=ner, text=text)
37
-
38
- audit_elapsedtime(function="Loading data", start=start_loading)
39
-
40
-
41
- def init_model_trans ():
42
- print("Initiating transcription model...")
43
- start = set_start()
44
-
45
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
46
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
47
 
48
- model_id = "openai/whisper-large-v3"
49
-
50
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
51
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
52
- )
53
- model.to(device)
54
-
55
- processor = AutoProcessor.from_pretrained(model_id)
56
-
57
- pipe = pipeline(
58
- "automatic-speech-recognition",
59
- model=model,
60
- tokenizer=processor.tokenizer,
61
- feature_extractor=processor.feature_extractor,
62
- max_new_tokens=128,
63
- chunk_length_s=30,
64
- batch_size=16,
65
- return_timestamps=True,
66
- torch_dtype=torch_dtype,
67
- device=device,
68
- )
69
- print(f'Init model successful')
70
- audit_elapsedtime(function="Initiating transcription model", start=start)
71
- return pipe
72
-
73
- def init_model_ner():
74
- print("Initiating NER model...")
75
- start = set_start()
76
- model = GLiNER.from_pretrained("urchade/gliner_multi")
77
- audit_elapsedtime(function="Initiating NER model", start=start)
78
- return model
79
-
80
- def transcribe (audio_sample: bytes, pipe) -> str:
81
- print("Initiating transcription...")
82
- start = set_start()
83
- # dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
84
- # sample = dataset[0]["audio"]
85
- result = pipe(audio_sample)
86
- audit_elapsedtime(function="Transcription", start=start)
87
- print(result)
88
-
89
- st.write('trancription: ', result["text"])
90
- return result["text"]
91
-
92
- def get_entity_labels(model: GLiNER, text: str, labels: list): #-> Lead_labels:
93
- print("Initiating entity recognition...")
94
- start = set_start()
95
- entities = model.predict_entities(text, labels)
96
- audit_elapsedtime(function="Retreiving entity labels from text", start=start)
97
-
98
- for entity in entities:
99
- print(entity["text"], "=>", entity["label"])
100
- st.write('Entities: ', entities)
101
- # return Lead_Labels()
102
 
103
  if __name__ == "__main__":
104
  print("IN __name__")
 
1
  import streamlit as st
2
  from st_audiorec import st_audiorec
3
 
4
+ from ner import init_model_ner, get_entity_labels
5
+ from speech2text import init_model_trans, transcribe
6
+ from resources import audit_elapsedtime, set_start
 
 
 
 
7
 
8
  def main ():
9
  print("------------------------------")
10
  print(f"Running main")
11
 
12
+ s2t = init_model_trans()
13
  ner = init_model_ner() #async
14
 
 
 
 
 
 
 
15
  print("Rendering UI...")
16
  start_render = set_start()
17
  wav_audio_data = st_audiorec()
18
  audit_elapsedtime(function="Rendering UI", start=start_render)
19
 
20
+ if wav_audio_data is not None and s2t is not None:
21
  print("Loading data...")
22
  start_loading = set_start()
23
  st.audio(wav_audio_data, format='audio/wav')
24
+ text = transcribe(wav_audio_data, s2t)
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ if text is not None and ner is not None:
27
+ st.write('Entities: ', get_entity_labels(model=ner, text=text))
28
+ audit_elapsedtime(function="Loading data", start=start_loading)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  if __name__ == "__main__":
31
  print("IN __name__")