Spaces:
Running
Running
import json | |
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer | |
from threading import Thread | |
from util import build_parameters, try_start_cloudflared | |
from gpt4 import ask_gpt | |
# from modules import shared | |
# from modules.chat import generate_chat_reply | |
# from modules.text_generation import encode, generate_reply, stop_everything_event | |
class Handler(BaseHTTPRequestHandler): | |
def do_GET(self): | |
if self.path == '/api/v1/model': | |
self.send_response(200) | |
self.end_headers() | |
response = json.dumps({ | |
'result': 'GPT4 mindsdb OpenAI original' | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def do_POST(self): | |
content_length = int(self.headers['Content-Length']) | |
body = json.loads(self.rfile.read(content_length).decode('utf-8')) | |
if self.path == '/api/v1/generate': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
prompt = body['prompt'] | |
generate_params = build_parameters(body) | |
stopping_strings = generate_params.pop('stopping_strings') | |
generate_params['stream'] = False | |
# generator = generate_reply( | |
# prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) | |
answer = ask_gpt(prompt) | |
response = json.dumps({ | |
'results': [{ | |
'text': answer | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/chat': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
user_input = body['user_input'] | |
history = body['history'] | |
regenerate = body.get('regenerate', False) | |
_continue = body.get('_continue', False) | |
generate_params = build_parameters(body, chat=True) | |
generate_params['stream'] = False | |
generator = 'error' | |
# generator = generate_chat_reply( | |
# user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False) | |
answer = history | |
for a in generator: | |
answer = a | |
response = json.dumps({ | |
'results': [{ | |
'history': answer | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/stop-stream': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
# stop_everything_event() | |
response = json.dumps({ | |
'results': 'error' | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
elif self.path == '/api/v1/token-count': | |
self.send_response(200) | |
self.send_header('Content-Type', 'application/json') | |
self.end_headers() | |
# tokens = encode(body['prompt'])[0] | |
response = json.dumps({ | |
'results': [{ | |
'tokens': 'error' | |
}] | |
}) | |
self.wfile.write(response.encode('utf-8')) | |
else: | |
self.send_error(404) | |
def _run_server(port: int, share: bool = False): | |
address = '0.0.0.0' if 0 else '127.0.0.1' | |
server = ThreadingHTTPServer((address, port), Handler) | |
def on_start(public_url: str): | |
with open('main.md', 'r') as f: | |
text = f.read() | |
text = text.replace("[located in the logs of this container]", f"{public_url}/api") | |
with open('main.md', 'w') as f: | |
f.write(text) | |
print(f'Starting non-streaming server at public url {public_url}/api') | |
if share: | |
try: | |
try_start_cloudflared(port, max_attempts=3, on_start=on_start) | |
except Exception: | |
pass | |
else: | |
print( | |
f'Starting API at http://{address}:{port}/api') | |
server.serve_forever() | |
def start_server(port: int, share: bool = False): | |
Thread(target=_run_server, args=[port, share], daemon=True).start() | |