File size: 5,706 Bytes
42c635f
 
1f4e6d7
 
fc2fb4c
1f4e6d7
 
 
c2f16dd
 
fc2fb4c
 
6316aee
8de23a7
 
c2f16dd
86e147f
 
 
 
 
 
1f919a0
 
86e147f
 
 
240ec1d
86e147f
240ec1d
86e147f
240ec1d
86e147f
240ec1d
86e147f
240ec1d
86e147f
240ec1d
86e147f
1f919a0
 
 
c565b19
86e147f
 
 
 
 
fc2fb4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5d92c5
 
fc2fb4c
 
 
 
 
 
a5d92c5
fc2fb4c
a5d92c5
fc2fb4c
60cb5bc
fc2fb4c
 
 
 
 
a5d92c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc2fb4c
 
 
 
 
 
 
 
 
 
 
 
6316aee
fc2fb4c
 
 
 
 
 
 
 
1f4e6d7
 
 
5e08dde
1f4e6d7
fc2fb4c
 
5e08dde
fc2fb4c
8205d13
fc2fb4c
5e08dde
fc2fb4c
5e08dde
 
 
 
cc5d47e
fc2fb4c
 
 
 
 
 
 
 
 
 
 
 
1f919a0
 
 
 
1f4e6d7
 
 
8de23a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import io

import numpy as np
import soundfile
from flask import Flask, request, send_file, jsonify

from inference import infer_tool, slicer

import requests
import os
import uuid
import threading
import traceback
from qcloud_cos import CosConfig
from qcloud_cos import CosS3Client


secret_id = os.getenv('SECRET_ID')
secret_key = os.getenv('SECRET_KEY')
region = 'na-siliconvalley'
bucket_name = 'xiaohei-cat-ai-1304646510'

print("Starting download the model and config...")

config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key)
client = CosS3Client(config)

response1 = client.get_object(
    Bucket=bucket_name,
    Key="models/So-VITS-SVC/Koxia-Full/G_full.pth"
)
response1['Body'].get_stream_to_file('/tmp/G_full.pth')

response2 = client.get_object(
    Bucket=bucket_name,
    Key="models/So-VITS-SVC/Koxia-Full/config.json"
)
response2['Body'].get_stream_to_file('/tmp/config.json')

print("Download complete!")

print("Starting service...")

model_name = "/tmp/G_full.pth"  # 模型地址
config_name = "/tmp/config.json"  # config地址
svc_model = infer_tool.Svc(model_name, config_name)


tasks = {}
running_threads = 0
condition = threading.Condition()

def infer(audio_path, tran, spk, wav_format, task_id):
    global running_threads
    with condition:
        while running_threads >= 1:
            tasks[task_id] = {"status": "queue"}
            condition.wait()
        running_threads += 1
    tasks[task_id] = {"status": "processing"}
    try:
        audio_name = audio_path.split('/')[-1]
        infer_tool.format_wav(audio_path)
        chunks = slicer.cut(audio_path, db_thresh=-40)
        audio_data, audio_sr = slicer.chunks2audio(audio_path, chunks)

        audio = []

        def process_chunk(chunk_data, audio_sr, slice_tag, svc_model, audio):
            length = int(np.ceil(len(chunk_data) / audio_sr * svc_model.target_sample))
            if slice_tag:
                print('jump empty segment')
                _audio = np.zeros(length)
            else:
                # padd
                pad_len = int(audio_sr * 0.5)
                chunk_data = np.concatenate([np.zeros([pad_len]), chunk_data, np.zeros([pad_len])])
                raw_path = io.BytesIO()
                soundfile.write(raw_path, chunk_data, audio_sr, format="wav")
                raw_path.seek(0)
                out_audio, out_audio_shape, out_sr = svc_model.infer(spk, tran, raw_path)
                svc_model.clear_empty()
                _audio = out_audio.cpu().numpy()
                pad_len = int(svc_model.target_sample * 0.5)
                _audio = _audio[pad_len:-pad_len]
            audio.extend(list(infer_tool.pad_array(_audio, length)))

        for (slice_tag, data) in audio_data:
            print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')

            # Check if the segment is longer than 30 seconds
            segment_length = len(data) / audio_sr
            if segment_length > 30:
                # Split the segment into chunks of 30 seconds or less
                num_chunks = int(np.ceil(segment_length / 30))
                chunk_length = int(len(data) / num_chunks)
                for i in range(num_chunks):
                    chunk_data = data[i * chunk_length:(i + 1) * chunk_length]
                    process_chunk(chunk_data, audio_sr, slice_tag, svc_model, audio)
            else:
                process_chunk(data, audio_sr, slice_tag, svc_model, audio)
        out_wav_path = "/tmp/" + audio_name
        soundfile.write(out_wav_path, audio, svc_model.target_sample, format=wav_format)
    
        # 提供文件的永久直链
        result_audio_url = f"/download/{os.path.basename(out_wav_path)}"

        # 更新任务状态
        tasks[task_id] = {
            "status": "completed",
            "url": result_audio_url,
        }
    except Exception as e:
        traceback.print_exc()
        tasks[task_id] = {
            "status": "error",
            "message": str(e)
        }
    with condition:
            running_threads -= 1
            condition.notify_all()

app = Flask(__name__)


@app.route("/wav2wav", methods=["GET"])
def wav2wav():
    task_id = str(uuid.uuid4())
    tasks[task_id] = {"status": "processing"}
    audio_result = requests.get(request.args.get('audio_path'))
    if audio_result.status_code != 200:
        print("audio result status code as ", audio_result.status_code, " because of ", str(audio_result.content))
        raise Exception("无效的 URL")
    with open("/tmp/" + request.args.get("audio_path", "").split('/')[-1], 'wb') as f:
        f.write(audio_result.content)
    audio_path = "/tmp/" + request.args.get("audio_path", "").split('/')[-1]  # wav文件地址
    tran = int(float(request.args.get("tran", 0)))  # 音调
    spk = request.args.get("spk", 0)  # 说话人(id或者name都可以,具体看你的config)
    wav_format = request.args.get("wav_format", 'wav')  # 范围文件格式
    threading.Thread(target=infer, args=(audio_path, tran, spk, wav_format, task_id)).start()
    return jsonify({"task_id": task_id}), 202

@app.route('/api/tasks/<task_id>', methods=['GET'])
def get_task_status(task_id):
    task = tasks.get(task_id)
    if task:
        return jsonify(task)
    else:
        return jsonify({"error": "Task not found"}), 404

@app.route('/download/<filename>', methods=['GET'])
def download(filename):
    if filename == 'G_full.pth' or filename == 'config.json':
        return jsonify({"error": "File not found"}), 404
    else:
        return send_file("/tmp/" + filename, as_attachment=True)


if __name__ == '__main__':
    app.run(port=1145, host="0.0.0.0", debug=False, threaded=False)