File size: 2,240 Bytes
6a4546d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import asyncio
import json
from threading import Thread

from websockets.server import serve

from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import generate_reply

PATH = '/api/v1/stream'


async def _handle_connection(websocket, path):

    if path != PATH:
        print(f'Streaming api: unknown path: {path}')
        return

    async for message in websocket:
        message = json.loads(message)

        prompt = message['prompt']
        generate_params = build_parameters(message)
        stopping_strings = generate_params.pop('stopping_strings')
        generate_params['stream'] = True

        generator = generate_reply(
            prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)

        # As we stream, only send the new bytes.
        skip_index = 0
        message_num = 0

        for a in generator:
            to_send = a[skip_index:]
            await websocket.send(json.dumps({
                'event': 'text_stream',
                'message_num': message_num,
                'text': to_send
            }))

            await asyncio.sleep(0)

            skip_index += len(to_send)
            message_num += 1

        await websocket.send(json.dumps({
            'event': 'stream_end',
            'message_num': message_num
        }))


async def _run(host: str, port: int):
    async with serve(_handle_connection, host, port, ping_interval=None):
        await asyncio.Future()  # run forever


def _run_server(port: int, share: bool = False):
    address = '0.0.0.0' if shared.args.listen else '127.0.0.1'

    def on_start(public_url: str):
        public_url = public_url.replace('https://', 'wss://')
        print(f'Starting streaming server at public url {public_url}{PATH}')

    if share:
        try:
            try_start_cloudflared(port, max_attempts=3, on_start=on_start)
        except Exception as e:
            print(e)
    else:
        print(f'Starting streaming server at ws://{address}:{port}{PATH}')

    asyncio.run(_run(host=address, port=port))


def start_server(port: int, share: bool = False):
    Thread(target=_run_server, args=[port, share], daemon=True).start()