tuxedocat's picture
init
a7566b2
raw
history blame
4.74 kB
from functools import partial
import gradio as gr
import httpx
import subprocess
import os
from openai import OpenAI
from const import (
LLM_BASE_URL,
AUTH_CMD,
SYSTEM_PROMPTS,
EXAMPLES,
CSS,
HEADER,
FOOTER,
PLACEHOLDER,
ModelInfo,
MODELS,
)
def get_token() -> str:
try:
t = (
subprocess.run(
AUTH_CMD,
stdout=subprocess.PIPE,
stderr=subprocess.DEVNULL,
env=os.environ.copy(),
)
.stdout.decode("utf-8")
.strip()
)
assert t, "Failed to get auth token"
return t
except Exception:
raise ValueError("Failed to get auth token")
def get_headers(host: str) -> dict:
return {
"Authorization": f"Bearer {get_token()}",
"Host": host,
"Accept": "application/json",
"Content-Type": "application/json",
}
def proxy(request: httpx.Request, model_info: ModelInfo) -> httpx.Request:
request.url = request.url.copy_with(path=model_info.endpoint)
request.headers.update(get_headers(host=model_info.host))
return request
def call_llm(
message: str,
history: list[dict],
model_name: str,
system_prompt: str,
max_tokens: int,
temperature: float,
top_p: float,
):
history_openai_format = []
system_prompt_text = SYSTEM_PROMPTS[system_prompt]
if len(history) == 0:
init = {
"role": "system",
"content": system_prompt_text,
}
history_openai_format.append(init)
history_openai_format.append({"role": "user", "content": message})
else:
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
model_info = MODELS[model_name]
client = OpenAI(
api_key="",
base_url=LLM_BASE_URL,
http_client=httpx.Client(
event_hooks={
"request": [partial(proxy, model_info=model_info)],
},
verify=False,
),
)
stream = client.chat.completions.create(
model=f"/data/cyberagent/{model_info.name}",
messages=history_openai_format,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
n=1,
stream=True,
extra_body={"repetition_penalty": 1.1},
)
message = ""
for chunk in stream:
content = chunk.choices[0].delta.content or ""
message = message + content
yield message
def run():
chatbot = gr.Chatbot(
elem_id="chatbot",
scale=1,
show_copy_button=True,
placeholder=PLACEHOLDER,
layout="panel",
)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(HEADER)
gr.ChatInterface(
fn=call_llm,
stop_btn="Stop Generation",
examples=EXAMPLES,
cache_examples=False,
multimodal=False,
chatbot=chatbot,
additional_inputs_accordion=gr.Accordion(
label="Parameters", open=False, render=False
),
additional_inputs=[
gr.Dropdown(
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
label="Model",
visible=False,
),
gr.Dropdown(
choices=list(SYSTEM_PROMPTS.keys()),
value=list(SYSTEM_PROMPTS.keys())[0],
label="System Prompt",
visible=False,
),
gr.Slider(
minimum=32,
maximum=4096,
step=1,
value=1024,
label="Max tokens",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="Temperature",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=1.0,
label="Top-p",
render=False,
),
],
analytics_enabled=False,
)
gr.Markdown(FOOTER)
demo.queue(max_size=256, api_open=False)
demo.launch(share=False, quiet=True)
if __name__ == "__main__":
run()