hf-llm-api / apis /chat_api.py
Hansimov's picture
:boom: [Fix] Ignore invalid HF Token
8ab8ca6
raw
history blame
5.22 kB
import argparse
import os
import sys
import uvicorn
from fastapi import FastAPI, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse, ServerSentEvent
from utils.logger import logger
from networks.message_streamer import MessageStreamer
from messagers.message_composer import MessageComposer
from mocks.stream_chat_mocker import stream_chat_mock
class ChatAPIApp:
def __init__(self):
self.app = FastAPI(
docs_url="/",
title="HuggingFace LLM API",
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
version="1.0",
)
self.setup_routes()
def get_available_models(self):
# ANCHOR[id=available-models]: Available models
self.available_models = [
{
"id": "mixtral-8x7b",
"description": "[mistralai/Mixtral-8x7B-Instruct-v0.1]: https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1",
},
{
"id": "mistral-7b",
"description": "[mistralai/Mistral-7B-Instruct-v0.2]: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
},
{
"id": "openchat-3.5",
"description": "[openchat/openchat-3.5-1210]: https://huggingface.co/openchat/openchat-3.5-1210",
},
]
return self.available_models
def extract_api_key(
credentials: HTTPAuthorizationCredentials = Depends(
HTTPBearer(auto_error=False)
),
):
api_key = None
if credentials:
api_key = credentials.credentials
else:
api_key = os.getenv("HF_TOKEN")
if api_key:
if api_key.startswith("hf_"):
return api_key
else:
logger.warn(f"Invalid HF Token")
else:
logger.warn("Not provide HF Token!")
return None
class ChatCompletionsPostItem(BaseModel):
model: str = Field(
default="mixtral-8x7b",
description="(str) `mixtral-8x7b`",
)
messages: list = Field(
default=[{"role": "user", "content": "Hello, who are you?"}],
description="(list) Messages",
)
temperature: float = Field(
default=0.01,
description="(float) Temperature",
)
max_tokens: int = Field(
default=4096,
description="(int) Max tokens",
)
stream: bool = Field(
default=True,
description="(bool) Stream",
)
def chat_completions(
self, item: ChatCompletionsPostItem, api_key: str = Depends(extract_api_key)
):
streamer = MessageStreamer(model=item.model)
composer = MessageComposer(model=item.model)
composer.merge(messages=item.messages)
# streamer.chat = stream_chat_mock
stream_response = streamer.chat_response(
prompt=composer.merged_str,
temperature=item.temperature,
max_new_tokens=item.max_tokens,
api_key=api_key,
)
if item.stream:
event_source_response = EventSourceResponse(
streamer.chat_return_generator(stream_response),
media_type="text/event-stream",
ping=2000,
ping_message_factory=lambda: ServerSentEvent(**{"comment": ""}),
)
return event_source_response
else:
data_response = streamer.chat_return_dict(stream_response)
return data_response
def setup_routes(self):
for prefix in ["", "/v1"]:
self.app.get(
prefix + "/models",
summary="Get available models",
)(self.get_available_models)
self.app.post(
prefix + "/chat/completions",
summary="Chat completions in conversation session",
)(self.chat_completions)
class ArgParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
super(ArgParser, self).__init__(*args, **kwargs)
self.add_argument(
"-s",
"--server",
type=str,
default="0.0.0.0",
help="Server IP for HF LLM Chat API",
)
self.add_argument(
"-p",
"--port",
type=int,
default=23333,
help="Server Port for HF LLM Chat API",
)
self.add_argument(
"-d",
"--dev",
default=False,
action="store_true",
help="Run in dev mode",
)
self.args = self.parse_args(sys.argv[1:])
app = ChatAPIApp().app
if __name__ == "__main__":
args = ArgParser().args
if args.dev:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=True)
else:
uvicorn.run("__main__:app", host=args.server, port=args.port, reload=False)
# python -m apis.chat_api # [Docker] on product mode
# python -m apis.chat_api -d # [Dev] on develop mode