from fastapi import FastAPI, Request from fastapi.templating import Jinja2Templates from huggingface_hub import InferenceClient app = FastAPI() templates = Jinja2Templates(directory="templates") client = InferenceClient( "mistralai/Mistral-7B-Instruct-v0.1" ) async def format_prompt(message, history): prompt = "" for user_prompt, bot_response in history: prompt += f"[INST] {user_prompt} [/INST]" prompt += f" {bot_response} " prompt += f"[INST] {message} [/INST]" return prompt async def generate( prompt: str, temperature: float = 0.9, max_new_tokens: int = 256, top_p: float = 0.95, repetition_penalty: float = 1.0, ): temperature = float(temperature) if temperature < 1e-2: temperature = 1e-2 top_p = float(top_p) generate_kwargs = { "temperature": temperature, "max_new_tokens": max_new_tokens, "top_p": top_p, "repetition_penalty": repetition_penalty, "do_sample": True, "seed": 42, } formatted_prompt = await format_prompt(prompt, []) response = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False, return_full_text=True) return response @app.get("/") async def index(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/generate/") async def chatbot_response(prompt: str): response = await generate(prompt) return {"response": response}