import time import json from typing import List, Literal from fastapi import FastAPI from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel from huggingface_hub import InferenceClient app = FastAPI() client = InferenceClient( "mistralai/Mistral-7B-Instruct-v0.2" ) class Message(BaseModel): role: Literal["user", "assistant"] content: str class Payload(BaseModel): stream: bool = False model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2" messages: List[Message] temperature: float = 0.9 frequency_penalty: float = 1.2 top_p: float = 0.9 async def stream(iter): while True: try: value = await asyncio.to_thread(iter.__next__) yield value except StopIteration: break def format_prompt(messages: List[Message]): prompt = "" for message in messages: if message['role'] == 'user': prompt += f"[INST] {message['content']} [/INST]" else: prompt += f" {message['content']} " return prompt def make_chunk_obj(i, delta, fr): return { "id": str(time.time_ns()), "object": "chat.completion.chunk", "created": round(time.time()), "model": "mistral-7b-instruct-v0.2", "system_fingerprint": "wtf", "choices": [ { "index": i, "delta": { "content": delta }, "finish_reason": fr } ] } def generate( messages, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, ): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, seed=None ) formatted_prompt = format_prompt(messages) stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False) for response in stream: t = response.token.text yield t if t != "" else "" #return output def generate_norm(*args) -> str: t = "" for chunk in generate(*args): t += chunk return t @app.get('/') async def index(): return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" }) @app.post('/chat/completions') async def c_cmp(payload: Payload): if not payload.stream: return JSONResponse( { "id": str(time.time_ns()), "object": "chat.completion", "created": round(time.time()), "model": payload.model, "system_fingerprint": "wtf", "choices": [ { "index": 0, "message": { "role": "assistant", "content": generate_norm( payload.model_dump()['messages'], payload.temperature, 4096, payload.top_p, payload.frequency_penalty ) } } ] } ) def streamer(): text = "" result = generate( payload.model_dump()['messages'], payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component 4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component payload.frequency_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component ) for i, item in enumerate(result): yield item return StreamingResponse(streamer())