from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModelForCausalLM import torch from huggingface_hub import login import os print("Google Gemma 2 Chatbot is starting...") # read access token from environment variable access_token = os.getenv('HF_TOKEN') login(access_token) model_id = "google/gemma-2-9b-it" print("Model loading started") tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, ) print("Model loading completed") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Selected device:", device) app = FastAPI() @app.get('/') def home(): return {"hello": "Bitfumes"} @app.post('/ask') async def ask(request: Request): data = await request.json() prompt = data.get("prompt") if not prompt: return {"error": "Prompt is missing"} print("Device of the model:", model.device) messages = [ {"role": "user", "content": f"{prompt}"}, ] print("Messages:", messages) print("Tokenizer process started") input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True).to("cuda") print("Tokenizer process completed") print("Model process started") outputs = model.generate(**input_ids, max_new_tokens=256) print("Tokenizer decode process started") answer = tokenizer.decode(outputs[0]).split("")[1].strip() return {"answer": answer}