import time import json from typing import List, Literal from fastapi import FastAPI from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel from gradio_client import Client app = FastAPI() client = Client("AWeirdDev/mistral-7b-instruct-v0.2") class Message(BaseModel): role: Literal["user", "assistant", "system"] content: str class Payload(BaseModel): stream: bool = False model: Literal["mistral-7b-instruct-v0.2"] = "mistral-7b-instruct-v0.2" messages: List[Message] temperature: float = 0.9 frequency_penalty: float = 1.2 top_p: float = 0.9 async def stream(iter): while True: try: value = await asyncio.to_thread(iter.__next__) yield value except StopIteration: break def make_chunk_obj(i, delta, fr): return { "id": str(time.time_ns()), "object": "chat.completion.chunk", "created": round(time.time()), "model": "mistral-7b-instruct-v0.2", "system_fingerprint": "wtf", "choices": [ { "index": i, "delta": { "content": delta }, "finish_reason": fr } ] } @app.get('/') async def index(): return JSONResponse({ "message": "hello", "url": "https://aweirddev-mistral-7b-instruct-v0-2-leicht.hf.space" }) @app.post('/chat/completions') async def c_cmp(payload: Payload): if not payload.stream: return JSONResponse( { "id": str(time.time_ns()), "object": "chat.completion", "created": round(time.time()), "model": payload.model, "system_fingerprint": "wtf", "choices": [ { "index": 0, "message": { "role": "assistant", "content": client.predict( payload.model_dump()['messages'], payload.temperature, 4096, payload.top_p, payload.frequency_penalty, api_name="/chat" ) } } ] } ) def streamer(): text = "" result = client.submit( payload.model_dump()['messages'], payload.temperature, # float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component 4096, # float (numeric value between 0 and 1048) in 'Max new tokens' Slider component payload.top_p, # float (numeric value between 0.0 and 1) in 'Top-p (nucleus sampling)' Slider component payload.frequency_penalty, # float (numeric value between 1.0 and 2.0) in 'Repetition penalty' Slider component api_name="/chat" ) for i, item in enumerate(result): delta = item[len(text):] yield "data: " + json.dumps( make_chunk_obj(i, delta, None) ) text = item yield "data: " + json.dumps(make_chunk_obj(i, "", "stop")) yield "data: [END]" return StreamingResponse(streamer())