Reiner4 superdup95 commited on
Commit
e79d812
0 Parent(s):

Duplicate from superdup95/extras_test

Browse files

Co-authored-by: su <superdup95@users.noreply.huggingface.co>

Files changed (7) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +21 -0
  3. README.md +11 -0
  4. constants.py +50 -0
  5. requirements-complete.txt +19 -0
  6. server.py +966 -0
  7. tts_edge.py +34 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements-complete.txt .
6
+ RUN pip install -r requirements-complete.txt
7
+
8
+ RUN mkdir /.cache && chmod -R 777 /.cache
9
+ RUN mkdir .chroma && chmod -R 777 .chroma
10
+
11
+ COPY . .
12
+
13
+
14
+ RUN chmod -R 777 /app
15
+
16
+ RUN --mount=type=secret,id=password,mode=0444,required=true \
17
+ cat /run/secrets/password > /test
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb,sd", "--sd-remote", "--sd-remote-port=443", "--sd-remote-ssl"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: extras
3
+ emoji: 🧊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: superdup95/extras_test
10
+ ---
11
+ Fixed Server.JS Latest 2023/08/16
constants.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ DEFAULT_CUDA_DEVICE = "cuda:0"
3
+ # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
4
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ElectrifAi_v14"
5
+ # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
6
+ DEFAULT_CLASSIFICATION_MODEL = "joeddav/distilbert-base-uncased-go-emotions-student"
7
+ # Also try: 'Salesforce/blip-image-captioning-base'
8
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
9
+ DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
10
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
+ DEFAULT_REMOTE_SD_PORT = 7860
13
+ DEFAULT_CHROMA_PORT = 8000
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
16
+ # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
+ DEFAULT_SUMMARIZE_PARAMS = {
18
+ "temperature": 1.0,
19
+ "repetition_penalty": 1.0,
20
+ "max_length": 500,
21
+ "min_length": 200,
22
+ "length_penalty": 1.5,
23
+ "bad_words": [
24
+ "\n",
25
+ '"',
26
+ "*",
27
+ "[",
28
+ "]",
29
+ "{",
30
+ "}",
31
+ ":",
32
+ "(",
33
+ ")",
34
+ "<",
35
+ ">",
36
+ "Â",
37
+ "The text ends",
38
+ "The story ends",
39
+ "The text is",
40
+ "The story is",
41
+ ],
42
+ }
43
+
44
+ PROMPT_PREFIX = "best quality, absurdres, "
45
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
46
+ error hands, bad hands, error fingers, bad fingers, missing fingers
47
+ error legs, bad legs, multiple legs, missing legs, error lighting,
48
+ error shadow, error reflection, text, error, extra digit, fewer digits,
49
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
50
+ signature, watermark, username, blurry"""
requirements-complete.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cloudflared
3
+ flask-cors
4
+ flask-compress
5
+ markdown
6
+ Pillow
7
+ colorama
8
+ webuiapi
9
+ --extra-index-url https://download.pytorch.org/whl/cu117
10
+ torch==2.0.0+cu117
11
+ torchvision==0.15.1
12
+ torchaudio==2.0.1+cu117
13
+ accelerate
14
+ transformers==4.28.1
15
+ diffusers==0.16.1
16
+ silero-api-server
17
+ chromadb
18
+ sentence_transformers
19
+ edge-tts
server.py ADDED
@@ -0,0 +1,966 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ Response,
7
+ render_template_string,
8
+ abort,
9
+ send_from_directory,
10
+ send_file,
11
+ )
12
+ from flask_cors import CORS
13
+ from flask_compress import Compress
14
+ import markdown
15
+ import argparse
16
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
17
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
18
+ from transformers import BlipForConditionalGeneration
19
+ import unicodedata
20
+ import torch
21
+ import time
22
+ import os
23
+ import gc
24
+ import sys
25
+ import secrets
26
+ from PIL import Image
27
+ import base64
28
+ from io import BytesIO
29
+ from random import randint
30
+ import webuiapi
31
+ import hashlib
32
+ from constants import *
33
+ from colorama import Fore, Style, init as colorama_init
34
+
35
+ colorama_init()
36
+
37
+ if sys.hexversion < 0x030b0000:
38
+ print(f"{Fore.BLUE}{Style.BRIGHT}Python 3.11 or newer is recommended to run this program.{Style.RESET_ALL}")
39
+ time.sleep(2)
40
+
41
+ class SplitArgs(argparse.Action):
42
+ def __call__(self, parser, namespace, values, option_string=None):
43
+ setattr(
44
+ namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
45
+ )
46
+
47
+ #Setting Root Folders for Silero Generations so it is compatible with STSL, should not effect regular runs. - Rolyat
48
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
49
+ SILERO_SAMPLES_PATH = os.path.join(parent_dir, "tts_samples")
50
+ SILERO_SAMPLE_TEXT = os.path.join(parent_dir)
51
+
52
+ # Create directories if they don't exist
53
+ if not os.path.exists(SILERO_SAMPLES_PATH):
54
+ os.makedirs(SILERO_SAMPLES_PATH)
55
+ if not os.path.exists(SILERO_SAMPLE_TEXT):
56
+ os.makedirs(SILERO_SAMPLE_TEXT)
57
+
58
+ # Script arguments
59
+ parser = argparse.ArgumentParser(
60
+ prog="SillyTavern Extras", description="Web API for transformers models"
61
+ )
62
+ parser.add_argument(
63
+ "--port", type=int, help="Specify the port on which the application is hosted"
64
+ )
65
+ parser.add_argument(
66
+ "--listen", action="store_true", help="Host the app on the local network"
67
+ )
68
+ parser.add_argument(
69
+ "--share", action="store_true", help="Share the app on CloudFlare tunnel"
70
+ )
71
+ parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
72
+ parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
73
+ parser.add_argument("--cuda-device", help="Specify the CUDA device to use")
74
+ parser.add_argument("--mps", "--apple", "--m1", "--m2", action="store_false", dest="cpu", help="Run the models on Apple Silicon")
75
+ parser.set_defaults(cpu=True)
76
+ parser.add_argument("--summarization-model", help="Load a custom summarization model")
77
+ parser.add_argument(
78
+ "--classification-model", help="Load a custom text classification model"
79
+ )
80
+ parser.add_argument("--captioning-model", help="Load a custom captioning model")
81
+ parser.add_argument("--embedding-model", help="Load a custom text embedding model")
82
+ parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
83
+ parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
84
+ parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
85
+ parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=True, action=argparse.BooleanOptionalAction)
86
+ parser.add_argument(
87
+ "--secure", action="store_true", help="Enforces the use of an API key"
88
+ )
89
+ sd_group = parser.add_mutually_exclusive_group()
90
+
91
+ local_sd = sd_group.add_argument_group("sd-local")
92
+ local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
93
+ local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
94
+
95
+ remote_sd = sd_group.add_argument_group("sd-remote")
96
+ remote_sd.add_argument(
97
+ "--sd-remote", action="store_true", help="Use a remote backend for SD"
98
+ )
99
+ remote_sd.add_argument(
100
+ "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
101
+ )
102
+ remote_sd.add_argument(
103
+ "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
104
+ )
105
+ remote_sd.add_argument(
106
+ "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
107
+ )
108
+ remote_sd.add_argument(
109
+ "--sd-remote-auth",
110
+ type=str,
111
+ help="Specify the username:password for the remote SD backend (if required)",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--enable-modules",
116
+ action=SplitArgs,
117
+ default=[],
118
+ help="Override a list of enabled modules",
119
+ )
120
+
121
+ args = parser.parse_args()
122
+ # [HF, Huggingface] Set port to 7860, set host to remote.
123
+ port = 7860
124
+ host = "0.0.0.0"
125
+ summarization_model = (
126
+ args.summarization_model
127
+ if args.summarization_model
128
+ else DEFAULT_SUMMARIZATION_MODEL
129
+ )
130
+ classification_model = (
131
+ args.classification_model
132
+ if args.classification_model
133
+ else DEFAULT_CLASSIFICATION_MODEL
134
+ )
135
+ captioning_model = (
136
+ args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
137
+ )
138
+ embedding_model = (
139
+ args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
140
+ )
141
+
142
+ sd_remote_gradio = os.environ.get("sd_remote_gradio")
143
+
144
+ sd_use_remote = False if args.sd_model else True
145
+ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
146
+ sd_remote_host = sd_remote_gradio if sd_remote_gradio else DEFAULT_REMOTE_SD_HOST
147
+ sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
148
+ sd_remote_ssl = args.sd_remote_ssl
149
+ sd_remote_auth = args.sd_remote_auth
150
+
151
+ modules = (
152
+ args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
153
+ )
154
+
155
+ if len(modules) == 0:
156
+ print(
157
+ f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
158
+ )
159
+ print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
160
+
161
+ # Models init
162
+ cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
163
+ device_string = cuda_device if torch.cuda.is_available() and not args.cpu else 'mps' if torch.backends.mps.is_available() and not args.cpu else 'cpu'
164
+ device = torch.device(device_string)
165
+ torch_dtype = torch.float32 if device_string != cuda_device else torch.float16
166
+
167
+ if not torch.cuda.is_available() and not args.cpu:
168
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}")
169
+ if not torch.backends.mps.is_available() and not args.cpu:
170
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-mps is not supported on this device.{Style.RESET_ALL}")
171
+
172
+
173
+ print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
174
+
175
+ if "caption" in modules:
176
+ print("Initializing an image captioning model...")
177
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
178
+ if "blip" in captioning_model:
179
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
180
+ captioning_model, torch_dtype=torch_dtype
181
+ ).to(device)
182
+ else:
183
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
184
+ captioning_model, torch_dtype=torch_dtype
185
+ ).to(device)
186
+
187
+ if "summarize" in modules:
188
+ print("Initializing a text summarization model...")
189
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
190
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
191
+ summarization_model, torch_dtype=torch_dtype
192
+ ).to(device)
193
+
194
+ if "classify" in modules:
195
+ print("Initializing a sentiment classification pipeline...")
196
+ classification_pipe = pipeline(
197
+ "text-classification",
198
+ model=classification_model,
199
+ top_k=None,
200
+ device=device,
201
+ torch_dtype=torch_dtype,
202
+ )
203
+
204
+ if "sd" in modules and not sd_use_remote:
205
+ from diffusers import StableDiffusionPipeline
206
+ from diffusers import EulerAncestralDiscreteScheduler
207
+
208
+ print("Initializing Stable Diffusion pipeline...")
209
+ sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
210
+ sd_device = torch.device(sd_device_string)
211
+ sd_torch_dtype = torch.float32 if sd_device_string != cuda_device else torch.float16
212
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
213
+ sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
214
+ ).to(sd_device)
215
+ sd_pipe.safety_checker = lambda images, clip_input: (images, False)
216
+ sd_pipe.enable_attention_slicing()
217
+ # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
218
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
219
+ sd_pipe.scheduler.config
220
+ )
221
+ elif "sd" in modules and sd_use_remote:
222
+ print("Initializing Stable Diffusion connection")
223
+ try:
224
+ sd_remote = webuiapi.WebUIApi(
225
+ host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
226
+ )
227
+ if sd_remote_auth:
228
+ username, password = sd_remote_auth.split(":")
229
+ sd_remote.set_auth(username, password)
230
+ sd_remote.util_wait_for_ready()
231
+ except Exception as e:
232
+ # remote sd from modules
233
+ print(
234
+ f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
235
+ )
236
+ modules.remove("sd")
237
+
238
+ if "tts" in modules:
239
+ print("tts module is deprecated. Please use silero-tts instead.")
240
+ modules.remove("tts")
241
+ modules.append("silero-tts")
242
+
243
+
244
+ if "silero-tts" in modules:
245
+ if not os.path.exists(SILERO_SAMPLES_PATH):
246
+ os.makedirs(SILERO_SAMPLES_PATH)
247
+ print("Initializing Silero TTS server")
248
+ from silero_api_server import tts
249
+
250
+ tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
251
+ if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
252
+ print("Generating Silero TTS samples...")
253
+ tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
254
+ tts_service.generate_samples()
255
+
256
+
257
+ if "edge-tts" in modules:
258
+ print("Initializing Edge TTS client")
259
+ import tts_edge as edge
260
+
261
+
262
+ if "chromadb" in modules:
263
+ print("Initializing ChromaDB")
264
+ import chromadb
265
+ import posthog
266
+ from chromadb.config import Settings
267
+ from sentence_transformers import SentenceTransformer
268
+
269
+ # Assume that the user wants in-memory unless a host is specified
270
+ # Also disable chromadb telemetry
271
+ posthog.capture = lambda *args, **kwargs: None
272
+ if args.chroma_host is None:
273
+ if args.chroma_persist:
274
+ chromadb_client = chromadb.PersistentClient(path=args.chroma_folder, settings=Settings(anonymized_telemetry=False))
275
+ print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
276
+ else:
277
+ chromadb_client = chromadb.EphemeralClient(Settings(anonymized_telemetry=False))
278
+ print(f"ChromaDB is running in-memory without persistence.")
279
+ else:
280
+ chroma_port=(
281
+ args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
282
+ )
283
+ chromadb_client = chromadb.HttpClient(host=args.chroma_host, port=chroma_port, settings=Settings(anonymized_telemetry=False))
284
+ print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
285
+
286
+ chromadb_embedder = SentenceTransformer(embedding_model, device=device_string)
287
+ chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
288
+
289
+ # Check if the db is connected and running, otherwise tell the user
290
+ try:
291
+ chromadb_client.heartbeat()
292
+ print("Successfully pinged ChromaDB! Your client is successfully connected.")
293
+ except:
294
+ print("Could not ping ChromaDB! If you are running remotely, please check your host and port!")
295
+
296
+ # Flask init
297
+ app = Flask(__name__)
298
+ CORS(app) # allow cross-domain requests
299
+ Compress(app) # compress responses
300
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
301
+
302
+
303
+ def require_module(name):
304
+ def wrapper(fn):
305
+ @wraps(fn)
306
+ def decorated_view(*args, **kwargs):
307
+ if name not in modules:
308
+ abort(403, "Module is disabled by config")
309
+ return fn(*args, **kwargs)
310
+
311
+ return decorated_view
312
+
313
+ return wrapper
314
+
315
+
316
+ # AI stuff
317
+ def classify_text(text: str) -> list:
318
+ output = classification_pipe(
319
+ text,
320
+ truncation=True,
321
+ max_length=classification_pipe.model.config.max_position_embeddings,
322
+ )[0]
323
+ return sorted(output, key=lambda x: x["score"], reverse=True)
324
+
325
+
326
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
327
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
328
+ device, torch_dtype
329
+ )
330
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
331
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
332
+ return caption
333
+
334
+
335
+ def summarize_chunks(text: str, params: dict) -> str:
336
+ try:
337
+ return summarize(text, params)
338
+ except IndexError:
339
+ print(
340
+ "Sequence length too large for model, cutting text in half and calling again"
341
+ )
342
+ new_params = params.copy()
343
+ new_params["max_length"] = new_params["max_length"] // 2
344
+ new_params["min_length"] = new_params["min_length"] // 2
345
+ return summarize_chunks(
346
+ text[: (len(text) // 2)], new_params
347
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
348
+
349
+
350
+ def summarize(text: str, params: dict) -> str:
351
+ # Tokenize input
352
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
353
+ token_count = len(inputs[0])
354
+
355
+ bad_words_ids = [
356
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
357
+ for bad_word in params["bad_words"]
358
+ ]
359
+ summary_ids = summarization_transformer.generate(
360
+ inputs["input_ids"],
361
+ num_beams=2,
362
+ max_new_tokens=max(token_count, int(params["max_length"])),
363
+ min_new_tokens=min(token_count, int(params["min_length"])),
364
+ repetition_penalty=float(params["repetition_penalty"]),
365
+ temperature=float(params["temperature"]),
366
+ length_penalty=float(params["length_penalty"]),
367
+ bad_words_ids=bad_words_ids,
368
+ )
369
+ summary = summarization_tokenizer.batch_decode(
370
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
371
+ )[0]
372
+ summary = normalize_string(summary)
373
+ return summary
374
+
375
+
376
+ def normalize_string(input: str) -> str:
377
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
378
+ return output
379
+
380
+
381
+ def generate_image(data: dict) -> Image:
382
+ prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
383
+
384
+ if sd_use_remote:
385
+ image = sd_remote.txt2img(
386
+ prompt=prompt,
387
+ negative_prompt=data["negative_prompt"],
388
+ sampler_name=data["sampler"],
389
+ steps=data["steps"],
390
+ cfg_scale=data["scale"],
391
+ width=data["width"],
392
+ height=data["height"],
393
+ restore_faces=data["restore_faces"],
394
+ enable_hr=data["enable_hr"],
395
+ save_images=True,
396
+ send_images=True,
397
+ do_not_save_grid=False,
398
+ do_not_save_samples=False,
399
+ ).image
400
+ else:
401
+ image = sd_pipe(
402
+ prompt=prompt,
403
+ negative_prompt=data["negative_prompt"],
404
+ num_inference_steps=data["steps"],
405
+ guidance_scale=data["scale"],
406
+ width=data["width"],
407
+ height=data["height"],
408
+ ).images[0]
409
+
410
+ image.save("./debug.png")
411
+ return image
412
+
413
+
414
+ def image_to_base64(image: Image, quality: int = 75) -> str:
415
+ buffer = BytesIO()
416
+ image.convert("RGB")
417
+ image.save(buffer, format="JPEG", quality=quality)
418
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
419
+ return img_str
420
+
421
+
422
+ ignore_auth = []
423
+ # [HF, Huggingface] Get password instead of text file.
424
+ api_key = os.environ.get("password")
425
+
426
+ def is_authorize_ignored(request):
427
+ view_func = app.view_functions.get(request.endpoint)
428
+
429
+ if view_func is not None:
430
+ if view_func in ignore_auth:
431
+ return True
432
+ return False
433
+
434
+ @app.before_request
435
+ def before_request():
436
+ # Request time measuring
437
+ request.start_time = time.time()
438
+
439
+ # Checks if an API key is present and valid, otherwise return unauthorized
440
+ # The options check is required so CORS doesn't get angry
441
+ try:
442
+ if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
443
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
444
+ if request.method == 'POST':
445
+ print(f"Incoming POST request with {request.headers.get('Authorization')}")
446
+ response = jsonify({ 'error': '401: Invalid API key' })
447
+ response.status_code = 401
448
+ return "https://(hf_name)-(space_name).hf.space/"
449
+ except Exception as e:
450
+ print(f"API key check error: {e}")
451
+ return "https://(hf_name)-(space_name).hf.space/"
452
+
453
+
454
+ @app.after_request
455
+ def after_request(response):
456
+ duration = time.time() - request.start_time
457
+ response.headers["X-Request-Duration"] = str(duration)
458
+ return response
459
+
460
+
461
+ @app.route("/", methods=["GET"])
462
+ def index():
463
+ with open("./README.md", "r", encoding="utf8") as f:
464
+ content = f.read()
465
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
466
+
467
+
468
+ @app.route("/api/extensions", methods=["GET"])
469
+ def get_extensions():
470
+ extensions = dict(
471
+ {
472
+ "extensions": [
473
+ {
474
+ "name": "not-supported",
475
+ "metadata": {
476
+ "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
477
+ "requires": [],
478
+ "assets": [],
479
+ },
480
+ }
481
+ ]
482
+ }
483
+ )
484
+ return jsonify(extensions)
485
+
486
+
487
+ @app.route("/api/caption", methods=["POST"])
488
+ @require_module("caption")
489
+ def api_caption():
490
+ data = request.get_json()
491
+
492
+ if "image" not in data or not isinstance(data["image"], str):
493
+ abort(400, '"image" is required')
494
+
495
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
496
+ image = image.convert("RGB")
497
+ image.thumbnail((512, 512))
498
+ caption = caption_image(image)
499
+ thumbnail = image_to_base64(image)
500
+ print("Caption:", caption, sep="\n")
501
+ gc.collect()
502
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
503
+
504
+
505
+ @app.route("/api/summarize", methods=["POST"])
506
+ @require_module("summarize")
507
+ def api_summarize():
508
+ data = request.get_json()
509
+
510
+ if "text" not in data or not isinstance(data["text"], str):
511
+ abort(400, '"text" is required')
512
+
513
+ params = DEFAULT_SUMMARIZE_PARAMS.copy()
514
+
515
+ if "params" in data and isinstance(data["params"], dict):
516
+ params.update(data["params"])
517
+
518
+ print("Summary input:", data["text"], sep="\n")
519
+ summary = summarize_chunks(data["text"], params)
520
+ print("Summary output:", summary, sep="\n")
521
+ gc.collect()
522
+ return jsonify({"summary": summary})
523
+
524
+
525
+ @app.route("/api/classify", methods=["POST"])
526
+ @require_module("classify")
527
+ def api_classify():
528
+ data = request.get_json()
529
+
530
+ if "text" not in data or not isinstance(data["text"], str):
531
+ abort(400, '"text" is required')
532
+
533
+ print("Classification input:", data["text"], sep="\n")
534
+ classification = classify_text(data["text"])
535
+ print("Classification output:", classification, sep="\n")
536
+ gc.collect()
537
+ return jsonify({"classification": classification})
538
+
539
+
540
+ @app.route("/api/classify/labels", methods=["GET"])
541
+ @require_module("classify")
542
+ def api_classify_labels():
543
+ classification = classify_text("")
544
+ labels = [x["label"] for x in classification]
545
+ return jsonify({"labels": labels})
546
+
547
+
548
+ @app.route("/api/image", methods=["POST"])
549
+ @require_module("sd")
550
+ def api_image():
551
+ required_fields = {
552
+ "prompt": str,
553
+ }
554
+
555
+ optional_fields = {
556
+ "steps": 30,
557
+ "scale": 6,
558
+ "sampler": "DDIM",
559
+ "width": 512,
560
+ "height": 512,
561
+ "restore_faces": False,
562
+ "enable_hr": False,
563
+ "prompt_prefix": PROMPT_PREFIX,
564
+ "negative_prompt": NEGATIVE_PROMPT,
565
+ }
566
+
567
+ data = request.get_json()
568
+
569
+ # Check required fields
570
+ for field, field_type in required_fields.items():
571
+ if field not in data or not isinstance(data[field], field_type):
572
+ abort(400, f'"{field}" is required')
573
+
574
+ # Set optional fields to default values if not provided
575
+ for field, default_value in optional_fields.items():
576
+ type_match = (
577
+ (int, float)
578
+ if isinstance(default_value, (int, float))
579
+ else type(default_value)
580
+ )
581
+ if field not in data or not isinstance(data[field], type_match):
582
+ data[field] = default_value
583
+
584
+ try:
585
+ print("SD inputs:", data, sep="\n")
586
+ image = generate_image(data)
587
+ base64image = image_to_base64(image, quality=90)
588
+ return jsonify({"image": base64image})
589
+ except RuntimeError as e:
590
+ abort(400, str(e))
591
+
592
+
593
+ @app.route("/api/image/model", methods=["POST"])
594
+ @require_module("sd")
595
+ def api_image_model_set():
596
+ data = request.get_json()
597
+
598
+ if not sd_use_remote:
599
+ abort(400, "Changing model for local sd is not supported.")
600
+ if "model" not in data or not isinstance(data["model"], str):
601
+ abort(400, '"model" is required')
602
+
603
+ old_model = sd_remote.util_get_current_model()
604
+ sd_remote.util_set_model(data["model"], find_closest=False)
605
+ # sd_remote.util_set_model(data['model'])
606
+ sd_remote.util_wait_for_ready()
607
+ new_model = sd_remote.util_get_current_model()
608
+
609
+ return jsonify({"previous_model": old_model, "current_model": new_model})
610
+
611
+
612
+ @app.route("/api/image/model", methods=["GET"])
613
+ @require_module("sd")
614
+ def api_image_model_get():
615
+ model = sd_model
616
+
617
+ if sd_use_remote:
618
+ model = sd_remote.util_get_current_model()
619
+
620
+ return jsonify({"model": model})
621
+
622
+
623
+ @app.route("/api/image/models", methods=["GET"])
624
+ @require_module("sd")
625
+ def api_image_models():
626
+ models = [sd_model]
627
+
628
+ if sd_use_remote:
629
+ models = sd_remote.util_get_model_names()
630
+
631
+ return jsonify({"models": models})
632
+
633
+
634
+ @app.route("/api/image/samplers", methods=["GET"])
635
+ @require_module("sd")
636
+ def api_image_samplers():
637
+ samplers = ["Euler a"]
638
+
639
+ if sd_use_remote:
640
+ samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
641
+
642
+ return jsonify({"samplers": samplers})
643
+
644
+
645
+ @app.route("/api/modules", methods=["GET"])
646
+ def get_modules():
647
+ return jsonify({"modules": modules})
648
+
649
+
650
+ @app.route("/api/tts/speakers", methods=["GET"])
651
+ @require_module("silero-tts")
652
+ def tts_speakers():
653
+ voices = [
654
+ {
655
+ "name": speaker,
656
+ "voice_id": speaker,
657
+ "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
658
+ }
659
+ for speaker in tts_service.get_speakers()
660
+ ]
661
+ return jsonify(voices)
662
+
663
+ # Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23
664
+ @app.route("/api/tts/generate", methods=["POST"])
665
+ @require_module("silero-tts")
666
+ def tts_generate():
667
+ voice = request.get_json()
668
+ if "text" not in voice or not isinstance(voice["text"], str):
669
+ abort(400, '"text" is required')
670
+ if "speaker" not in voice or not isinstance(voice["speaker"], str):
671
+ abort(400, '"speaker" is required')
672
+ # Remove asterisks
673
+ voice["text"] = voice["text"].replace("*", "")
674
+ try:
675
+ # Remove the destination file if it already exists
676
+ if os.path.exists('test.wav'):
677
+ os.remove('test.wav')
678
+
679
+ audio = tts_service.generate(voice["speaker"], voice["text"])
680
+ audio_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.basename(audio))
681
+
682
+ os.rename(audio, audio_file_path)
683
+ return send_file(audio_file_path, mimetype="audio/x-wav")
684
+ except Exception as e:
685
+ print(e)
686
+ abort(500, voice["speaker"])
687
+
688
+
689
+ @app.route("/api/tts/sample/<speaker>", methods=["GET"])
690
+ @require_module("silero-tts")
691
+ def tts_play_sample(speaker: str):
692
+ return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
693
+
694
+
695
+ @app.route("/api/edge-tts/list", methods=["GET"])
696
+ @require_module("edge-tts")
697
+ def edge_tts_list():
698
+ voices = edge.get_voices()
699
+ return jsonify(voices)
700
+
701
+
702
+ @app.route("/api/edge-tts/generate", methods=["POST"])
703
+ @require_module("edge-tts")
704
+ def edge_tts_generate():
705
+ data = request.get_json()
706
+ if "text" not in data or not isinstance(data["text"], str):
707
+ abort(400, '"text" is required')
708
+ if "voice" not in data or not isinstance(data["voice"], str):
709
+ abort(400, '"voice" is required')
710
+ if "rate" in data and isinstance(data['rate'], int):
711
+ rate = data['rate']
712
+ else:
713
+ rate = 0
714
+ # Remove asterisks
715
+ data["text"] = data["text"].replace("*", "")
716
+ try:
717
+ audio = edge.generate_audio(text=data["text"], voice=data["voice"], rate=rate)
718
+ return Response(audio, mimetype="audio/mpeg")
719
+ except Exception as e:
720
+ print(e)
721
+ abort(500, data["voice"])
722
+
723
+
724
+ @app.route("/api/chromadb", methods=["POST"])
725
+ @require_module("chromadb")
726
+ def chromadb_add_messages():
727
+ data = request.get_json()
728
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
729
+ abort(400, '"chat_id" is required')
730
+ if "messages" not in data or not isinstance(data["messages"], list):
731
+ abort(400, '"messages" is required')
732
+
733
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
734
+ collection = chromadb_client.get_or_create_collection(
735
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
736
+ )
737
+
738
+ documents = [m["content"] for m in data["messages"]]
739
+ ids = [m["id"] for m in data["messages"]]
740
+ metadatas = [
741
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
742
+ for m in data["messages"]
743
+ ]
744
+
745
+ collection.upsert(
746
+ ids=ids,
747
+ documents=documents,
748
+ metadatas=metadatas,
749
+ )
750
+
751
+ return jsonify({"count": len(ids)})
752
+
753
+
754
+ @app.route("/api/chromadb/purge", methods=["POST"])
755
+ @require_module("chromadb")
756
+ def chromadb_purge():
757
+ data = request.get_json()
758
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
759
+ abort(400, '"chat_id" is required')
760
+
761
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
762
+ collection = chromadb_client.get_or_create_collection(
763
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
764
+ )
765
+
766
+ count = collection.count()
767
+ collection.delete()
768
+ print("ChromaDB embeddings deleted", count)
769
+ return 'Ok', 200
770
+
771
+
772
+ @app.route("/api/chromadb/query", methods=["POST"])
773
+ @require_module("chromadb")
774
+ def chromadb_query():
775
+ data = request.get_json()
776
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
777
+ abort(400, '"chat_id" is required')
778
+ if "query" not in data or not isinstance(data["query"], str):
779
+ abort(400, '"query" is required')
780
+
781
+ if "n_results" not in data or not isinstance(data["n_results"], int):
782
+ n_results = 1
783
+ else:
784
+ n_results = data["n_results"]
785
+
786
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
787
+ collection = chromadb_client.get_or_create_collection(
788
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
789
+ )
790
+
791
+ if collection.count() == 0:
792
+ print(f"Queried empty/missing collection for {repr(data['chat_id'])}.")
793
+ return jsonify([])
794
+
795
+
796
+ n_results = min(collection.count(), n_results)
797
+ query_result = collection.query(
798
+ query_texts=[data["query"]],
799
+ n_results=n_results,
800
+ )
801
+
802
+ documents = query_result["documents"][0]
803
+ ids = query_result["ids"][0]
804
+ metadatas = query_result["metadatas"][0]
805
+ distances = query_result["distances"][0]
806
+
807
+ messages = [
808
+ {
809
+ "id": ids[i],
810
+ "date": metadatas[i]["date"],
811
+ "role": metadatas[i]["role"],
812
+ "meta": metadatas[i]["meta"],
813
+ "content": documents[i],
814
+ "distance": distances[i],
815
+ }
816
+ for i in range(len(ids))
817
+ ]
818
+
819
+ return jsonify(messages)
820
+
821
+ @app.route("/api/chromadb/multiquery", methods=["POST"])
822
+ @require_module("chromadb")
823
+ def chromadb_multiquery():
824
+ data = request.get_json()
825
+ if "chat_list" not in data or not isinstance(data["chat_list"], list):
826
+ abort(400, '"chat_list" is required and should be a list')
827
+ if "query" not in data or not isinstance(data["query"], str):
828
+ abort(400, '"query" is required')
829
+
830
+ if "n_results" not in data or not isinstance(data["n_results"], int):
831
+ n_results = 1
832
+ else:
833
+ n_results = data["n_results"]
834
+
835
+ messages = []
836
+
837
+ for chat_id in data["chat_list"]:
838
+ if not isinstance(chat_id, str):
839
+ continue
840
+
841
+ try:
842
+ chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest()
843
+ collection = chromadb_client.get_collection(
844
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
845
+ )
846
+
847
+ # Skip this chat if the collection is empty
848
+ if collection.count() == 0:
849
+ continue
850
+
851
+ n_results_per_chat = min(collection.count(), n_results)
852
+ query_result = collection.query(
853
+ query_texts=[data["query"]],
854
+ n_results=n_results_per_chat,
855
+ )
856
+ documents = query_result["documents"][0]
857
+ ids = query_result["ids"][0]
858
+ metadatas = query_result["metadatas"][0]
859
+ distances = query_result["distances"][0]
860
+
861
+ chat_messages = [
862
+ {
863
+ "id": ids[i],
864
+ "date": metadatas[i]["date"],
865
+ "role": metadatas[i]["role"],
866
+ "meta": metadatas[i]["meta"],
867
+ "content": documents[i],
868
+ "distance": distances[i],
869
+ }
870
+ for i in range(len(ids))
871
+ ]
872
+
873
+ messages.extend(chat_messages)
874
+ except Exception as e:
875
+ print(e)
876
+
877
+ #remove duplicate msgs, filter down to the right number
878
+ seen = set()
879
+ messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))]
880
+ messages = sorted(messages, key=lambda x: x['distance'])[0:n_results]
881
+
882
+ return jsonify(messages)
883
+
884
+
885
+ @app.route("/api/chromadb/export", methods=["POST"])
886
+ @require_module("chromadb")
887
+ def chromadb_export():
888
+ data = request.get_json()
889
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
890
+ abort(400, '"chat_id" is required')
891
+
892
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
893
+ try:
894
+ collection = chromadb_client.get_collection(
895
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
896
+ )
897
+ except Exception as e:
898
+ print(e)
899
+ abort(400, "Chat collection not found in chromadb")
900
+
901
+ collection_content = collection.get()
902
+ documents = collection_content.get('documents', [])
903
+ ids = collection_content.get('ids', [])
904
+ metadatas = collection_content.get('metadatas', [])
905
+
906
+ unsorted_content = [
907
+ {
908
+ "id": ids[i],
909
+ "metadata": metadatas[i],
910
+ "document": documents[i],
911
+ }
912
+ for i in range(len(ids))
913
+ ]
914
+
915
+ sorted_content = sorted(unsorted_content, key=lambda x: x['metadata']['date'])
916
+
917
+ export = {
918
+ "chat_id": data["chat_id"],
919
+ "content": sorted_content
920
+ }
921
+
922
+ return jsonify(export)
923
+
924
+ @app.route("/api/chromadb/import", methods=["POST"])
925
+ @require_module("chromadb")
926
+ def chromadb_import():
927
+ data = request.get_json()
928
+ content = data['content']
929
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
930
+ abort(400, '"chat_id" is required')
931
+
932
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
933
+ collection = chromadb_client.get_or_create_collection(
934
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
935
+ )
936
+
937
+ documents = [item['document'] for item in content]
938
+ metadatas = [item['metadata'] for item in content]
939
+ ids = [item['id'] for item in content]
940
+
941
+
942
+ collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
943
+ print(f"Imported {len(ids)} (total {collection.count()}) content entries into {repr(data['chat_id'])}")
944
+
945
+ return jsonify({"count": len(ids)})
946
+
947
+
948
+ if args.share:
949
+ from flask_cloudflared import _run_cloudflared
950
+ import inspect
951
+
952
+ sig = inspect.signature(_run_cloudflared)
953
+ sum = sum(
954
+ 1
955
+ for param in sig.parameters.values()
956
+ if param.kind == param.POSITIONAL_OR_KEYWORD
957
+ )
958
+ if sum > 1:
959
+ metrics_port = randint(8100, 9000)
960
+ cloudflare = _run_cloudflared(port, metrics_port)
961
+ else:
962
+ cloudflare = _run_cloudflared(port)
963
+ print("Running on", cloudflare)
964
+
965
+ ignore_auth.append(tts_play_sample)
966
+ app.run(host=host, port=port)
tts_edge.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import edge_tts
3
+ import asyncio
4
+
5
+
6
+ def get_voices():
7
+ voices = asyncio.run(edge_tts.list_voices())
8
+ return voices
9
+
10
+
11
+ async def _iterate_chunks(audio):
12
+ async for chunk in audio.stream():
13
+ if chunk["type"] == "audio":
14
+ yield chunk["data"]
15
+
16
+
17
+ async def _async_generator_to_list(async_gen):
18
+ result = []
19
+ async for item in async_gen:
20
+ result.append(item)
21
+ return result
22
+
23
+
24
+ def generate_audio(text: str, voice: str, rate: int) -> bytes:
25
+ sign = '+' if rate > 0 else '-'
26
+ rate = f'{sign}{abs(rate)}%'
27
+ audio = edge_tts.Communicate(text=text, voice=voice, rate=rate)
28
+ chunks = asyncio.run(_async_generator_to_list(_iterate_chunks(audio)))
29
+ buffer = io.BytesIO()
30
+
31
+ for chunk in chunks:
32
+ buffer.write(chunk)
33
+
34
+ return buffer.getvalue()