File size: 3,557 Bytes
40fac91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
#maxlen[0]=int(input("Maximum text length:"))
if os.path.exists("/usr/lib/x86_64-linux-gnu/libtcmalloc.so"):
    try:
        os.environ["LD_PRELOAD"] = "/usr/lib/x86_64-linux-gnu/libtcmalloc.so"
        import ctypes
        ctypes.CDLL("libtcmalloc.so", mode=ctypes.RTLD_GLOBAL)
        print("tcmalloc.so loaded.")
    except Exception as e:
        print(e)
        print("Failed to load tcmalloc.so.")
else:
    print("Cannot locate TCMalloc.")
from fastapi import FastAPI,Body,Request
from fastapi.responses import JSONResponse,Response,StreamingResponse
from starlette.responses import FileResponse
import uvicorn
import logging
from pydantic import BaseModel
import vits
import torch
import re
import threading
import cmd

blacklist=[]
maxlen=[]

with open('blacklist.txt', 'r') as f:
    lines = f.readlines()
    blacklist = [line.strip() for line in lines]

with open('maxlen.txt', 'r') as f:
    maxlen.append(f.read())

if torch.cuda.is_available():
    gpu=1
else:
    print("Use CPU.")
    gpu=0

if gpu==1:
    import run_old
else:
    import run_new


app = FastAPI()
logging.basicConfig(level=logging.WARNING)

class item(BaseModel):
    command: str

@app.post("/")
def getwav(command:item,request:Request):
    global maxlen,blacklist
    if request.client.host in blacklist:
        return JSONResponse(
                status_code=403,
                content={"message":"IP banned."},)
    command=str(command)
    print(command)

    if str(command)[9:15]=="python":
        s = command[9:-1]
        text_match = re.search(r"--text=(\S+)", s)
        if text_match:
            text = text_match.group(1)
            if len(text)>int(maxlen[0]):
                return JSONResponse(
                status_code=403,
                content={"message":"The text is too long."},)
        else:
            return JSONResponse(
                status_code=404,
                content={"message":"missing text."},)
        character_match = re.search(r"--character=(\d+)", s)
        if character_match:
            character = int(character_match.group(1))
        else:
            return JSONResponse(
                status_code=404,
                content={"message":"missing character."},)

        try:
            if gpu==0:
        
                if "./vits/" in s:
                    result=run_new.ys(text,character)
                elif "./vits_bh3/" in s:
                    result=run_new.bh3(text,character)
                else:
                    return JSONResponse(
                        status_code=404,
                        content={"message":"missing py"},)
            if gpu==1:
                if "./vits/" in s:
                    result=run_old.ys(text,character)
                elif "./vits_bh3/" in s:
                    result=run_old.bh3(text,character)
                else:
                    return JSONResponse(
                        status_code=404,
                        content={"message":"missing py"},)
                
        except Exception as e:
            print(e)
            return JSONResponse(
                status_code=500,
                content={"message":"Internal Server Error."},)

        #os.system(command[9:-1])
        response = StreamingResponse(iter([result.getvalue()]), media_type="application/octet-stream")

        response.headers["Content-Disposition"] = "attachment; filename=example.wav"
        return response#FileResponse('./example.wav', media_type="wav")


#uvicorn.run(app=app, host="0.0.0.0", port=7860, log_level="debug")