Ll3doke / main.py
Ashrafb's picture
Rename app.py to main.py
71b54be verified
raw
history blame
No virus
2.06 kB
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from huggingface_hub import InferenceClient
import random
API_URL = "https://api-inference.huggingface.co/models/"
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.1"
)
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
def format_prompt(message, history):
prompt = "<s>"
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response}</s> "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(prompt, history, temperature=0.9, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=random.randint(0, 10**7),
)
formatted_prompt = format_prompt(prompt, history)
output = ""
for response in client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=False):
output += response.token.text
return output
@app.get("/", response_class=HTMLResponse)
async def home(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/generate/")
async def generate_chat(request: Request, prompt: str = Form(...), history: str = Form(...), temperature: float = Form(0.9), max_new_tokens: int = Form(512), top_p: float = Form(0.95), repetition_penalty: float = Form(1.0)):
history = eval(history) # Convert history string back to list
response = generate(prompt, history, temperature, max_new_tokens, top_p, repetition_penalty)
return {"response": response}