Spaces:
Running
Running
File size: 1,533 Bytes
ff9863c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from fastapi import FastAPI, Request
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import login
import os
print("Google Gemma 2 Chatbot is starting...")
# read access token from environment variable
access_token = os.getenv('HF_TOKEN')
login(access_token)
model_id = "google/gemma-2-9b-it"
print("Model loading started")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
)
print("Model loading completed")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Selected device:", device)
app = FastAPI()
@app.get('/')
def home():
return {"hello": "Bitfumes"}
@app.post('/ask')
async def ask(request: Request):
data = await request.json()
prompt = data.get("prompt")
if not prompt:
return {"error": "Prompt is missing"}
print("Device of the model:", model.device)
messages = [
{"role": "user", "content": f"{prompt}"},
]
print("Messages:", messages)
print("Tokenizer process started")
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda")
print("Tokenizer process completed")
print("Model process started")
outputs = model.generate(**input_ids, max_new_tokens=256)
print("Tokenizer decode process started")
answer = tokenizer.decode(outputs[0]).split("<end_of_turn>")[1].strip()
return {"answer": answer} |