Do0rMaMu's picture
Update main.py
f00a373 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, TextStreamer
import torch
import os
app = FastAPI()
# Define the request schema
class PromptRequest(BaseModel):
prompt: str
history: Optional[List[Dict[str, Any]]] = None
parameters: Optional[Dict[str, Any]] = None
@app.on_event("startup")
def load_model():
global model, tokenizer, pipe
os.environ["TRANSFORMERS_CACHE"] = "./cache"
model_path = "model/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
tokenizer = AutoTokenizer.from_pretrained(model_path)
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, cache_dir="./cache")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, streamer=streamer)
@app.post("/generate/")
async def generate_response(request: PromptRequest):
# Format the prompt with message history
history_text = ""
if request.history:
for message in request.history:
role = message.get("role", "user")
content = message.get("content", "")
history_text += f"{role}: {content}\n"
# Combine history with the current prompt
full_prompt = f"{history_text}\nUser: {request.prompt}\nAssistant:"
# Set default parameters and update with any provided
gen_params = {
"max_new_tokens": 256,
"temperature": 0.7,
"top_p": 0.9,
}
if request.parameters:
gen_params.update(request.parameters)
# Generate the response
try:
result = pipe(full_prompt, **gen_params)
return {"response": result[0]["generated_text"]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))