import os import json import threading import logging from google.cloud import storage from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from pydantic import BaseModel from fastapi import FastAPI, HTTPException import requests import uvicorn from dotenv import load_dotenv load_dotenv() API_KEY = os.getenv("API_KEY") GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HF_API_TOKEN = os.getenv("HF_API_TOKEN") logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) storage_client = storage.Client.from_service_account_info(credentials_info) bucket = storage_client.bucket(GCS_BUCKET_NAME) app = FastAPI() class DownloadModelRequest(BaseModel): model_name: str pipeline_task: str input_text: str class GCSHandler: def __init__(self, bucket_name): self.bucket = storage_client.bucket(bucket_name) def file_exists(self, blob_name): return self.bucket.blob(blob_name).exists() def create_folder_if_not_exists(self, folder_name): if not self.file_exists(folder_name): self.bucket.blob(folder_name + "/").upload_from_string("") def upload_file(self, blob_name, file_stream): self.create_folder_if_not_exists(os.path.dirname(blob_name)) blob = self.bucket.blob(blob_name) blob.upload_from_file(file_stream) def download_file(self, blob_name): blob = self.bucket.blob(blob_name) if not blob.exists(): raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.") return blob.open("rb") def generate_signed_url(self, blob_name, expiration=3600): blob = self.bucket.blob(blob_name) return blob.generate_signed_url(expiration=expiration) def download_model_from_huggingface(model_name): url = f"https://huggingface.co/{model_name}/tree/main" headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} response = requests.get(url, headers=headers) if response.status_code == 200: model_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] for file_name in model_files: file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}" file_content = requests.get(file_url).content blob_name = f"models/{model_name}/{file_name}" bucket.blob(blob_name).upload_from_string(file_content) else: raise HTTPException(status_code=404, detail="Error accessing Hugging Face model files.") def download_and_verify_model(model_name): model_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] gcs_handler = GCSHandler(GCS_BUCKET_NAME) if not all(gcs_handler.file_exists(f"models/{model_name}/{file}") for file in model_files): download_model_from_huggingface(model_name) def load_model_from_gcs(model_name): model_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] gcs_handler = GCSHandler(GCS_BUCKET_NAME) model_files_streams = { file: gcs_handler.download_file(f"models/{model_name}/{file}") for file in model_files if gcs_handler.file_exists(f"models/{model_name}/{file}") } model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors") tokenizer_stream = model_files_streams.get("tokenizer.json") config_stream = model_files_streams.get("config.json") model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream) tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) return model, tokenizer def load_model(model_name): gcs_handler = GCSHandler(GCS_BUCKET_NAME) try: return load_model_from_gcs(model_name) except HTTPException: download_and_verify_model(model_name) return load_model_from_gcs(model_name) @app.on_event("startup") async def startup(): gcs_handler = GCSHandler(GCS_BUCKET_NAME) blobs = list(bucket.list_blobs(prefix="models/")) model_names = set(blob.name.split("/")[1] for blob in blobs) def download_model_thread(model_name): try: download_and_verify_model(model_name) except Exception as e: logger.error(f"Error downloading model '{model_name}': {e}") threads = [threading.Thread(target=download_model_thread, args=(model_name,)) for model_name in model_names] for thread in threads: thread.start() for thread in threads: thread.join() @app.post("/predict/") async def predict(request: DownloadModelRequest): model_name = request.model_name pipeline_task = request.pipeline_task input_text = request.input_text model, tokenizer = load_model(model_name) pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer) result = pipe(input_text) return {"result": result} def download_all_models_in_background(): models_url = "https://huggingface.co/api/models" response = requests.get(models_url) if response.status_code == 200: models = response.json() for model in models: download_model_from_huggingface(model["id"]) def run_in_background(): threading.Thread(target=download_all_models_in_background, daemon=True).start() if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)