alessandro trinca tornidor commited on
Commit
9ab32d7
·
1 Parent(s): d804881

ci: hugginface space, move from docker to gradio sdk v5.6.0, add missing packages.txt with ffmpeg, pre-requirements.txt with pip, update gradio app to properly format informations to frontend, update tests

Browse files
README.md CHANGED
@@ -3,7 +3,9 @@ title: AI Pronunciation Trainer
3
  emoji: 🎤
4
  colorFrom: red
5
  colorTo: blue
6
- sdk: docker
 
 
7
  pinned: false
8
  license: mit
9
  ---
@@ -59,7 +61,8 @@ pnpm playwright test
59
 
60
  - add an updated online version on HuggingFace, Cloudflare or AWS
61
  - move from pytorch to onnxruntime (if possible)
62
- - refactor frontend with something more modern (e.g. vuejs)
 
63
  - refactor css style with tailwindcss
64
  - add more e2e tests with playwright
65
 
 
3
  emoji: 🎤
4
  colorFrom: red
5
  colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.6.0
8
+ app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
61
 
62
  - add an updated online version on HuggingFace, Cloudflare or AWS
63
  - move from pytorch to onnxruntime (if possible)
64
+ - refactor frontend with something more modern (e.g. vuejs, gradio)
65
+ - improve documentation, backend tests
66
  - refactor css style with tailwindcss
67
  - add more e2e tests with playwright
68
 
aip_trainer/lambdas/lambdaSpeechToScore.py CHANGED
@@ -43,12 +43,13 @@ def lambda_handler(event, context):
43
  },
44
  'body': ''
45
  }
46
- output = get_speech_to_score(real_text=real_text, file_bytes_or_audiotmpfile=file_bytes_or_audiotmpfile, language=language)
 
47
  app_logger.debug(f"output: {output} ...")
48
  return output
49
 
50
 
51
- def get_speech_to_score(real_text: str, file_bytes_or_audiotmpfile: str | dict, language: str = "en", remove_random_file: bool = True):
52
  app_logger.info(f"real_text:{real_text} ...")
53
  app_logger.debug(f"file_bytes:{file_bytes_or_audiotmpfile} ...")
54
  app_logger.info(f"language:{language} ...")
@@ -118,10 +119,12 @@ def get_speech_to_score(real_text: str, file_bytes_or_audiotmpfile: str | dict,
118
  duration = time.time() - start
119
  duration_tot = time.time() - start0
120
  app_logger.info(f'Time to post-process results: {duration}, tot_duration:{duration_tot}.')
 
 
121
 
122
- res = {'real_transcript': result['recording_transcript'],
123
- 'ipa_transcript': result['recording_ipa'],
124
- 'pronunciation_accuracy': str(int(result['pronunciation_accuracy'])),
125
  'real_transcripts': real_transcripts, 'matched_transcripts': matched_transcripts,
126
  'real_transcripts_ipa': real_transcripts_ipa, 'matched_transcripts_ipa': matched_transcripts_ipa,
127
  'pair_accuracy_category': pair_accuracy_category,
@@ -129,7 +132,15 @@ def get_speech_to_score(real_text: str, file_bytes_or_audiotmpfile: str | dict,
129
  'end_time': result['end_time'],
130
  'is_letter_correct_all_words': is_letter_correct_all_words}
131
 
132
- return json.dumps(res)
 
 
 
 
 
 
 
 
133
 
134
 
135
  # From Librosa
 
43
  },
44
  'body': ''
45
  }
46
+ output = get_speech_to_score_dict(real_text=real_text, file_bytes_or_audiotmpfile=file_bytes_or_audiotmpfile, language=language, remove_random_file=False)
47
+ output = json.dumps(output)
48
  app_logger.debug(f"output: {output} ...")
49
  return output
50
 
51
 
52
+ def get_speech_to_score_dict(real_text: str, file_bytes_or_audiotmpfile: str | dict, language: str = "en", remove_random_file: bool = True):
53
  app_logger.info(f"real_text:{real_text} ...")
54
  app_logger.debug(f"file_bytes:{file_bytes_or_audiotmpfile} ...")
55
  app_logger.info(f"language:{language} ...")
 
119
  duration = time.time() - start
120
  duration_tot = time.time() - start0
121
  app_logger.info(f'Time to post-process results: {duration}, tot_duration:{duration_tot}.')
122
+ pronunciation_accuracy = str(int(result['pronunciation_accuracy']))
123
+ ipa_transcript = result['recording_ipa']
124
 
125
+ return {'real_transcript': result['recording_transcript'],
126
+ 'ipa_transcript': ipa_transcript,
127
+ 'pronunciation_accuracy': pronunciation_accuracy,
128
  'real_transcripts': real_transcripts, 'matched_transcripts': matched_transcripts,
129
  'real_transcripts_ipa': real_transcripts_ipa, 'matched_transcripts_ipa': matched_transcripts_ipa,
130
  'pair_accuracy_category': pair_accuracy_category,
 
132
  'end_time': result['end_time'],
133
  'is_letter_correct_all_words': is_letter_correct_all_words}
134
 
135
+
136
+ def get_speech_to_score_tuple(real_text: str, file_bytes_or_audiotmpfile: str | dict, language: str = "en", remove_random_file: bool = True):
137
+ output = get_speech_to_score_dict(real_text=real_text, file_bytes_or_audiotmpfile=file_bytes_or_audiotmpfile, language=language, remove_random_file=remove_random_file)
138
+ real_transcripts = output['real_transcripts']
139
+ is_letter_correct_all_words = output['is_letter_correct_all_words']
140
+ pronunciation_accuracy = output['pronunciation_accuracy']
141
+ ipa_transcript = output['ipa_transcript']
142
+ real_transcripts_ipa = output['real_transcripts_ipa']
143
+ return real_transcripts, is_letter_correct_all_words, pronunciation_accuracy, ipa_transcript, real_transcripts_ipa, json.dumps(output)
144
 
145
 
146
  # From Librosa
aip_trainer/lambdas/routes.py DELETED
@@ -1,16 +0,0 @@
1
- import random
2
-
3
- import structlog
4
- from fastapi import APIRouter
5
-
6
-
7
- custom_structlog_logger = structlog.stdlib.get_logger(__name__)
8
- router = APIRouter()
9
-
10
-
11
- @router.get("/health")
12
- def health():
13
- import torch
14
- import torchaudio
15
- custom_structlog_logger.info(f"Still alive, torch version:{torch.__version__}, torchaudio:{torchaudio.__version__} ...")
16
- return "Still alive!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aip_trainer/models/models.py CHANGED
@@ -8,6 +8,66 @@ from silero.utils import Decoder
8
  from aip_trainer import app_logger
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def silero_stt(
12
  language="en",
13
  version="latest",
 
8
  from aip_trainer import app_logger
9
 
10
 
11
+ def silero_tts(language='en',
12
+ speaker='kseniya_16khz',
13
+ **kwargs):
14
+ """ Silero Text-To-Speech Models
15
+ language (str): language of the model, now available are ['ru', 'en', 'de', 'es', 'fr']
16
+ Returns a model and a set of utils
17
+ Please see https://github.com/snakers4/silero-models for usage examples
18
+ """
19
+ from omegaconf import OmegaConf
20
+ from silero.tts_utils import apply_tts
21
+ from silero.tts_utils import init_jit_model as init_jit_model_tts
22
+
23
+ models_list_file = os.path.join(os.path.dirname(__file__), "..", "..", "models.yml")
24
+ if not os.path.exists(models_list_file):
25
+ models_list_file = 'latest_silero_models.yml'
26
+ if not os.path.exists(models_list_file):
27
+ torch.hub.download_url_to_file('https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml',
28
+ 'latest_silero_models.yml',
29
+ progress=False)
30
+ assert os.path.exists(models_list_file)
31
+ models = OmegaConf.load(models_list_file)
32
+ available_languages = list(models.tts_models.keys())
33
+ assert language in available_languages, f'Language not in the supported list {available_languages}'
34
+ available_speakers = []
35
+ speaker_language = {}
36
+ for lang in available_languages:
37
+ speakers = list(models.tts_models.get(lang).keys())
38
+ available_speakers.extend(speakers)
39
+ for _ in speakers:
40
+ speaker_language[_] = lang
41
+ assert speaker in available_speakers, f'Speaker not in the supported list {available_speakers}'
42
+ assert language == speaker_language[speaker], f"Incorrect language '{language}' for this speaker, please specify '{speaker_language[speaker]}'"
43
+
44
+ model_conf = models.tts_models[language][speaker].latest
45
+ if '_v2' in speaker or '_v3' in speaker or 'v3_' in speaker or 'v4_' in speaker:
46
+ from torch import package
47
+ model_url = model_conf.package
48
+ model_dir = os.path.join(os.path.dirname(__file__), "model")
49
+ os.makedirs(model_dir, exist_ok=True)
50
+ model_path = os.path.join(model_dir, os.path.basename(model_url))
51
+ if not os.path.isfile(model_path):
52
+ torch.hub.download_url_to_file(model_url,
53
+ model_path,
54
+ progress=True)
55
+ imp = package.PackageImporter(model_path)
56
+ model = imp.load_pickle("tts_models", "model")
57
+ if speaker == 'multi_v2':
58
+ avail_speakers = model_conf.speakers
59
+ return model, avail_speakers
60
+ else:
61
+ example_text = model_conf.example
62
+ return model, example_text
63
+ else:
64
+ model = init_jit_model_tts(model_conf.jit)
65
+ symbols = model_conf.tokenset
66
+ example_text = model_conf.example
67
+ sample_rate = model_conf.sample_rate
68
+ return model, symbols, sample_rate, example_text, apply_tts
69
+
70
+
71
  def silero_stt(
72
  language="en",
73
  version="latest",
app.py CHANGED
@@ -1,127 +1,112 @@
1
- import logging
2
- import os
3
- import time
4
-
5
  import gradio as gr
6
- import structlog
7
- import uvicorn
8
- from aip_trainer.lambdas import lambdaSpeechToScore
9
- from asgi_correlation_id import CorrelationIdMiddleware
10
- from asgi_correlation_id.context import correlation_id
11
- from dotenv import load_dotenv
12
- from fastapi import FastAPI, Request, Response
13
- from uvicorn.protocols.utils import get_path_with_query_string
14
-
15
- from aip_trainer.utils.session_logger import setup_logging
16
- from aip_trainer.lambdas.routes import router
17
-
18
-
19
- load_dotenv()
20
-
21
- LOG_JSON_FORMAT = bool(os.getenv("LOG_JSON_FORMAT", False))
22
- LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")
23
- setup_logging(json_logs=LOG_JSON_FORMAT, log_level=LOG_LEVEL)
24
- logger = structlog.stdlib.get_logger(__name__)
25
- app = FastAPI(title="Example API", version="1.0.0")
26
-
27
-
28
- @app.middleware("http")
29
- async def logging_middleware(request: Request, call_next) -> Response:
30
- structlog.contextvars.clear_contextvars()
31
- # These context vars will be added to all log entries emitted during the request
32
- request_id = correlation_id.get()
33
- # print(f"request_id:{request_id}.")
34
- structlog.contextvars.bind_contextvars(request_id=request_id)
35
 
36
- start_time = time.perf_counter_ns()
37
- # If the call_next raises an error, we still want to return our own 500 response,
38
- # so we can add headers to it (process time, request ID...)
39
- response = Response(status_code=500)
40
- try:
41
- response = await call_next(request)
42
- except Exception:
43
- # TODO: Validate that we don't swallow exceptions (unit test?)
44
- structlog.stdlib.get_logger("api.error").exception("Uncaught exception")
45
- raise
46
- finally:
47
- process_time = time.perf_counter_ns() - start_time
48
- status_code = response.status_code
49
- url = get_path_with_query_string(request.scope)
50
- client_host = request.client.host
51
- client_port = request.client.port
52
- http_method = request.method
53
- http_version = request.scope["http_version"]
54
- # Recreate the Uvicorn access log format, but add all parameters as structured information
55
- logger.info(
56
- f"""{client_host}:{client_port} - "{http_method} {url} HTTP/{http_version}" {status_code}""",
57
- http={
58
- "url": str(request.url),
59
- "status_code": status_code,
60
- "method": http_method,
61
- "request_id": request_id,
62
- "version": http_version,
63
- },
64
- network={"client": {"ip": client_host, "port": client_port}},
65
- duration=process_time,
66
- )
67
- response.headers["X-Process-Time"] = str(process_time / 10 ** 9)
68
- return response
69
-
70
-
71
- app.include_router(router)
72
- logger.info("routes included, creating gradio app")
73
- CUSTOM_GRADIO_PATH = "/"
74
-
75
-
76
- def get_gradio_app():
77
- with gr.Blocks() as gradio_app:
78
- logger.info("start gradio app building...")
79
- gr.Markdown(
80
- """
81
- # Hello World!
82
 
83
- Start typing below to _see_ the *output*.
84
 
85
- Here a [link](https://huggingface.co/spaces/aletrn/gradio_with_fastapi).
86
- """
87
- )
88
- learner_transcription = gr.Textbox(
89
- label="Learner Transcription",
90
- placeholder="It is nice to wreck a nice beach",
91
- )
92
- language = gr.Textbox(
93
- label="language",
94
- placeholder="en",
95
- )
96
- learner_recording = gr.Audio(
97
- label="Learner Recording",
98
- sources=["microphone", "upload"],
99
- type="filepath"
100
- )
101
- text_output = gr.Textbox(lines=1, placeholder=None, label="Text Output")
102
- btn = gr.Button(value="get speech score")
103
- """
104
- event = {'body': json.dumps(request.get_json(force=True))}
105
- lambda_correct_output = lambdaSpeechToScore.lambda_handler(event, [])
 
 
 
 
 
 
 
106
  """
107
- btn.click(
108
- lambdaSpeechToScore.get_speech_to_score,
109
- inputs=[learner_transcription, learner_recording, language],
110
- outputs=[text_output]
111
- )
112
- return gradio_app
113
 
114
-
115
- logger.info("mounting gradio app within FastAPI...")
116
- gradio_app_md = get_gradio_app()
117
- app.add_middleware(CorrelationIdMiddleware)
118
- app = gr.mount_gradio_app(app, gradio_app_md, path=CUSTOM_GRADIO_PATH)
119
- logger.info("gradio app mounted")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  if __name__ == "__main__":
123
- try:
124
- uvicorn.run("app:app", host="127.0.0.1", port=7860, log_config=None, reload=True)
125
- except Exception as ex:
126
- logging.error(f"ex:{ex}.")
127
- raise ex
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from aip_trainer import app_logger
4
+ from aip_trainer.lambdas import lambdaSpeechToScore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
 
6
 
7
+ js = """
8
+ function updateCssText(text, letters) {
9
+ let wordsArr = text.split(" ")
10
+ let lettersWordsArr = letters.split(" ")
11
+ let speechOutputContainer = document.querySelector('#speech-output');
12
+ speechOutputContainer.textContent = ""
13
+
14
+ for (let idx in wordsArr) {
15
+ let word = wordsArr[idx]
16
+ let letterIsCorrect = lettersWordsArr[idx]
17
+ for (let idx1 in word) {
18
+ let letterCorrect = letterIsCorrect[idx1] == "1"
19
+ let containerLetter = document.createElement("span")
20
+ containerLetter.style.color = letterCorrect ? 'green' : "red"
21
+ containerLetter.innerText = word[idx1];
22
+ speechOutputContainer.appendChild(containerLetter)
23
+ }
24
+ let containerSpace = document.createElement("span")
25
+ containerSpace.textContent = " "
26
+ speechOutputContainer.appendChild(containerSpace)
27
+ }
28
+ }
29
+ """
30
+
31
+ with gr.Blocks() as gradio_app:
32
+ app_logger.info("start gradio app building...")
33
+
34
+ gr.Markdown(
35
  """
36
+ # AI Pronunciation Trainer
 
 
 
 
 
37
 
38
+ See [my fork](https://github.com/trincadev/ai-pronunciation-trainer) of [AI Pronunciation Trainer](https://github.com/Thiagohgl/ai-pronunciation-trainer) repositroy
39
+ for more details.
40
+ """
41
+ )
42
+ with gr.Row():
43
+ with gr.Column(scale=4, min_width=300):
44
+ with gr.Row():
45
+ with gr.Column(scale=1, min_width=50):
46
+ language = gr.Radio(["de", "en"], label="Language", value="en")
47
+ with gr.Column(scale=7, min_width=300):
48
+ learner_transcription = gr.Textbox(
49
+ lines=3,
50
+ label="Learner Transcription",
51
+ value="Hi there, how are you?",
52
+ )
53
+ with gr.Row():
54
+ learner_recording = gr.Audio(
55
+ label="Learner Recording",
56
+ sources=["microphone", "upload"],
57
+ type="filepath",
58
+ )
59
+ with gr.Column(scale=3, min_width=300):
60
+ transcripted_text = gr.Textbox(
61
+ lines=2, placeholder=None, label="Transcripted text", visible=False
62
+ )
63
+ letter_correctness = gr.Textbox(
64
+ lines=1,
65
+ placeholder=None,
66
+ label="Letters correctness",
67
+ visible=False,
68
+ )
69
+ pronunciation_accuracy = gr.Textbox(
70
+ lines=1, placeholder=None, label="Pronunciation accuracy %"
71
+ )
72
+ recording_ipa = gr.Textbox(
73
+ lines=1, placeholder=None, label="Learner phonetic transcription"
74
+ )
75
+ ideal_ipa = gr.Textbox(
76
+ lines=1, placeholder=None, label="Ideal phonetic transcription"
77
+ )
78
+ res = gr.Textbox(lines=1, placeholder=None, label="RES", visible=False)
79
+ html_output = gr.HTML(
80
+ label="Speech accuracy output",
81
+ elem_id="speech-output",
82
+ show_label=True,
83
+ visible=True,
84
+ render=True,
85
+ value=" - ",
86
+ elem_classes="speech-output",
87
+ )
88
+ btn = gr.Button(value="Recognize speech accuracy")
89
+ # real_transcripts, is_letter_correct_all_words, pronunciation_accuracy, result['recording_ipa'], real_transcripts_ipa, res
90
+
91
+ btn.click(
92
+ lambdaSpeechToScore.get_speech_to_score_tuple,
93
+ inputs=[learner_transcription, learner_recording, language],
94
+ outputs=[
95
+ transcripted_text,
96
+ letter_correctness,
97
+ pronunciation_accuracy,
98
+ recording_ipa,
99
+ ideal_ipa,
100
+ res,
101
+ ],
102
+ )
103
+ html_output.change(
104
+ None,
105
+ inputs=[transcripted_text, letter_correctness],
106
+ outputs=[html_output],
107
+ js=js,
108
+ )
109
 
110
 
111
  if __name__ == "__main__":
112
+ gradio_app.launch()
 
 
 
 
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
pre-requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
requirements-flask.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audioread
2
+ dtwalign
3
+ eng_to_ipa
4
+ epitran==1.25.1
5
+ flask
6
+ flask_cors
7
+ gunicorn
8
+ omegaconf
9
+ ortools==9.11.4210
10
+ pandas
11
+ pickle-mixin
12
+ python-dotenv
13
+ requests
14
+ sentencepiece
15
+ silero==0.4.1
16
+ soundfile==0.12.1
17
+ sqlalchemy
18
+ structlog
19
+ torch
20
+ torchaudio
21
+ transformers
requirements.txt CHANGED
@@ -1,9 +1,8 @@
 
1
  audioread
2
  dtwalign
3
  eng_to_ipa
4
  epitran==1.25.1
5
- flask
6
- flask_cors
7
  gunicorn
8
  omegaconf
9
  ortools==9.11.4210
@@ -14,7 +13,6 @@ requests
14
  sentencepiece
15
  silero==0.4.1
16
  soundfile==0.12.1
17
- sqlalchemy
18
  structlog
19
  torch
20
  torchaudio
 
1
+ asgi-correlation-id
2
  audioread
3
  dtwalign
4
  eng_to_ipa
5
  epitran==1.25.1
 
 
6
  gunicorn
7
  omegaconf
8
  ortools==9.11.4210
 
13
  sentencepiece
14
  silero==0.4.1
15
  soundfile==0.12.1
 
16
  structlog
17
  torch
18
  torchaudio
tests/test_GetAccuracyFromRecordedAudio.py CHANGED
@@ -86,7 +86,7 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase):
86
 
87
  language = "en"
88
  path = EVENTS_FOLDER / f"test_{language}.wav"
89
- output = lambdaSpeechToScore.get_speech_to_score(
90
  real_text=text_dict[language],
91
  file_bytes_or_audiotmpfile=path,
92
  language=language,
@@ -105,14 +105,14 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase):
105
  "end_time": "0.559875 1.658125 1.14825 1.344375 1.658125",
106
  "is_letter_correct_all_words": "11 000001 111 111 1111 ",
107
  }
108
- check_output(self, json.loads(output), expected_output)
109
 
110
  def test_get_speech_to_score_de_ok(self):
111
  from aip_trainer.lambdas import lambdaSpeechToScore
112
 
113
  language = "de"
114
  path = EVENTS_FOLDER / f"test_{language}.wav"
115
- output = lambdaSpeechToScore.get_speech_to_score(
116
  real_text=text_dict[language],
117
  file_bytes_or_audiotmpfile=path,
118
  language=language,
@@ -131,7 +131,7 @@ class TestGetAccuracyFromRecordedAudio(unittest.TestCase):
131
  "end_time": "0.328 0.6458125 1.44025 2.4730625 2.15525 2.4730625",
132
  "is_letter_correct_all_words": "111 111 11111 000 1011 111 ",
133
  }
134
- check_output(self, json.loads(output), expected_output)
135
 
136
 
137
  if __name__ == "__main__":
 
86
 
87
  language = "en"
88
  path = EVENTS_FOLDER / f"test_{language}.wav"
89
+ output = lambdaSpeechToScore.get_speech_to_score_dict(
90
  real_text=text_dict[language],
91
  file_bytes_or_audiotmpfile=path,
92
  language=language,
 
105
  "end_time": "0.559875 1.658125 1.14825 1.344375 1.658125",
106
  "is_letter_correct_all_words": "11 000001 111 111 1111 ",
107
  }
108
+ check_output(self, output, expected_output)
109
 
110
  def test_get_speech_to_score_de_ok(self):
111
  from aip_trainer.lambdas import lambdaSpeechToScore
112
 
113
  language = "de"
114
  path = EVENTS_FOLDER / f"test_{language}.wav"
115
+ output = lambdaSpeechToScore.get_speech_to_score_dict(
116
  real_text=text_dict[language],
117
  file_bytes_or_audiotmpfile=path,
118
  language=language,
 
131
  "end_time": "0.328 0.6458125 1.44025 2.4730625 2.15525 2.4730625",
132
  "is_letter_correct_all_words": "111 111 11111 000 1011 111 ",
133
  }
134
+ check_output(self, output, expected_output)
135
 
136
 
137
  if __name__ == "__main__":