Ll3doke / main.py
Ashrafb's picture
Update main.py
b2e6b80 verified
raw
history blame
No virus
2.35 kB
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from huggingface_hub import InferenceClient
import json
app = FastAPI()
client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
SYSTEM_MESSAGE = (
"You are a helpful, respectful and honest assistant. Always answer as helpfully "
"as possible, while being safe. Your answers should not include any harmful, "
"unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure "
"that your responses are socially unbiased and positive in nature.\n\nIf a question "
"does not make any sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer to a question, please "
"don't share false information."
"Always respond in the language of user prompt for each prompt ."
)
MAX_TOKENS = 2000
TEMPERATURE = 0.7
TOP_P = 0.95
def respond(message, history: list[tuple[str, str]]):
messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = client.chat_completion(
messages,
max_tokens=MAX_TOKENS,
stream=True,
temperature=TEMPERATURE,
top_p=TOP_P,
)
for message in response: # Handle regular iteration
yield message.choices[0].delta.content
@app.post("/generate/")
async def generate(request: Request):
form = await request.form()
prompt = form.get("prompt")
history = json.loads(form.get("history", "[]")) # Default to empty history
if not prompt:
raise HTTPException(status_code=400, detail="Prompt is required")
response_generator = respond(prompt, history)
final_response = ""
for part in response_generator:
final_response += part
return JSONResponse(content={"response": final_response})
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")