import os import re import json import requests from fastapi import FastAPI, HTTPException from pydantic import BaseModel from google.cloud import storage from google.auth import exceptions from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from io import BytesIO from dotenv import load_dotenv import uvicorn load_dotenv() API_KEY = os.getenv("API_KEY") GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HF_API_TOKEN = os.getenv("HF_API_TOKEN") def validate_bucket_name(bucket_name): if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name): raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.") return bucket_name def validate_huggingface_repo_name(repo_name): if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name): raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.") if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name: raise ValueError(f"Invalid repository name '{repo_name}'. Cannot start or end with '-' or '.', or contain '..'.") if len(repo_name) > 96: raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.") return repo_name try: GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME) credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) storage_client = storage.Client.from_service_account_info(credentials_info) bucket = storage_client.bucket(GCS_BUCKET_NAME) except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e: raise RuntimeError(f"Error al cargar credenciales o bucket: {e}") app = FastAPI() class DownloadModelRequest(BaseModel): model_name: str pipeline_task: str input_text: str class GCSHandler: def __init__(self, bucket_name): self.bucket = storage_client.bucket(bucket_name) def file_exists(self, blob_name): return self.bucket.blob(blob_name).exists() def upload_file(self, blob_name, file_stream): blob = self.bucket.blob(blob_name) blob.upload_from_file(file_stream) def download_file(self, blob_name): blob = self.bucket.blob(blob_name) if not blob.exists(): raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.") return BytesIO(blob.download_as_bytes()) def download_model_from_huggingface(model_name): model_name = validate_huggingface_repo_name(model_name) file_patterns = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] for i in range(1, 100): file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"]) for filename in file_patterns: url = f"https://huggingface.co/{model_name}/resolve/main/{filename}" headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} try: response = requests.get(url, headers=headers, stream=True) if response.status_code == 200: blob_name = f"{model_name}/{filename}" bucket.blob(blob_name).upload_from_file(BytesIO(response.content)) except Exception as e: raise HTTPException(status_code=500, detail=f"Error downloading {filename} from Hugging Face: {e}") @app.post("/predict/") async def predict(request: DownloadModelRequest): try: gcs_handler = GCSHandler(GCS_BUCKET_NAME) model_prefix = request.model_name model_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] for i in range(1, 100): model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"]) if not any(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files): download_model_from_huggingface(model_prefix) model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")} config_stream = model_files_streams.get("config.json") tokenizer_stream = model_files_streams.get("tokenizer.json") if not config_stream or not tokenizer_stream: raise HTTPException(status_code=500, detail="Required model files missing.") model = AutoModelForCausalLM.from_pretrained(config_stream) tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream) pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer) result = pipeline_(request.input_text) return {"response": result} except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {e}") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)