File size: 5,412 Bytes
319a292 f84a20c 1c3034c db17ba5 319a292 b5bc6a9 f173552 319a292 823bbba 319a292 f84a20c 319a292 09dbcd2 efa228b 319a292 b5bc6a9 319a292 b5bc6a9 319a292 b5bc6a9 319a292 b5bc6a9 319a292 b5bc6a9 319a292 f173552 db17ba5 3e20aa7 db17ba5 823bbba db17ba5 f173552 319a292 b5bc6a9 f173552 b5bc6a9 319a292 3e20aa7 f173552 3e20aa7 b5bc6a9 3e20aa7 b5bc6a9 823bbba 3e20aa7 823bbba b5bc6a9 3e20aa7 823bbba be0c6cb 823bbba 3e20aa7 319a292 3e20aa7 319a292 d84cd10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import json
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from google.auth import exceptions
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from io import BytesIO
from dotenv import load_dotenv
import uvicorn
import tempfile
load_dotenv()
API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
try:
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
storage_client = storage.Client.from_service_account_info(credentials_info)
bucket = storage_client.bucket(GCS_BUCKET_NAME)
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")
app = FastAPI()
class DownloadModelRequest(BaseModel):
model_name: str
pipeline_task: str
input_text: str
class GCSHandler:
def __init__(self, bucket_name):
self.bucket = storage_client.bucket(bucket_name)
def file_exists(self, blob_name):
return self.bucket.blob(blob_name).exists()
def upload_file(self, blob_name, file_stream):
blob = self.bucket.blob(blob_name)
blob.upload_from_file(file_stream)
def download_file(self, blob_name):
blob = self.bucket.blob(blob_name)
if not blob.exists():
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
return BytesIO(blob.download_as_bytes())
def download_model_from_huggingface(model_name):
url = f"https://huggingface.co/{model_name}/tree/main"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
try:
response = requests.get(url, headers=headers)
if response.status_code == 200:
# Enlace a los archivos del modelo
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
for file_name in model_files:
file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
file_content = requests.get(file_url).content
blob_name = f"{model_name}/{file_name}"
bucket.blob(blob_name).upload_from_file(BytesIO(file_content))
else:
raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
@app.post("/predict/")
async def predict(request: DownloadModelRequest):
try:
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
model_prefix = request.model_name
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
# Verificar si los archivos del modelo están en GCS
model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
if not model_files_exist:
# Descargar el modelo si no existe
download_model_from_huggingface(model_prefix)
# Descargar los archivos necesarios
model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
# Asegurar que los archivos esenciales estén presentes
config_stream = model_files_streams.get("config.json")
tokenizer_stream = model_files_streams.get("tokenizer.json")
model_stream = model_files_streams.get("pytorch_model.bin")
if not config_stream or not tokenizer_stream or not model_stream:
raise HTTPException(status_code=500, detail="Required model files missing.")
# Guardar los archivos en directorios temporales
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = os.path.join(tmp_dir, "config.json")
tokenizer_path = os.path.join(tmp_dir, "tokenizer.json")
model_path = os.path.join(tmp_dir, "pytorch_model.bin")
with open(config_path, 'wb') as f:
f.write(config_stream.read())
with open(tokenizer_path, 'wb') as f:
f.write(tokenizer_stream.read())
with open(model_path, 'wb') as f:
f.write(model_stream.read())
# Cargar el modelo y el tokenizador desde los archivos temporales
model = AutoModelForCausalLM.from_pretrained(tmp_dir)
tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
# Crear un pipeline para la tarea deseada
pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
# Realizar la predicción
result = pipeline_(request.input_text)
return {"response": result}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|