Spaces:
Running
Running
from functools import partial | |
import gradio as gr | |
import httpx | |
import subprocess | |
import os | |
from openai import OpenAI | |
from const import ( | |
LLM_BASE_URL, | |
AUTH_CMD, | |
SYSTEM_PROMPTS, | |
EXAMPLES, | |
CSS, | |
HEADER, | |
FOOTER, | |
PLACEHOLDER, | |
ModelInfo, | |
MODELS, | |
) | |
def get_token() -> str: | |
try: | |
t = ( | |
subprocess.run( | |
AUTH_CMD, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.DEVNULL, | |
env=os.environ.copy(), | |
) | |
.stdout.decode("utf-8") | |
.strip() | |
) | |
assert t, "Failed to get auth token" | |
return t | |
except Exception: | |
raise ValueError("Failed to get auth token") | |
def get_headers(host: str) -> dict: | |
return { | |
"Authorization": f"Bearer {get_token()}", | |
"Host": host, | |
"Accept": "application/json", | |
"Content-Type": "application/json", | |
} | |
def proxy(request: httpx.Request, model_info: ModelInfo) -> httpx.Request: | |
request.url = request.url.copy_with(path=model_info.endpoint) | |
request.headers.update(get_headers(host=model_info.host)) | |
return request | |
def call_llm( | |
message: str, | |
history: list[dict], | |
model_name: str, | |
system_prompt: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
history_openai_format = [] | |
system_prompt_text = SYSTEM_PROMPTS[system_prompt] | |
if len(history) == 0: | |
init = { | |
"role": "system", | |
"content": system_prompt_text, | |
} | |
history_openai_format.append(init) | |
history_openai_format.append({"role": "user", "content": message}) | |
else: | |
for human, assistant in history: | |
history_openai_format.append({"role": "user", "content": human}) | |
history_openai_format.append({"role": "assistant", "content": assistant}) | |
history_openai_format.append({"role": "user", "content": message}) | |
model_info = MODELS[model_name] | |
client = OpenAI( | |
api_key="", | |
base_url=LLM_BASE_URL, | |
http_client=httpx.Client( | |
event_hooks={ | |
"request": [partial(proxy, model_info=model_info)], | |
}, | |
verify=False, | |
), | |
) | |
stream = client.chat.completions.create( | |
model=f"/data/cyberagent/{model_info.name}", | |
messages=history_openai_format, | |
temperature=temperature, | |
top_p=top_p, | |
max_tokens=max_tokens, | |
n=1, | |
stream=True, | |
extra_body={"repetition_penalty": 1.1}, | |
) | |
message = "" | |
for chunk in stream: | |
content = chunk.choices[0].delta.content or "" | |
message = message + content | |
yield message | |
def run(): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
scale=1, | |
show_copy_button=True, | |
placeholder=PLACEHOLDER, | |
layout="panel", | |
) | |
with gr.Blocks(fill_height=True) as demo: | |
gr.Markdown(HEADER) | |
gr.ChatInterface( | |
fn=call_llm, | |
stop_btn="Stop Generation", | |
examples=EXAMPLES, | |
cache_examples=False, | |
multimodal=False, | |
chatbot=chatbot, | |
additional_inputs_accordion=gr.Accordion( | |
label="Parameters", open=False, render=False | |
), | |
additional_inputs=[ | |
gr.Dropdown( | |
choices=list(MODELS.keys()), | |
value=list(MODELS.keys())[0], | |
label="Model", | |
visible=False, | |
), | |
gr.Dropdown( | |
choices=list(SYSTEM_PROMPTS.keys()), | |
value=list(SYSTEM_PROMPTS.keys())[0], | |
label="System Prompt", | |
visible=False, | |
), | |
gr.Slider( | |
minimum=32, | |
maximum=4096, | |
step=1, | |
value=1024, | |
label="Max tokens", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=0.3, | |
label="Temperature", | |
render=False, | |
), | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
value=1.0, | |
label="Top-p", | |
render=False, | |
), | |
], | |
analytics_enabled=False, | |
) | |
gr.Markdown(FOOTER) | |
demo.queue(max_size=256, api_open=False) | |
demo.launch(share=False, quiet=True) | |
if __name__ == "__main__": | |
run() | |