|
import time |
|
import json |
|
from typing import List, Literal |
|
|
|
from fastapi import FastAPI |
|
from fastapi.responses import StreamingResponse, JSONResponse |
|
from pydantic import BaseModel |
|
|
|
from huggingface_hub import InferenceClient |
|
|
|
app = FastAPI() |
|
client = InferenceClient( |
|
"mistralai/Mistral-7B-Instruct-v0.2" |
|
) |
|
|
|
class Message(BaseModel): |
|
role: Literal["user", "assistant"] |
|
content: str |
|
|
|
class Payload(BaseModel): |
|
stream: bool = False |
|
model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2" |
|
messages: List[Message] |
|
temperature: float = 0.9 |
|
frequency_penalty: float = 1.2 |
|
top_p: float = 0.9 |
|
|
|
async def stream(iter): |
|
while True: |
|
try: |
|
value = await asyncio.to_thread(iter.__next__) |
|
yield value |
|
except StopIteration: |
|
break |
|
|
|
def format_prompt(messages: List[Message]): |
|
prompt = "<s>" |
|
|
|
for message in messages: |
|
if message['role'] == 'user': |
|
prompt += f"[INST] {message['content']} [/INST]" |
|
else: |
|
prompt += f" {message['content']}</s> " |
|
|
|
return prompt |
|
|
|
def make_chunk_obj(i, delta, fr): |
|
return { |
|
"id": str(time.time_ns()), |
|
"object": "chat.completion.chunk", |
|
"created": round(time.time()), |
|
"model": "mistral-7b-instruct-v0.2", |
|
"system_fingerprint": "wtf", |
|
"choices": [ |
|
{ |
|
"index": i, |
|
"delta": { |
|
"content": delta |
|
}, |
|
"finish_reason": fr |
|
} |
|
] |
|
} |
|
|
|
def generate( |
|
messages, |
|
temperature=0.9, |
|
max_new_tokens=256, |
|
top_p=0.95, |
|
repetition_penalty=1.0, |
|
): |
|
temperature = float(temperature) |
|
if temperature < 1e-2: |
|
temperature = 1e-2 |
|
top_p = float(top_p) |
|
|
|
generate_kwargs = dict( |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
do_sample=True, |
|
seed=None |
|
) |
|
|
|
formatted_prompt = format_prompt(messages) |
|
|
|
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) |
|
|
|
for response in stream: |
|
t = response.token.text |
|
yield t if t != "</s>" else "" |
|
|
|
|
|
|
|
def generate_norm(*args) -> str: |
|
t = "" |
|
for chunk in generate(*args): |
|
t += chunk |
|
return t |
|
|
|
@app.get('/') |
|
async def index(): |
|
return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" }) |
|
|
|
@app.post('/chat/completions') |
|
async def c_cmp(payload: Payload): |
|
if not payload.stream: |
|
return JSONResponse( |
|
{ |
|
"id": str(time.time_ns()), |
|
"object": "chat.completion", |
|
"created": round(time.time()), |
|
"model": payload.model, |
|
"system_fingerprint": "wtf", |
|
"choices": [ |
|
{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": generate_norm( |
|
payload.model_dump()['messages'], |
|
payload.temperature, |
|
4096, |
|
payload.top_p, |
|
payload.frequency_penalty |
|
) |
|
} |
|
} |
|
] |
|
} |
|
) |
|
|
|
|
|
def streamer(): |
|
text = "" |
|
result = generate( |
|
payload.model_dump()['messages'], |
|
payload.temperature, |
|
4096, |
|
payload.top_p, |
|
payload.frequency_penalty, |
|
) |
|
for i, item in enumerate(result): |
|
yield item |
|
|
|
return StreamingResponse(streamer()) |
|
|