ecyht2's picture
feat: Added other controls
dcb96e9 verified
raw
history blame
4.29 kB
"""Python Application Script for AI chatbot using LLAMA CPP."""
import logging
import gradio as gr
from llama_cpp import Llama
# Setting up enviornment
log_level = os.environ.get("LOG_LEVEL", "WARNING")
logging.basicConfig(encoding='utf-8', level=log_level)
# Default System Prompt
DEFAULT_SYSTEM_PROMPT = os.environ.get("DEFAULT_SYSTEM", "You are Dolphin, a helpful AI assistant.")
# Model Path
model_path = "model.gguf"
logging.debug("Model Path: %s", model_path)
logging.info("Loading Moddel")
llm = Llama(model_path=model_path, n_ctx=4000, n_threads=2, chat_format="chatml")
def generate(
message: str,
history: list[tuple[str, str]],
system_prompt: str,
temperature: float = 0.1,
max_tokens: int = 512,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
):
"""Function to generate text.
:param message: The new user prompt.
:param history: The history of the chat session.
:param system: The system prompt of the model.
:param temperature: The temperature parameter for the model.
:param max_tokens: The maximum amount of tokens to use for the model.
:param top_p: The top p value for the model.
:param repetition_penalty: The repetition penalty for the model.
"""
logging.info("Generating Text")
logging.debug("message: %s", message)
logging.debug("history: %s", history)
logging.debug("system: %s", system)
logging.debug("temperature: %s", temperature)
logging.debug("max_tokens: %s", max_tokens)
logging.debug("top_p: %s", top_p)
logging.debug("repetion_penalty: %s", repetition_penalty)
# Formatting Prompt
logging.info("Formatting Prompt")
formatted_prompt = [{"role": "system", "content": system_prompt}]
for user_prompt, bot_response in history:
formatted_prompt.append({"role": "user", "content": user_prompt})
formatted_prompt.append({"role": "assistant", "content": bot_response})
formatted_prompt.append({"role": "user", "content": message})
logging.debug("Formatted Prompt: %s", formatted_prompt)
# Generating Response
logging.info("Generating Response")
stream_response = llm.create_chat_completion(
messages=formatted_prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
repeat_penalty=repetition_penalty,
stream=True,
)
# Parsing Response
logging.info("Parsing Response")
response = ""
for chunk in stream_response:
if (
len(chunk["choices"][0]["delta"]) != 0
and "content" in chunk["choices"][0]["delta"]
):
response += chunk["choices"][0]["delta"]["content"]
logging.debug("Response: %s", response)
yield response
additional_inputs = [
gr.Textbox(
label="System Prompt",
max_lines=1,
interactive=True,
value=DEFAULT_SYSTEM_PROMPT,
),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=256,
minimum=0,
maximum=1048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
examples = []
logging.info("Creating Chatbot")
mychatbot = gr.Chatbot(avatar_images=["user.png", "botsc.png"], bubble_full_width=False, show_label=False, show_copy_button=True, likeable=True,)
logging.info("Creating Chat Interface")
iface = gr.ChatInterface(
fn=generate,
chatbot=mychatbot,
additional_inputs=additional_inputs,
examples=examples,
concurrency_limit=20,
title="LLAMA CPP Template"
)
logging.info("Starting Application")
iface.launch(show_api=False, server_name="0.0.0.0")