import datetime import logging import time import requests logger = logging.getLogger(__name__) class KoboldApiServerException(Exception): pass def wait_for_kai_server(koboldai_url: str, max_wait_time_seconds: int) -> None: '''Blocks until the KAI server is up.''' start_time = datetime.datetime.now() while True: try: requests.head(koboldai_url, timeout=(5, 5)) break except requests.exceptions.ConnectionError as ex: if "Connection refused" not in str(ex): raise ex abort_at = start_time + datetime.timedelta( seconds=max_wait_time_seconds) if datetime.datetime.now() > abort_at: raise TimeoutError( f"Waited for {max_wait_time_seconds} seconds but KoboldAI" " server is still not up, aborting.") time.sleep(1) def run_raw_inference_on_kai( koboldai_url: str, prompt: str, max_new_tokens: int, do_sample: bool, typical_p: float, repetition_penalty: float, **kwargs, ) -> str: endpoint = f"{koboldai_url}/api/v1/generate" payload = { "prompt": prompt, # Incredibly low max len for reasons explained in the "while True" loop # below. "max_length": 32, # Take care of parameters which are named differently between Kobold and # HuggingFace. "sampler_full_determinism": not do_sample, "typical": typical_p, "rep_pen": repetition_penalty, # Disable any pre or post-processing on the KoboldAI side, we'd rather # take care of things on our own. "frmttriminc": False, "frmtrmspch": False, "frmtrmblln": False, "frmtadsnsp": False, # Append any other generation parameters that we didn't handle manually. **kwargs, } generated_text = "" # Currently, Kobold doesn't support custom stopping criteria, and their chat # mode can't handle multi-line responses. To work around both of those, we # use the regular adventure mode generation but keep asking for more tokens # until the model starts trying to talk as the user, then we stop. attempts = 0 max_extra_attempts = 4 while attempts < (payload["max_length"] / max_new_tokens) + max_extra_attempts: attempts += 1 response = requests.post(endpoint, json=payload) if not response.ok: error_message = response.text raise KoboldApiServerException( "The KoboldAI API server returned an error" f" (HTTP status code {response.status_code}): {error_message}") inference_result = response.json()["results"][0]["text"] generated_text += inference_result # Model started to talk as us. Stop generating and return results, the # rest of the code will take care of trimming it properly. if "\nYou:" in generated_text: logger.debug("Hit `\nYou:`: `%s`", generated_text) return generated_text # For SFT: hit an EOS token. Trim and return. if generated_text.endswith("<|endoftext|>"): logger.debug("Got EOS token: `%s`", generated_text) # We add a fake generated "\nYou:" here so the trimming code doesn't # need to handle SFT and UFT models differently. return generated_text.replace("<|endoftext|>", "\nYou:") # Hit the configured generation limit. if len(generated_text.split()) >= max_new_tokens: logger.debug("Hit max length: `%s`", generated_text) return generated_text # Model still hasn't finished what it had to say. Append its output to # the prompt and feed it back in. logger.debug("Got another %s tokens, but still not done: `%s`", payload["max_length"], generated_text) payload["prompt"] += inference_result logger.debug("Exhausted generation attempts: `%s`", generated_text) return generated_text