StormblessedKal commited on
Commit
af9920b
·
1 Parent(s): b7866b5

support s3

Browse files
Files changed (2) hide show
  1. src/predict.py +18 -7
  2. src/rp_handler.py +4 -3
src/predict.py CHANGED
@@ -58,6 +58,7 @@ from pydantic import BaseModel, HttpUrl
58
  from api import BaseSpeakerTTS, ToneColorConverter
59
 
60
  from pydub import AudioSegment
 
61
 
62
 
63
  class Predictor:
@@ -146,22 +147,32 @@ class Predictor:
146
 
147
 
148
  def createvoice(self,audio_base_64,cut_audio,process_audio):
149
- file_bytes = base64.b64decode(audio_base_64)
150
- file_buffer = io.BytesIO(file_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
152
  header = file_buffer.read(12)
153
  print(header)
154
- file_format = None
155
- bucket_name = self.bucket_name
156
  if b'WAVE' in header:
157
  file_format = 'wav'
158
  elif b'OggS' in header:
159
  file_format = 'ogg'
160
- else:
161
- file_format = 'mp3'
162
 
163
  unique_filename = f"{uuid.uuid4()}"
164
-
165
  local_filename = f"{unique_filename}.{file_format}"
166
  with open(local_filename, 'wb') as file_out:
167
  file_out.write(file_bytes)
 
58
  from api import BaseSpeakerTTS, ToneColorConverter
59
 
60
  from pydub import AudioSegment
61
+ from urllib.parse import urlparse
62
 
63
 
64
  class Predictor:
 
147
 
148
 
149
  def createvoice(self,audio_base_64,cut_audio,process_audio):
150
+ file_bytes = None
151
+ if s3_url:
152
+ parsed_url = urlparse(s3_url)
153
+ bucket_name = parsed_url.netloc
154
+ s3_key = parsed_url.path.lstrip('/')
155
+ local_filename = f"{uuid.uuid4()}"
156
+ self.download_file_from_s3(bucket_name, s3_key, local_filename)
157
+ with open(local_filename, 'rb') as file:
158
+ file_bytes = file.read()
159
+ os.remove(local_filename)
160
+ elif audio_base_64:
161
+ file_bytes = base64.b64decode(audio_base_64)
162
+ else:
163
+ raise ValueError("Either s3_url or audio_base_64 must be provided.")
164
 
165
+ file_buffer = io.BytesIO(file_bytes)
166
  header = file_buffer.read(12)
167
  print(header)
168
+
169
+ file_format = 'mp3' # Default format
170
  if b'WAVE' in header:
171
  file_format = 'wav'
172
  elif b'OggS' in header:
173
  file_format = 'ogg'
 
 
174
 
175
  unique_filename = f"{uuid.uuid4()}"
 
176
  local_filename = f"{unique_filename}.{file_format}"
177
  with open(local_filename, 'wb') as file_out:
178
  file_out.write(file_bytes)
src/rp_handler.py CHANGED
@@ -28,9 +28,10 @@ def run_voice_clone_job(job):
28
  return {"error":"Please set method_type: available options, create_voice, voice_clone, voice_clone_with_emotions,voice_clone_with_multi_lang"}
29
 
30
  if method_type == "create_voice":
 
31
  audio_base64 = job_input.get('audio_base64')
32
- if audio_base64 is None:
33
- return {"error":"Needs audio file as base64"}
34
  cut_audio = job_input.get('cut_audio')
35
  process_audio = job_input.get('process_audio')
36
  print(process_audio)
@@ -39,7 +40,7 @@ def run_voice_clone_job(job):
39
  if cut_audio is None:
40
  cut_audio = 0
41
 
42
- processed_urls = MODEL.createvoice(audio_base64,cut_audio,process_audio)
43
  return processed_urls
44
  else:
45
  s3_url = job_input.get('s3_url')
 
28
  return {"error":"Please set method_type: available options, create_voice, voice_clone, voice_clone_with_emotions,voice_clone_with_multi_lang"}
29
 
30
  if method_type == "create_voice":
31
+ s3_url = job_input.get("s3_url")
32
  audio_base64 = job_input.get('audio_base64')
33
+ if audio_base64 is None and s3_url is None:
34
+ return {"error":"set audio_base64 or s3_url"}
35
  cut_audio = job_input.get('cut_audio')
36
  process_audio = job_input.get('process_audio')
37
  print(process_audio)
 
40
  if cut_audio is None:
41
  cut_audio = 0
42
 
43
+ processed_urls = MODEL.createvoice(s3_url,audio_base64,cut_audio,process_audio)
44
  return processed_urls
45
  else:
46
  s3_url = job_input.get('s3_url')