Spaces:
Runtime error
Runtime error
File size: 4,094 Bytes
bddc905 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import datetime
import logging
import time
import requests
logger = logging.getLogger(__name__)
class KoboldApiServerException(Exception):
pass
def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None:
'''Blocks until the KAI server is up.'''
start_time = datetime.datetime.now()
while True:
try:
requests.head(koboldai_url, timeout=(5, 5))
break
except requests.exceptions.ConnectionError as ex:
if "Connection refused" not in str(ex):
raise ex
abort_at = start_time + datetime.timedelta(
seconds=max_wait_time_seconds)
if datetime.datetime.now() > abort_at:
raise TimeoutError(
f"Waited for {max_wait_time_seconds} seconds but KoboldAI"
" server is still not up, aborting.")
time.sleep(1)
def run_raw_inference_on_kai(
koboldai_url: str,
prompt: str,
max_new_tokens: int,
do_sample: bool,
typical_p: float,
repetition_penalty: float,
**kwargs,
) -> str:
endpoint = f"{koboldai_url}/api/v1/generate"
payload = {
"prompt": prompt,
# Incredibly low max len for reasons explained in the "while True" loop
# below.
"max_length": 32,
# Take care of parameters which are named differently between Kobold and
# HuggingFace.
"sampler_full_determinism": not do_sample,
"typical": typical_p,
"rep_pen": repetition_penalty,
# Disable any pre or post-processing on the KoboldAI side, we'd rather
# take care of things on our own.
"frmttriminc": False,
"frmtrmspch": False,
"frmtrmblln": False,
"frmtadsnsp": False,
# Append any other generation parameters that we didn't handle manually.
**kwargs,
}
generated_text = ""
# Currently, Kobold doesn't support custom stopping criteria, and their chat
# mode can't handle multi-line responses. To work around both of those, we
# use the regular adventure mode generation but keep asking for more tokens
# until the model starts trying to talk as the user, then we stop.
attempts = 0
max_extra_attempts = 4
while attempts < (payload["max_length"] /
max_new_tokens) + max_extra_attempts:
attempts += 1
response = requests.post(endpoint, json=payload)
if not response.ok:
error_message = response.text
raise KoboldApiServerException(
"The KoboldAI API server returned an error"
f" (HTTP status code {response.status_code}): {error_message}")
inference_result = response.json()["results"][0]["text"]
generated_text += inference_result
# Model started to talk as us. Stop generating and return results, the
# rest of the code will take care of trimming it properly.
if "\nYou:" in generated_text:
logger.debug("Hit `\nYou:`: `%s`", generated_text)
return generated_text
# For SFT: hit an EOS token. Trim and return.
if generated_text.endswith("<|endoftext|>"):
logger.debug("Got EOS token: `%s`", generated_text)
# We add a fake generated "\nYou:" here so the trimming code doesn't
# need to handle SFT and UFT models differently.
return generated_text.replace("<|endoftext|>", "\nYou:")
# Hit the configured generation limit.
if len(generated_text.split()) >= max_new_tokens:
logger.debug("Hit max length: `%s`", generated_text)
return generated_text
# Model still hasn't finished what it had to say. Append its output to
# the prompt and feed it back in.
logger.debug("Got another %s tokens, but still not done: `%s`",
payload["max_length"], generated_text)
payload["prompt"] += inference_result
logger.debug("Exhausted generation attempts: `%s`", generated_text)
return generated_text
|