import os import json import logging import uuid import threading import io from fastapi import FastAPI, HTTPException from pydantic import BaseModel from google.cloud import storage from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline import uvicorn import torch import requests from safetensors import safe_open from dotenv import load_dotenv load_dotenv() API_KEY = os.getenv("API_KEY") GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HF_API_TOKEN = os.getenv("HF_API_TOKEN") logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) storage_client = storage.Client.from_service_account_info(credentials_info) bucket = storage_client.bucket(GCS_BUCKET_NAME) logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}") except (json.JSONDecodeError, KeyError, ValueError) as e: logger.error(f"Error al cargar las credenciales o bucket: {e}") raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}") app = FastAPI() class DownloadModelRequest(BaseModel): model_name: str pipeline_task: str input_text: str class GCSHandler: def __init__(self, bucket_name): self.bucket = storage_client.bucket(bucket_name) def file_exists(self, blob_name): exists = self.bucket.blob(blob_name).exists() return exists def download_file(self, blob_name): blob = self.bucket.blob(blob_name) if not blob.exists(): raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.") return blob.download_as_bytes() def upload_file(self, blob_name, file_data): blob = self.bucket.blob(blob_name) blob.upload_from_file(file_data) def generate_signed_url(self, blob_name, expiration=3600): blob = self.bucket.blob(blob_name) url = blob.generate_signed_url(expiration=expiration) return url def load_model_from_gcs(model_name: str, model_files: list): gcs_handler = GCSHandler(GCS_BUCKET_NAME) model_blobs = {file: gcs_handler.download_file(f"{model_name}/{file}") for file in model_files} model_stream = model_blobs.get("pytorch_model.bin") or model_blobs.get("model.safetensors") config_stream = model_blobs.get("config.json") tokenizer_stream = model_blobs.get("tokenizer.json") if "safetensors" in model_stream.name: model = load_safetensors_model(model_stream) else: model = AutoModelForCausalLM.from_pretrained(io.BytesIO(model_stream), config=config_stream) tokenizer = AutoTokenizer.from_pretrained(io.BytesIO(tokenizer_stream)) return model, tokenizer def load_safetensors_model(model_stream): with safe_open(io.BytesIO(model_stream), framework="pt") as model_data: model = torch.load(model_data) return model def get_model_files_from_gcs(model_name: str): gcs_handler = GCSHandler(GCS_BUCKET_NAME) blob_list = list(gcs_handler.bucket.list_blobs(prefix=f"{model_name}/")) model_files = [blob.name for blob in blob_list if any(part in blob.name for part in ["pytorch_model", "model"]) and "index" not in blob.name] model_files = sorted(model_files) return model_files def download_model_from_huggingface(model_name): url = f"https://huggingface.co/{model_name}/tree/main" headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} try: response = requests.get(url, headers=headers) if response.status_code == 200: model_files = [ "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors", ] def download_file(file_name): file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}" file_content = requests.get(file_url).content blob_name = f"{model_name}/{file_name}" blob = bucket.blob(blob_name) blob.upload_from_string(file_content) threads = [threading.Thread(target=download_file, args=(file_name,)) for file_name in model_files] for thread in threads: thread.start() for thread in threads: thread.join() else: raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.") except Exception as e: raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}") @app.post("/predict/") async def predict(request: DownloadModelRequest): try: gcs_handler = GCSHandler(GCS_BUCKET_NAME) model_prefix = request.model_name model_files = get_model_files_from_gcs(model_prefix) if not model_files: download_model_from_huggingface(model_prefix) model_files = get_model_files_from_gcs(model_prefix) model, tokenizer = load_model_from_gcs(model_prefix, model_files) pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer) if request.pipeline_task in ["text-generation", "translation", "summarization"]: result = pipe(request.input_text) return {"response": result[0]} elif request.pipeline_task == "image-generation": images = pipe(request.input_text) image = images[0] image_filename = f"{uuid.uuid4().hex}.png" image_path = f"images/{image_filename}" image.save(image_path) gcs_handler.upload_file(image_path, open(image_path, "rb")) image_url = gcs_handler.generate_signed_url(image_path) return {"response": {"image_url": image_url}} elif request.pipeline_task == "image-editing": edited_images = pipe(request.input_text) edited_image = edited_images[0] edited_image_filename = f"{uuid.uuid4().hex}_edited.png" edited_image.save(edited_image_filename) gcs_handler.upload_file(f"images/{edited_image_filename}", open(edited_image_filename, "rb")) edited_image_url = gcs_handler.generate_signed_url(f"images/{edited_image_filename}") return {"response": {"edited_image_url": edited_image_url}} elif request.pipeline_task == "image-to-image": transformed_images = pipe(request.input_text) transformed_image = transformed_images[0] transformed_image_filename = f"{uuid.uuid4().hex}_transformed.png" transformed_image.save(transformed_image_filename) gcs_handler.upload_file(f"images/{transformed_image_filename}", open(transformed_image_filename, "rb")) transformed_image_url = gcs_handler.generate_signed_url(f"images/{transformed_image_filename}") return {"response": {"transformed_image_url": transformed_image_url}} elif request.pipeline_task == "text-to-3d": model_3d_filename = f"{uuid.uuid4().hex}.obj" model_3d_path = f"3d-models/{model_3d_filename}" with open(model_3d_path, "w") as f: f.write("Simulated 3D model data") gcs_handler.upload_file(f"3d-models/{model_3d_filename}", open(model_3d_path, "rb")) model_3d_url = gcs_handler.generate_signed_url(f"3d-models/{model_3d_filename}") return {"response": {"model_3d_url": model_3d_url}} except HTTPException as e: raise e except Exception as e: raise HTTPException(status_code=500, detail=f"Error: {e}") @app.on_event("startup") async def startup_event(): logger.info("Iniciando la API...") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)