radinhas commited on
Commit
d37853e
·
1 Parent(s): b08001d

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +33 -2
apis/chat_api.py CHANGED
@@ -2,6 +2,8 @@ import argparse
2
  import uvicorn
3
  import sys
4
  import json
 
 
5
 
6
 
7
  from fastapi import FastAPI
@@ -13,6 +15,7 @@ from utils.logger import logger
13
  from networks.message_streamer import MessageStreamer
14
  from messagers.message_composer import MessageComposer
15
  from googletrans import Translator
 
16
 
17
 
18
  class ChatAPIApp:
@@ -71,7 +74,7 @@ class ChatAPIApp:
71
  class DetectLanguagePostItem(BaseModel):
72
  input_text: str = Field(
73
  default="Hello",
74
- description="(str) `Text for translate`",
75
  )
76
 
77
  def detect_language(self, item: DetectLanguagePostItem):
@@ -84,8 +87,31 @@ class ChatAPIApp:
84
  }
85
  json_compatible_item_data = jsonable_encoder(item_response)
86
  return JSONResponse(content=json_compatible_item_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
-
 
 
 
 
 
89
  def setup_routes(self):
90
  for prefix in ["", "/v1"]:
91
  self.app.get(
@@ -103,6 +129,11 @@ class ChatAPIApp:
103
  summary="detect language",
104
  )(self.detect_language)
105
 
 
 
 
 
 
106
 
107
  class ArgParser(argparse.ArgumentParser):
108
  def __init__(self, *args, **kwargs):
 
2
  import uvicorn
3
  import sys
4
  import json
5
+ import string
6
+ import random
7
 
8
 
9
  from fastapi import FastAPI
 
15
  from networks.message_streamer import MessageStreamer
16
  from messagers.message_composer import MessageComposer
17
  from googletrans import Translator
18
+ from gtts import gTTS
19
 
20
 
21
  class ChatAPIApp:
 
74
  class DetectLanguagePostItem(BaseModel):
75
  input_text: str = Field(
76
  default="Hello",
77
+ description="(str) `Text for detection`",
78
  )
79
 
80
  def detect_language(self, item: DetectLanguagePostItem):
 
87
  }
88
  json_compatible_item_data = jsonable_encoder(item_response)
89
  return JSONResponse(content=json_compatible_item_data)
90
+
91
+ class TTSPostItem(BaseModel):
92
+ input_text: str = Field(
93
+ default="Hello",
94
+ description="(str) `Text for TTS`",
95
+ )
96
+ from_language: str = Field(
97
+ default="en",
98
+ description="(str) `TTS language`",
99
+ )
100
+
101
+ def text_to_speech(self, item: TTSPostItem):
102
+ audioobj = gTTS(text = item.input_text,
103
+ lang = item.from_language,
104
+ slow = False)
105
+ fileName = ''.join(random.SystemRandom().choice(string.ascii_uppercase + string.digits) for _ in range(10));
106
+ fileName = fileName + ".mp3";
107
+ audioobj.save(fileName)
108
 
109
+ item_response = {
110
+ "src": fileName
111
+ }
112
+ json_compatible_item_data = jsonable_encoder(item_response)
113
+ return JSONResponse(content=json_compatible_item_data)
114
+
115
  def setup_routes(self):
116
  for prefix in ["", "/v1"]:
117
  self.app.get(
 
129
  summary="detect language",
130
  )(self.detect_language)
131
 
132
+ self.app.post(
133
+ prefix + "/tts",
134
+ summary="text to speech",
135
+ )(self.text_to_speech)
136
+
137
 
138
  class ArgParser(argparse.ArgumentParser):
139
  def __init__(self, *args, **kwargs):