next-playground commited on
Commit
fc2fb4c
1 Parent(s): db1e6f3

Update flask_api_full_song.py

Browse files
Files changed (1) hide show
  1. flask_api_full_song.py +86 -34
flask_api_full_song.py CHANGED
@@ -1,56 +1,108 @@
1
- import io
2
-
3
  import numpy as np
4
  import soundfile
5
- from flask import Flask, request, send_file
6
 
7
  from inference import infer_tool, slicer
8
 
9
  import requests
10
  import os
 
 
11
  from qcloud_cos import CosConfig
12
  from qcloud_cos import CosS3Client
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  app = Flask(__name__)
15
 
16
 
17
  @app.route("/wav2wav", methods=["POST"])
18
  def wav2wav():
 
 
19
  request_form = request.form
20
- audio_path = request_form.get("audio_path", None) # wav文件地址
 
 
 
 
 
21
  tran = int(float(request_form.get("tran", 0))) # 音调
22
  spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config)
23
  wav_format = request_form.get("wav_format", 'wav') # 范围文件格式
24
- infer_tool.format_wav(audio_path)
25
- chunks = slicer.cut(audio_path, db_thresh=-40)
26
- audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks)
27
-
28
- audio = []
29
- for (slice_tag, data) in audio_data:
30
- print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
31
-
32
- length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
33
- if slice_tag:
34
- print('jump empty segment')
35
- _audio = np.zeros(length)
36
- else:
37
- # padd
38
- pad_len = int(audio_sr * 0.5)
39
- data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
40
- raw_path = io.BytesIO()
41
- soundfile.write(raw_path, data, audio_sr, format="wav")
42
- raw_path.seek(0)
43
- out_audio, out_sr = svc_model.infer(spk, tran, raw_path)
44
- svc_model.clear_empty()
45
- _audio = out_audio.cpu().numpy()
46
- pad_len = int(svc_model.target_sample * 0.5)
47
- _audio = _audio[pad_len:-pad_len]
48
-
49
- audio.extend(list(infer_tool.pad_array(_audio, length)))
50
- out_wav_path = io.BytesIO()
51
- soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format)
52
- out_wav_path.seek(0)
53
- return send_file(out_wav_path, download_name=f"temp.{wav_format}", as_attachment=True)
54
 
55
 
56
  if __name__ == '__main__':
 
 
 
1
  import numpy as np
2
  import soundfile
3
+ from flask import Flask, request, send_file, jsonify
4
 
5
  from inference import infer_tool, slicer
6
 
7
  import requests
8
  import os
9
+ import uuid
10
+ import threading
11
  from qcloud_cos import CosConfig
12
  from qcloud_cos import CosS3Client
13
 
14
+ tasks = {}
15
+ running_threads = 0
16
+ condition = threading.Condition()
17
+
18
+ def infer(audio_path, tran, spk, wav_format, task_id):
19
+ global running_threads
20
+ with condition:
21
+ while running_threads >= 1:
22
+ tasks[task_id] = {"status": "queue"}
23
+ condition.wait()
24
+ running_threads += 1
25
+ tasks[task_id] = {"status": "processing"}
26
+ try:
27
+ audio_name = audio_path.split('/')[-1]
28
+ infer_tool.format_wav(audio_path)
29
+ chunks = slicer.cut(audio_path, db_thresh=-40)
30
+ audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks)
31
+
32
+ audio = []
33
+ for (slice_tag, data) in audio_data:
34
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
35
+
36
+ length = int(np.ceil(len(data) / audio_sr * svc_model.target_sample))
37
+ if slice_tag:
38
+ print('jump empty segment')
39
+ _audio = np.zeros(length)
40
+ else:
41
+ # padd
42
+ pad_len = int(audio_sr * 0.5)
43
+ data = np.concatenate([np.zeros([pad_len]), data, np.zeros([pad_len])])
44
+ raw_path = io.BytesIO()
45
+ soundfile.write(raw_path, data, audio_sr, format="wav")
46
+ raw_path.seek(0)
47
+ out_audio, out_sr = svc_model.infer(spk, tran, raw_path)
48
+ svc_model.clear_empty()
49
+ _audio = out_audio.cpu().numpy()
50
+ pad_len = int(svc_model.target_sample * 0.5)
51
+ _audio = _audio[pad_len:-pad_len]
52
+
53
+ audio.extend(list(infer_tool.pad_array(_audio, length)))
54
+ out_wav_path = "/tmp/" + audio_name
55
+ soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format)
56
+ out_wav_path.seek(0)
57
+
58
+ # 提供文件的永久直链
59
+ result_audio_url = f"/download/{os.path.basename(out_wav_path)}"
60
+
61
+ # 更新任务状态
62
+ tasks[task_id] = {
63
+ "status": "completed",
64
+ "url": result_audio_url,
65
+ }
66
+ except Exception as e:
67
+ tasks[task_id] = {
68
+ "status": "error",
69
+ "message": str(e)
70
+ }
71
+ with condition:
72
+ running_threads -= 1
73
+ condition.notify_all()
74
+
75
  app = Flask(__name__)
76
 
77
 
78
  @app.route("/wav2wav", methods=["POST"])
79
  def wav2wav():
80
+ task_id = str(uuid.uuid4())
81
+ tasks[task_id] = {"status": "processing"}
82
  request_form = request.form
83
+ audio_result = requests.get(request_form.get("audio_path", ""))
84
+ if audio_result.status_code != 200:
85
+ raise Exception("无效的 URL")
86
+ with open("/tmp/" + request_form.get("audio_path", "").split('/')[-1], 'wb') as f:
87
+ f.write(audio_result.content)
88
+ audio_path = "/tmp/" + request_form.get("audio_path", "").split('/')[-1] # wav文件地址
89
  tran = int(float(request_form.get("tran", 0))) # 音调
90
  spk = request_form.get("spk", 0) # 说话人(id或者name都可以,具体看你的config)
91
  wav_format = request_form.get("wav_format", 'wav') # 范围文件格式
92
+ threading.Thread(target=infer, args=(audio_path, tran, spk, wav_format)).start()
93
+ return jsonify({"task_id": task_id}), 202
94
+
95
+ @app.route('/api/tasks/<task_id>', methods=['GET'])
96
+ def get_task_status(task_id):
97
+ task = tasks.get(task_id)
98
+ if task:
99
+ return jsonify(task)
100
+ else:
101
+ return jsonify({"error": "Task not found"}), 404
102
+
103
+ @app.route('/download/<filename>', methods=['GET'])
104
+ def download(filename):
105
+ return send_file("/tmp/" + filename, as_attachment=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  if __name__ == '__main__':