File size: 5,259 Bytes
319a292 f84a20c 1c3034c 319a292 b5bc6a9 f173552 319a292 b5bc6a9 319a292 f84a20c 319a292 09dbcd2 b5bc6a9 09dbcd2 b5bc6a9 319a292 f84a20c 319a292 09dbcd2 319a292 b5bc6a9 319a292 b5bc6a9 319a292 b5bc6a9 319a292 b5bc6a9 319a292 b5bc6a9 319a292 f173552 b5bc6a9 f173552 b5bc6a9 f173552 b5bc6a9 f173552 b5bc6a9 f173552 319a292 b5bc6a9 f173552 b5bc6a9 319a292 b5bc6a9 f173552 b5bc6a9 319a292 d84cd10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import os
import re
import json
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from google.auth import exceptions
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from io import BytesIO
from dotenv import load_dotenv
import uvicorn
# Carga de variables de entorno
load_dotenv()
API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
def validate_bucket_name(bucket_name):
"""Valida que el nombre del bucket cumpla con las restricciones de Google Cloud."""
if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
return bucket_name
def validate_huggingface_repo_name(repo_name):
"""Valida que el nombre del repositorio cumpla con las restricciones de Hugging Face."""
if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
raise ValueError(f"Invalid repository name '{repo_name}'. Cannot start or end with '-' or '.', or contain '..'.")
if len(repo_name) > 96:
raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
return repo_name
# Validar y configurar cliente de GCS
try:
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
storage_client = storage.Client.from_service_account_info(credentials_info)
bucket = storage_client.bucket(GCS_BUCKET_NAME)
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
print(f"Error al cargar credenciales o bucket: {e}")
exit(1)
app = FastAPI()
class DownloadModelRequest(BaseModel):
model_name: str
pipeline_task: str
input_text: str
class GCSHandler:
def __init__(self, bucket_name):
self.bucket = storage_client.bucket(bucket_name)
def file_exists(self, blob_name):
return self.bucket.blob(blob_name).exists()
def upload_file(self, blob_name, file_stream):
blob = self.bucket.blob(blob_name)
blob.upload_from_file(file_stream)
def download_file(self, blob_name):
blob = self.bucket.blob(blob_name)
if not blob.exists():
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
return BytesIO(blob.download_as_bytes())
def download_model_from_huggingface(model_name):
"""Descarga un modelo desde Hugging Face y lo sube a GCS."""
model_name = validate_huggingface_repo_name(model_name)
file_patterns = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
for i in range(1, 100):
file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
for filename in file_patterns:
url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
try:
response = requests.get(url, headers=headers, stream=True)
if response.status_code == 200:
blob_name = f"{model_name}/{filename}"
bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
except Exception as e:
print(f"Error downloading {filename} from Hugging Face: {e}")
@app.post("/predict/")
async def predict(request: DownloadModelRequest):
try:
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
model_prefix = request.model_name
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
for i in range(1, 100):
model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
if not any(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files):
download_model_from_huggingface(model_prefix)
model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
config_stream = model_files_streams.get("config.json")
tokenizer_stream = model_files_streams.get("tokenizer.json")
if not config_stream or not tokenizer_stream:
raise HTTPException(status_code=500, detail="Required model files missing.")
model = AutoModelForCausalLM.from_pretrained(config_stream)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
result = pipeline_(request.input_text)
return {"response": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|