File size: 4,094 Bytes
bddc905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import datetime
import logging
import time

import requests

logger = logging.getLogger(__name__)


class KoboldApiServerException(Exception):
    pass


def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None:
    '''Blocks until the KAI server is up.'''
    start_time = datetime.datetime.now()

    while True:
        try:
            requests.head(koboldai_url, timeout=(5, 5))
            break
        except requests.exceptions.ConnectionError as ex:
            if "Connection refused" not in str(ex):
                raise ex

            abort_at = start_time + datetime.timedelta(
                seconds=max_wait_time_seconds)

            if datetime.datetime.now() > abort_at:
                raise TimeoutError(
                    f"Waited for {max_wait_time_seconds} seconds but KoboldAI"
                    " server is still not up, aborting.")

            time.sleep(1)


def run_raw_inference_on_kai(
    koboldai_url: str,
    prompt: str,
    max_new_tokens: int,
    do_sample: bool,
    typical_p: float,
    repetition_penalty: float,
    **kwargs,
) -> str:
    endpoint = f"{koboldai_url}/api/v1/generate"
    payload = {
        "prompt": prompt,

        # Incredibly low max len for reasons explained in the "while True" loop
        # below.
        "max_length": 32,

        # Take care of parameters which are named differently between Kobold and
        # HuggingFace.
        "sampler_full_determinism": not do_sample,
        "typical": typical_p,
        "rep_pen": repetition_penalty,

        # Disable any pre or post-processing on the KoboldAI side, we'd rather
        # take care of things on our own.
        "frmttriminc": False,
        "frmtrmspch": False,
        "frmtrmblln": False,
        "frmtadsnsp": False,

        # Append any other generation parameters that we didn't handle manually.
        **kwargs,
    }
    generated_text = ""

    # Currently, Kobold doesn't support custom stopping criteria, and their chat
    # mode can't handle multi-line responses. To work around both of those, we
    # use the regular adventure mode generation but keep asking for more tokens
    # until the model starts trying to talk as the user, then we stop.
    attempts = 0
    max_extra_attempts = 4
    while attempts < (payload["max_length"] /
                      max_new_tokens) + max_extra_attempts:
        attempts += 1
        response = requests.post(endpoint, json=payload)
        if not response.ok:
            error_message = response.text
            raise KoboldApiServerException(
                "The KoboldAI API server returned an error"
                f" (HTTP status code {response.status_code}): {error_message}")

        inference_result = response.json()["results"][0]["text"]
        generated_text += inference_result

        # Model started to talk as us. Stop generating and return results, the
        # rest of the code will take care of trimming it properly.
        if "\nYou:" in generated_text:
            logger.debug("Hit `\nYou:`: `%s`", generated_text)
            return generated_text

        # For SFT: hit an EOS token. Trim and return.
        if generated_text.endswith("<|endoftext|>"):
            logger.debug("Got EOS token: `%s`", generated_text)

            # We add a fake generated "\nYou:" here so the trimming code doesn't
            # need to handle SFT and UFT models differently.
            return generated_text.replace("<|endoftext|>", "\nYou:")

        # Hit the configured generation limit.
        if len(generated_text.split()) >= max_new_tokens:
            logger.debug("Hit max length: `%s`", generated_text)
            return generated_text

        # Model still hasn't finished what it had to say. Append its output to
        # the prompt and feed it back in.
        logger.debug("Got another %s tokens, but still not done: `%s`",
                     payload["max_length"], generated_text)
        payload["prompt"] += inference_result

    logger.debug("Exhausted generation attempts: `%s`", generated_text)
    return generated_text