Spaces:
Runtime error
Runtime error
import datetime | |
import logging | |
import time | |
import requests | |
logger = logging.getLogger(__name__) | |
class KoboldApiServerException(Exception): | |
pass | |
def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None: | |
'''Blocks until the KAI server is up.''' | |
start_time = datetime.datetime.now() | |
while True: | |
try: | |
requests.head(koboldai_url, timeout=(5, 5)) | |
break | |
except requests.exceptions.ConnectionError as ex: | |
if "Connection refused" not in str(ex): | |
raise ex | |
abort_at = start_time + datetime.timedelta( | |
seconds=max_wait_time_seconds) | |
if datetime.datetime.now() > abort_at: | |
raise TimeoutError( | |
f"Waited for {max_wait_time_seconds} seconds but KoboldAI" | |
" server is still not up, aborting.") | |
time.sleep(1) | |
def run_raw_inference_on_kai( | |
koboldai_url: str, | |
prompt: str, | |
max_new_tokens: int, | |
do_sample: bool, | |
typical_p: float, | |
repetition_penalty: float, | |
**kwargs, | |
) -> str: | |
endpoint = f"{koboldai_url}/api/v1/generate" | |
payload = { | |
"prompt": prompt, | |
# Incredibly low max len for reasons explained in the "while True" loop | |
# below. | |
"max_length": 32, | |
# Take care of parameters which are named differently between Kobold and | |
# HuggingFace. | |
"sampler_full_determinism": not do_sample, | |
"typical": typical_p, | |
"rep_pen": repetition_penalty, | |
# Disable any pre or post-processing on the KoboldAI side, we'd rather | |
# take care of things on our own. | |
"frmttriminc": False, | |
"frmtrmspch": False, | |
"frmtrmblln": False, | |
"frmtadsnsp": False, | |
# Append any other generation parameters that we didn't handle manually. | |
**kwargs, | |
} | |
generated_text = "" | |
# Currently, Kobold doesn't support custom stopping criteria, and their chat | |
# mode can't handle multi-line responses. To work around both of those, we | |
# use the regular adventure mode generation but keep asking for more tokens | |
# until the model starts trying to talk as the user, then we stop. | |
attempts = 0 | |
max_extra_attempts = 4 | |
while attempts < (payload["max_length"] / | |
max_new_tokens) + max_extra_attempts: | |
attempts += 1 | |
response = requests.post(endpoint, json=payload) | |
if not response.ok: | |
error_message = response.text | |
raise KoboldApiServerException( | |
"The KoboldAI API server returned an error" | |
f" (HTTP status code {response.status_code}): {error_message}") | |
inference_result = response.json()["results"][0]["text"] | |
generated_text += inference_result | |
# Model started to talk as us. Stop generating and return results, the | |
# rest of the code will take care of trimming it properly. | |
if "\nYou:" in generated_text: | |
logger.debug("Hit `\nYou:`: `%s`", generated_text) | |
return generated_text | |
# For SFT: hit an EOS token. Trim and return. | |
if generated_text.endswith("<|endoftext|>"): | |
logger.debug("Got EOS token: `%s`", generated_text) | |
# We add a fake generated "\nYou:" here so the trimming code doesn't | |
# need to handle SFT and UFT models differently. | |
return generated_text.replace("<|endoftext|>", "\nYou:") | |
# Hit the configured generation limit. | |
if len(generated_text.split()) >= max_new_tokens: | |
logger.debug("Hit max length: `%s`", generated_text) | |
return generated_text | |
# Model still hasn't finished what it had to say. Append its output to | |
# the prompt and feed it back in. | |
logger.debug("Got another %s tokens, but still not done: `%s`", | |
payload["max_length"], generated_text) | |
payload["prompt"] += inference_result | |
logger.debug("Exhausted generation attempts: `%s`", generated_text) | |
return generated_text | |