Reiner4 doctord98 commited on
Commit
3d7d81b
0 Parent(s):

Duplicate from doctord98/extras

Browse files

Co-authored-by: doctord98 <doctord98@users.noreply.huggingface.co>

Files changed (7) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +21 -0
  3. README.md +11 -0
  4. constants.py +50 -0
  5. requirements.txt +18 -0
  6. server.py +854 -0
  7. tts_edge.py +34 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install -r requirements.txt
7
+
8
+ RUN mkdir /.cache && chmod -R 777 /.cache
9
+ RUN mkdir .chroma && chmod -R 777 .chroma
10
+
11
+ COPY . .
12
+
13
+
14
+ RUN chmod -R 777 /app
15
+
16
+ RUN --mount=type=secret,id=password,mode=0444,required=true \
17
+ cat /run/secrets/password > /test
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: smut
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: doctord98/extras
10
+ ---
11
+ doctord98 is your lord and savior
constants.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ # Also try: 'slauw87/bart-large-cnn-samsum'
3
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ElectrifAi_v14"
4
+ # Also try: 'nateraw/bert-base-uncased-emotion'
5
+ DEFAULT_CLASSIFICATION_MODEL = "joeddav/distilbert-base-uncased-go-emotions-student"
6
+ # Also try: 'Salesforce/blip-image-captioning-base'
7
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
8
+ # Also try: 'ckpt/anything-v4.5-vae-swapped'
9
+ DEFAULT_SD_MODEL = "sinkinai/MeinaHentai-v3-baked-vae"
10
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
+ DEFAULT_REMOTE_SD_PORT = 7860
13
+ DEFAULT_CHROMA_PORT = 8000
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "Doctor is your lord and savior"
16
+ # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
+ DEFAULT_SUMMARIZE_PARAMS = {
18
+ "temperature": 1.0,
19
+ "repetition_penalty": 1.0,
20
+ "max_length": 500,
21
+ "min_length": 200,
22
+ "length_penalty": 1.5,
23
+ "bad_words": [
24
+ "\n",
25
+ '"',
26
+ "*",
27
+ "[",
28
+ "]",
29
+ "{",
30
+ "}",
31
+ ":",
32
+ "(",
33
+ ")",
34
+ "<",
35
+ ">",
36
+ "Â",
37
+ "The text ends",
38
+ "The story ends",
39
+ "The text is",
40
+ "The story is",
41
+ ],
42
+ }
43
+
44
+ PROMPT_PREFIX = "best quality, absurdres, "
45
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
46
+ error hands, bad hands, error fingers, bad fingers, missing fingers
47
+ error legs, bad legs, multiple legs, missing legs, error lighting,
48
+ error shadow, error reflection, text, error, extra digit, fewer digits,
49
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
50
+ signature, watermark, username, blurry"""
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cors
3
+ flask-compress
4
+ markdown
5
+ Pillow
6
+ colorama
7
+ webuiapi
8
+ --extra-index-url https://download.pytorch.org/whl/cu117
9
+ torch==2.0.0+cu117
10
+ torchvision==0.15.1
11
+ torchaudio==2.0.1+cu117
12
+ accelerate
13
+ transformers==4.28.1
14
+ diffusers==0.16.1
15
+ silero-api-server
16
+ chromadb
17
+ sentence_transformers
18
+ edge-tts
server.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ Response,
7
+ render_template_string,
8
+ abort,
9
+ send_from_directory,
10
+ send_file,
11
+ )
12
+ from flask_cors import CORS
13
+ from flask_compress import Compress
14
+ import markdown
15
+ import argparse
16
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
17
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
18
+ from transformers import BlipForConditionalGeneration
19
+ import unicodedata
20
+ import torch
21
+ import time
22
+ import os
23
+ import gc
24
+ import secrets
25
+ from PIL import Image
26
+ import base64
27
+ from io import BytesIO
28
+ from random import randint
29
+ import webuiapi
30
+ import hashlib
31
+ from constants import *
32
+ from colorama import Fore, Style, init as colorama_init
33
+
34
+ colorama_init()
35
+
36
+
37
+ class SplitArgs(argparse.Action):
38
+ def __call__(self, parser, namespace, values, option_string=None):
39
+ setattr(
40
+ namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
41
+ )
42
+
43
+
44
+ # Script arguments
45
+ parser = argparse.ArgumentParser(
46
+ prog="SillyTavern Extras", description="Web API for transformers models"
47
+ )
48
+ parser.add_argument(
49
+ "--port", type=int, help="Specify the port on which the application is hosted"
50
+ )
51
+ parser.add_argument(
52
+ "--listen", action="store_true", help="Host the app on the local network"
53
+ )
54
+ parser.add_argument(
55
+ "--share", action="store_true", help="Share the app on CloudFlare tunnel"
56
+ )
57
+ parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
58
+ parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
59
+ parser.set_defaults(cpu=True)
60
+ parser.add_argument("--summarization-model", help="Load a custom summarization model")
61
+ parser.add_argument(
62
+ "--classification-model", help="Load a custom text classification model"
63
+ )
64
+ parser.add_argument("--captioning-model", help="Load a custom captioning model")
65
+ parser.add_argument("--embedding-model", help="Load a custom text embedding model")
66
+ parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
67
+ parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
68
+ parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
69
+ parser.add_argument('--chroma-persist', help="Chromadb persistence", default=True, action=argparse.BooleanOptionalAction)
70
+ parser.add_argument(
71
+ "--secure", action="store_true", help="Enforces the use of an API key"
72
+ )
73
+
74
+ sd_group = parser.add_mutually_exclusive_group()
75
+
76
+ local_sd = sd_group.add_argument_group("sd-local")
77
+ local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
78
+ local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
79
+
80
+ remote_sd = sd_group.add_argument_group("sd-remote")
81
+ remote_sd.add_argument(
82
+ "--sd-remote", action="store_true", help="Use a remote backend for SD"
83
+ )
84
+ remote_sd.add_argument(
85
+ "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
86
+ )
87
+ remote_sd.add_argument(
88
+ "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
89
+ )
90
+ remote_sd.add_argument(
91
+ "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
92
+ )
93
+ remote_sd.add_argument(
94
+ "--sd-remote-auth",
95
+ type=str,
96
+ help="Specify the username:password for the remote SD backend (if required)",
97
+ )
98
+
99
+ parser.add_argument(
100
+ "--enable-modules",
101
+ action=SplitArgs,
102
+ default=[],
103
+ help="Override a list of enabled modules",
104
+ )
105
+
106
+ args = parser.parse_args()
107
+
108
+ port = 7860
109
+ host = "0.0.0.0"
110
+ summarization_model = (
111
+ args.summarization_model
112
+ if args.summarization_model
113
+ else DEFAULT_SUMMARIZATION_MODEL
114
+ )
115
+ classification_model = (
116
+ args.classification_model
117
+ if args.classification_model
118
+ else DEFAULT_CLASSIFICATION_MODEL
119
+ )
120
+ captioning_model = (
121
+ args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
122
+ )
123
+ embedding_model = (
124
+ args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
125
+ )
126
+
127
+ sd_use_remote = False if args.sd_model else True
128
+ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
129
+ sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
130
+ sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
131
+ sd_remote_ssl = args.sd_remote_ssl
132
+ sd_remote_auth = args.sd_remote_auth
133
+
134
+ modules = (
135
+ args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
136
+ )
137
+
138
+ if len(modules) == 0:
139
+ print(
140
+ f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
141
+ )
142
+ print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
143
+
144
+ # Models init
145
+ device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
146
+ device = torch.device(device_string)
147
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
148
+
149
+ if not torch.cuda.is_available() and not args.cpu:
150
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device. Defaulting to CPU mode.{Style.RESET_ALL}")
151
+
152
+ print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
153
+
154
+ if "caption" in modules:
155
+ print("Initializing an image captioning model...")
156
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
157
+ if "blip" in captioning_model:
158
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
159
+ captioning_model, torch_dtype=torch_dtype
160
+ ).to(device)
161
+ else:
162
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
163
+ captioning_model, torch_dtype=torch_dtype
164
+ ).to(device)
165
+
166
+ if "summarize" in modules:
167
+ print("Initializing a text summarization model...")
168
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
169
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
170
+ summarization_model, torch_dtype=torch_dtype
171
+ ).to(device)
172
+
173
+ if "classify" in modules:
174
+ print("Initializing a sentiment classification pipeline...")
175
+ classification_pipe = pipeline(
176
+ "text-classification",
177
+ model=classification_model,
178
+ top_k=None,
179
+ device=device,
180
+ torch_dtype=torch_dtype,
181
+ )
182
+
183
+ if "sd" in modules and not sd_use_remote:
184
+ from diffusers import StableDiffusionPipeline
185
+ from diffusers import EulerAncestralDiscreteScheduler
186
+
187
+ print("Initializing Stable Diffusion pipeline")
188
+ sd_device_string = (
189
+ "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
190
+ )
191
+ sd_device = torch.device(sd_device_string)
192
+ sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
193
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
194
+ sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
195
+ ).to(sd_device)
196
+ sd_pipe.safety_checker = lambda images, clip_input: (images, False)
197
+ sd_pipe.enable_attention_slicing()
198
+ # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
199
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
200
+ sd_pipe.scheduler.config
201
+ )
202
+ elif "sd" in modules and sd_use_remote:
203
+ print("Initializing Stable Diffusion connection")
204
+ try:
205
+ sd_remote = webuiapi.WebUIApi(
206
+ host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
207
+ )
208
+ if sd_remote_auth:
209
+ username, password = sd_remote_auth.split(":")
210
+ sd_remote.set_auth(username, password)
211
+ sd_remote.util_wait_for_ready()
212
+ except Exception as e:
213
+ # remote sd from modules
214
+ print(
215
+ f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
216
+ )
217
+ modules.remove("sd")
218
+
219
+ if "tts" in modules:
220
+ print("tts module is deprecated. Please use silero-tts instead.")
221
+ modules.remove("tts")
222
+ modules.append("silero-tts")
223
+
224
+
225
+ if "silero-tts" in modules:
226
+ if not os.path.exists(SILERO_SAMPLES_PATH):
227
+ os.makedirs(SILERO_SAMPLES_PATH)
228
+ print("Initializing Silero TTS server")
229
+ from silero_api_server import tts
230
+
231
+ tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
232
+ if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
233
+ print("Generating Silero TTS samples...")
234
+ tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
235
+ tts_service.generate_samples()
236
+
237
+
238
+ if "edge-tts" in modules:
239
+ print("Initializing Edge TTS client")
240
+ import tts_edge as edge
241
+
242
+
243
+ if "chromadb" in modules:
244
+ print("Initializing ChromaDB")
245
+ import chromadb
246
+ import posthog
247
+ from chromadb.config import Settings
248
+ from sentence_transformers import SentenceTransformer
249
+
250
+ # Assume that the user wants in-memory unless a host is specified
251
+ # Also disable chromadb telemetry
252
+ posthog.capture = lambda *args, **kwargs: None
253
+ if args.chroma_host is None:
254
+ if args.chroma_persist:
255
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False, persist_directory=args.chroma_folder, chroma_db_impl='duckdb+parquet'))
256
+ print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
257
+ else:
258
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
259
+ print(f"ChromaDB is running in-memory without persistence.")
260
+ else:
261
+ chroma_port=(
262
+ args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
263
+ )
264
+ chromadb_client = chromadb.Client(
265
+ Settings(
266
+ anonymized_telemetry=False,
267
+ chroma_api_impl="rest",
268
+ chroma_server_host=args.chroma_host,
269
+ chroma_server_http_port=chroma_port
270
+ )
271
+ )
272
+ print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
273
+
274
+ chromadb_embedder = SentenceTransformer(embedding_model)
275
+ chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
276
+
277
+ # Check if the db is connected and running, otherwise tell the user
278
+ try:
279
+ chromadb_client.heartbeat()
280
+ print("Successfully pinged ChromaDB! Your client is successfully connected.")
281
+ except:
282
+ print("Could not ping ChromaDB! If you are running remotely, please check your host and port!")
283
+
284
+ # Flask init
285
+ app = Flask(__name__)
286
+ CORS(app) # allow cross-domain requests
287
+ Compress(app) # compress responses
288
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
289
+
290
+
291
+ def require_module(name):
292
+ def wrapper(fn):
293
+ @wraps(fn)
294
+ def decorated_view(*args, **kwargs):
295
+ if name not in modules:
296
+ abort(403, "Module is disabled by config")
297
+ return fn(*args, **kwargs)
298
+
299
+ return decorated_view
300
+
301
+ return wrapper
302
+
303
+
304
+ # AI stuff
305
+ def classify_text(text: str) -> list:
306
+ output = classification_pipe(
307
+ text,
308
+ truncation=True,
309
+ max_length=classification_pipe.model.config.max_position_embeddings,
310
+ )[0]
311
+ return sorted(output, key=lambda x: x["score"], reverse=True)
312
+
313
+
314
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
315
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
316
+ device, torch_dtype
317
+ )
318
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
319
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
320
+ return caption
321
+
322
+
323
+ def summarize_chunks(text: str, params: dict) -> str:
324
+ try:
325
+ return summarize(text, params)
326
+ except IndexError:
327
+ print(
328
+ "Sequence length too large for model, cutting text in half and calling again"
329
+ )
330
+ new_params = params.copy()
331
+ new_params["max_length"] = new_params["max_length"] // 2
332
+ new_params["min_length"] = new_params["min_length"] // 2
333
+ return summarize_chunks(
334
+ text[: (len(text) // 2)], new_params
335
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
336
+
337
+
338
+ def summarize(text: str, params: dict) -> str:
339
+ # Tokenize input
340
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
341
+ token_count = len(inputs[0])
342
+
343
+ bad_words_ids = [
344
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
345
+ for bad_word in params["bad_words"]
346
+ ]
347
+ summary_ids = summarization_transformer.generate(
348
+ inputs["input_ids"],
349
+ num_beams=2,
350
+ max_new_tokens=max(token_count, int(params["max_length"])),
351
+ min_new_tokens=min(token_count, int(params["min_length"])),
352
+ repetition_penalty=float(params["repetition_penalty"]),
353
+ temperature=float(params["temperature"]),
354
+ length_penalty=float(params["length_penalty"]),
355
+ bad_words_ids=bad_words_ids,
356
+ )
357
+ summary = summarization_tokenizer.batch_decode(
358
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
359
+ )[0]
360
+ summary = normalize_string(summary)
361
+ return summary
362
+
363
+
364
+ def normalize_string(input: str) -> str:
365
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
366
+ return output
367
+
368
+
369
+ def generate_image(data: dict) -> Image:
370
+ prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
371
+
372
+ if sd_use_remote:
373
+ image = sd_remote.txt2img(
374
+ prompt=prompt,
375
+ negative_prompt=data["negative_prompt"],
376
+ sampler_name=data["sampler"],
377
+ steps=data["steps"],
378
+ cfg_scale=data["scale"],
379
+ width=data["width"],
380
+ height=data["height"],
381
+ restore_faces=data["restore_faces"],
382
+ enable_hr=data["enable_hr"],
383
+ save_images=True,
384
+ send_images=True,
385
+ do_not_save_grid=False,
386
+ do_not_save_samples=False,
387
+ ).image
388
+ else:
389
+ image = sd_pipe(
390
+ prompt=prompt,
391
+ negative_prompt=data["negative_prompt"],
392
+ num_inference_steps=data["steps"],
393
+ guidance_scale=data["scale"],
394
+ width=data["width"],
395
+ height=data["height"],
396
+ ).images[0]
397
+
398
+ image.save("./debug.png")
399
+ return image
400
+
401
+
402
+ def image_to_base64(image: Image, quality: int = 75) -> str:
403
+ buffer = BytesIO()
404
+ image.convert("RGB")
405
+ image.save(buffer, format="JPEG", quality=quality)
406
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
407
+ return img_str
408
+
409
+ ignore_auth = []
410
+
411
+ api_key = os.environ.get("password")
412
+
413
+ def is_authorize_ignored(request):
414
+ view_func = app.view_functions.get(request.endpoint)
415
+
416
+ if view_func is not None:
417
+ if view_func in ignore_auth:
418
+ return True
419
+ return False
420
+
421
+ @app.before_request
422
+ def before_request():
423
+ # Request time measuring
424
+ request.start_time = time.time()
425
+
426
+ # Checks if an API key is present and valid, otherwise return unauthorized
427
+ # The options check is required so CORS doesn't get angry
428
+ try:
429
+ if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
430
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
431
+ response = jsonify({ 'error': '401: Invalid API key' })
432
+ response.status_code = 401
433
+ return "this space is only for doctord98 but you can duplicate it and enjoy"
434
+ except Exception as e:
435
+ print(f"API key check error: {e}")
436
+ return "this space is only for doctord98 but you can duplicate it and enjoy"
437
+
438
+
439
+ @app.after_request
440
+ def after_request(response):
441
+ duration = time.time() - request.start_time
442
+ response.headers["X-Request-Duration"] = str(duration)
443
+ return response
444
+
445
+
446
+ @app.route("/", methods=["GET"])
447
+ def index():
448
+ with open("./README.md", "r", encoding="utf8") as f:
449
+ content = f.read()
450
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
451
+
452
+
453
+ @app.route("/api/extensions", methods=["GET"])
454
+ def get_extensions():
455
+ extensions = dict(
456
+ {
457
+ "extensions": [
458
+ {
459
+ "name": "not-supported",
460
+ "metadata": {
461
+ "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
462
+ "requires": [],
463
+ "assets": [],
464
+ },
465
+ }
466
+ ]
467
+ }
468
+ )
469
+ return jsonify(extensions)
470
+
471
+
472
+ @app.route("/api/caption", methods=["POST"])
473
+ @require_module("caption")
474
+ def api_caption():
475
+ data = request.get_json()
476
+
477
+ if "image" not in data or not isinstance(data["image"], str):
478
+ abort(400, '"image" is required')
479
+
480
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
481
+ image = image.convert("RGB")
482
+ image.thumbnail((512, 512))
483
+ caption = caption_image(image)
484
+ thumbnail = image_to_base64(image)
485
+ print("Caption:", caption, sep="\n")
486
+ gc.collect()
487
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
488
+
489
+
490
+ @app.route("/api/summarize", methods=["POST"])
491
+ @require_module("summarize")
492
+ def api_summarize():
493
+ data = request.get_json()
494
+
495
+ if "text" not in data or not isinstance(data["text"], str):
496
+ abort(400, '"text" is required')
497
+
498
+ params = DEFAULT_SUMMARIZE_PARAMS.copy()
499
+
500
+ if "params" in data and isinstance(data["params"], dict):
501
+ params.update(data["params"])
502
+
503
+ print("Summary input:", data["text"], sep="\n")
504
+ summary = summarize_chunks(data["text"], params)
505
+ print("Summary output:", summary, sep="\n")
506
+ gc.collect()
507
+ return jsonify({"summary": summary})
508
+
509
+
510
+ @app.route("/api/classify", methods=["POST"])
511
+ @require_module("classify")
512
+ def api_classify():
513
+ data = request.get_json()
514
+
515
+ if "text" not in data or not isinstance(data["text"], str):
516
+ abort(400, '"text" is required')
517
+
518
+ print("Classification input:", data["text"], sep="\n")
519
+ classification = classify_text(data["text"])
520
+ print("Classification output:", classification, sep="\n")
521
+ gc.collect()
522
+ return jsonify({"classification": classification})
523
+
524
+
525
+ @app.route("/api/classify/labels", methods=["GET"])
526
+ @require_module("classify")
527
+ def api_classify_labels():
528
+ classification = classify_text("")
529
+ labels = [x["label"] for x in classification]
530
+ return jsonify({"labels": labels})
531
+
532
+
533
+ @app.route("/api/image", methods=["POST"])
534
+ @require_module("sd")
535
+ def api_image():
536
+ required_fields = {
537
+ "prompt": str,
538
+ }
539
+
540
+ optional_fields = {
541
+ "steps": 30,
542
+ "scale": 6,
543
+ "sampler": "DDIM",
544
+ "width": 512,
545
+ "height": 512,
546
+ "restore_faces": False,
547
+ "enable_hr": False,
548
+ "prompt_prefix": PROMPT_PREFIX,
549
+ "negative_prompt": NEGATIVE_PROMPT,
550
+ }
551
+
552
+ data = request.get_json()
553
+
554
+ # Check required fields
555
+ for field, field_type in required_fields.items():
556
+ if field not in data or not isinstance(data[field], field_type):
557
+ abort(400, f'"{field}" is required')
558
+
559
+ # Set optional fields to default values if not provided
560
+ for field, default_value in optional_fields.items():
561
+ type_match = (
562
+ (int, float)
563
+ if isinstance(default_value, (int, float))
564
+ else type(default_value)
565
+ )
566
+ if field not in data or not isinstance(data[field], type_match):
567
+ data[field] = default_value
568
+
569
+ try:
570
+ print("SD inputs:", data, sep="\n")
571
+ image = generate_image(data)
572
+ base64image = image_to_base64(image, quality=90)
573
+ return jsonify({"image": base64image})
574
+ except RuntimeError as e:
575
+ abort(400, str(e))
576
+
577
+
578
+ @app.route("/api/image/model", methods=["POST"])
579
+ @require_module("sd")
580
+ def api_image_model_set():
581
+ data = request.get_json()
582
+
583
+ if not sd_use_remote:
584
+ abort(400, "Changing model for local sd is not supported.")
585
+ if "model" not in data or not isinstance(data["model"], str):
586
+ abort(400, '"model" is required')
587
+
588
+ old_model = sd_remote.util_get_current_model()
589
+ sd_remote.util_set_model(data["model"], find_closest=False)
590
+ # sd_remote.util_set_model(data['model'])
591
+ sd_remote.util_wait_for_ready()
592
+ new_model = sd_remote.util_get_current_model()
593
+
594
+ return jsonify({"previous_model": old_model, "current_model": new_model})
595
+
596
+
597
+ @app.route("/api/image/model", methods=["GET"])
598
+ @require_module("sd")
599
+ def api_image_model_get():
600
+ model = sd_model
601
+
602
+ if sd_use_remote:
603
+ model = sd_remote.util_get_current_model()
604
+
605
+ return jsonify({"model": model})
606
+
607
+
608
+ @app.route("/api/image/models", methods=["GET"])
609
+ @require_module("sd")
610
+ def api_image_models():
611
+ models = [sd_model]
612
+
613
+ if sd_use_remote:
614
+ models = sd_remote.util_get_model_names()
615
+
616
+ return jsonify({"models": models})
617
+
618
+
619
+ @app.route("/api/image/samplers", methods=["GET"])
620
+ @require_module("sd")
621
+ def api_image_samplers():
622
+ samplers = ["Euler a"]
623
+
624
+ if sd_use_remote:
625
+ samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
626
+
627
+ return jsonify({"samplers": samplers})
628
+
629
+
630
+ @app.route("/api/modules", methods=["GET"])
631
+ def get_modules():
632
+ return jsonify({"modules": modules})
633
+
634
+
635
+ @app.route("/api/tts/speakers", methods=["GET"])
636
+ @require_module("silero-tts")
637
+ def tts_speakers():
638
+ voices = [
639
+ {
640
+ "name": speaker,
641
+ "voice_id": speaker,
642
+ "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
643
+ }
644
+ for speaker in tts_service.get_speakers()
645
+ ]
646
+ return jsonify(voices)
647
+
648
+
649
+ @app.route("/api/tts/generate", methods=["POST"])
650
+ @require_module("silero-tts")
651
+ def tts_generate():
652
+ voice = request.get_json()
653
+ if "text" not in voice or not isinstance(voice["text"], str):
654
+ abort(400, '"text" is required')
655
+ if "speaker" not in voice or not isinstance(voice["speaker"], str):
656
+ abort(400, '"speaker" is required')
657
+ # Remove asterisks
658
+ voice["text"] = voice["text"].replace("*", "")
659
+ try:
660
+ audio = tts_service.generate(voice["speaker"], voice["text"])
661
+ return send_file(audio, mimetype="audio/x-wav")
662
+ except Exception as e:
663
+ print(e)
664
+ abort(500, voice["speaker"])
665
+
666
+
667
+ @app.route("/api/tts/sample/<speaker>", methods=["GET"])
668
+ @require_module("silero-tts")
669
+ def tts_play_sample(speaker: str):
670
+ return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
671
+
672
+
673
+ @app.route("/api/edge-tts/list", methods=["GET"])
674
+ @require_module("edge-tts")
675
+ def edge_tts_list():
676
+ voices = edge.get_voices()
677
+ return jsonify(voices)
678
+
679
+
680
+ @app.route("/api/edge-tts/generate", methods=["POST"])
681
+ @require_module("edge-tts")
682
+ def edge_tts_generate():
683
+ data = request.get_json()
684
+ if "text" not in data or not isinstance(data["text"], str):
685
+ abort(400, '"text" is required')
686
+ if "voice" not in data or not isinstance(data["voice"], str):
687
+ abort(400, '"voice" is required')
688
+ if "rate" in data and isinstance(data['rate'], int):
689
+ rate = data['rate']
690
+ else:
691
+ rate = 0
692
+ # Remove asterisks
693
+ data["text"] = data["text"].replace("*", "")
694
+ try:
695
+ audio = edge.generate_audio(text=data["text"], voice=data["voice"], rate=rate)
696
+ return Response(audio, mimetype="audio/mpeg")
697
+ except Exception as e:
698
+ print(e)
699
+ abort(500, data["voice"])
700
+
701
+
702
+ @app.route("/api/chromadb", methods=["POST"])
703
+ @require_module("chromadb")
704
+ def chromadb_add_messages():
705
+ data = request.get_json()
706
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
707
+ abort(400, '"chat_id" is required')
708
+ if "messages" not in data or not isinstance(data["messages"], list):
709
+ abort(400, '"messages" is required')
710
+
711
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
712
+ collection = chromadb_client.get_or_create_collection(
713
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
714
+ )
715
+
716
+ documents = [m["content"] for m in data["messages"]]
717
+ ids = [m["id"] for m in data["messages"]]
718
+ metadatas = [
719
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
720
+ for m in data["messages"]
721
+ ]
722
+
723
+ collection.upsert(
724
+ ids=ids,
725
+ documents=documents,
726
+ metadatas=metadatas,
727
+ )
728
+
729
+ return jsonify({"count": len(ids)})
730
+
731
+
732
+ @app.route("/api/chromadb/purge", methods=["POST"])
733
+ @require_module("chromadb")
734
+ def chromadb_purge():
735
+ data = request.get_json()
736
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
737
+ abort(400, '"chat_id" is required')
738
+
739
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
740
+ collection = chromadb_client.get_or_create_collection(
741
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
742
+ )
743
+
744
+ count = collection.count()
745
+ collection.delete()
746
+ #Write deletion to persistent folder
747
+ chromadb_client.persist()
748
+ print("ChromaDB embeddings deleted", count)
749
+ return 'Ok', 200
750
+
751
+
752
+ @app.route("/api/chromadb/query", methods=["POST"])
753
+ @require_module("chromadb")
754
+ def chromadb_query():
755
+ data = request.get_json()
756
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
757
+ abort(400, '"chat_id" is required')
758
+ if "query" not in data or not isinstance(data["query"], str):
759
+ abort(400, '"query" is required')
760
+
761
+ if "n_results" not in data or not isinstance(data["n_results"], int):
762
+ n_results = 1
763
+ else:
764
+ n_results = data["n_results"]
765
+
766
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
767
+ collection = chromadb_client.get_or_create_collection(
768
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
769
+ )
770
+
771
+ n_results = min(collection.count(), n_results)
772
+ query_result = collection.query(
773
+ query_texts=[data["query"]],
774
+ n_results=n_results,
775
+ )
776
+
777
+ documents = query_result["documents"][0]
778
+ ids = query_result["ids"][0]
779
+ metadatas = query_result["metadatas"][0]
780
+ distances = query_result["distances"][0]
781
+
782
+ messages = [
783
+ {
784
+ "id": ids[i],
785
+ "date": metadatas[i]["date"],
786
+ "role": metadatas[i]["role"],
787
+ "meta": metadatas[i]["meta"],
788
+ "content": documents[i],
789
+ "distance": distances[i],
790
+ }
791
+ for i in range(len(ids))
792
+ ]
793
+
794
+ return jsonify(messages)
795
+
796
+
797
+ @app.route("/api/chromadb/export", methods=["POST"])
798
+ @require_module("chromadb")
799
+ def chromadb_export():
800
+ data = request.get_json()
801
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
802
+ abort(400, '"chat_id" is required')
803
+
804
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
805
+ collection = chromadb_client.get_or_create_collection(
806
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
807
+ )
808
+ collection_content = collection.get()
809
+ documents = collection_content.get('documents', [])
810
+ ids = collection_content.get('ids', [])
811
+ metadatas = collection_content.get('metadatas', [])
812
+
813
+ unsorted_content = [
814
+ {
815
+ "id": ids[i],
816
+ "metadata": metadatas[i],
817
+ "document": documents[i],
818
+ }
819
+ for i in range(len(ids))
820
+ ]
821
+
822
+ sorted_content = sorted(unsorted_content, key=lambda x: x['metadata']['date'])
823
+
824
+ export = {
825
+ "chat_id": data["chat_id"],
826
+ "content": sorted_content
827
+ }
828
+
829
+ return jsonify(export)
830
+
831
+ @app.route("/api/chromadb/import", methods=["POST"])
832
+ @require_module("chromadb")
833
+ def chromadb_import():
834
+ data = request.get_json()
835
+ content = data['content']
836
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
837
+ abort(400, '"chat_id" is required')
838
+
839
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
840
+ collection = chromadb_client.get_or_create_collection(
841
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
842
+ )
843
+
844
+ documents = [item['document'] for item in content]
845
+ metadatas = [item['metadata'] for item in content]
846
+ ids = [item['id'] for item in content]
847
+
848
+
849
+ collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
850
+
851
+ return jsonify({"count": len(ids)})
852
+
853
+ ignore_auth.append(tts_play_sample)
854
+ app.run(host=host, port=port)
tts_edge.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import edge_tts
3
+ import asyncio
4
+
5
+
6
+ def get_voices():
7
+ voices = asyncio.run(edge_tts.list_voices())
8
+ return voices
9
+
10
+
11
+ async def _iterate_chunks(audio):
12
+ async for chunk in audio.stream():
13
+ if chunk["type"] == "audio":
14
+ yield chunk["data"]
15
+
16
+
17
+ async def _async_generator_to_list(async_gen):
18
+ result = []
19
+ async for item in async_gen:
20
+ result.append(item)
21
+ return result
22
+
23
+
24
+ def generate_audio(text: str, voice: str, rate: int) -> bytes:
25
+ sign = '+' if rate > 0 else '-'
26
+ rate = f'{sign}{abs(rate)}%'
27
+ audio = edge_tts.Communicate(text=text, voice=voice, rate=rate)
28
+ chunks = asyncio.run(_async_generator_to_list(_iterate_chunks(audio)))
29
+ buffer = io.BytesIO()
30
+
31
+ for chunk in chunks:
32
+ buffer.write(chunk)
33
+
34
+ return buffer.getvalue()