Hjgugugjhuhjggg
commited on
Commit
•
319a292
1
Parent(s):
51669c7
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import requests
|
5 |
+
import torch
|
6 |
+
from fastapi import FastAPI, HTTPException
|
7 |
+
from pydantic import BaseModel
|
8 |
+
from google.cloud import storage
|
9 |
+
from google.auth import exceptions
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from transformers.hf_api import HfApi, HfFolder, HfLoginManager
|
12 |
+
from io import BytesIO
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
import uvicorn
|
15 |
+
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
# Variables de entorno
|
19 |
+
API_KEY = os.getenv("API_KEY")
|
20 |
+
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
21 |
+
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
22 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Token de Hugging Face
|
23 |
+
|
24 |
+
# Inicialización del cliente de GCS
|
25 |
+
try:
|
26 |
+
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
27 |
+
storage_client = storage.Client.from_service_account_info(credentials_info)
|
28 |
+
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
29 |
+
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError) as e:
|
30 |
+
print(f"Error al cargar credenciales o bucket: {e}")
|
31 |
+
exit(1)
|
32 |
+
|
33 |
+
# Inicialización de FastAPI
|
34 |
+
app = FastAPI()
|
35 |
+
|
36 |
+
# Inicio de sesión en Hugging Face
|
37 |
+
try:
|
38 |
+
if not HF_API_TOKEN:
|
39 |
+
raise ValueError("El token de Hugging Face no está definido en las variables de entorno.")
|
40 |
+
HfApi().set_access_token(HF_API_TOKEN)
|
41 |
+
print("Inicio de sesión en Hugging Face exitoso.")
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error al iniciar sesión en Hugging Face: {e}")
|
44 |
+
exit(1)
|
45 |
+
|
46 |
+
|
47 |
+
class DownloadModelRequest(BaseModel):
|
48 |
+
model_name: str
|
49 |
+
pipeline_task: str
|
50 |
+
input_text: str
|
51 |
+
|
52 |
+
|
53 |
+
class GCSStreamHandler:
|
54 |
+
def __init__(self, bucket_name):
|
55 |
+
self.bucket = storage_client.bucket(bucket_name)
|
56 |
+
|
57 |
+
def file_exists(self, blob_name):
|
58 |
+
return self.bucket.blob(blob_name).exists()
|
59 |
+
|
60 |
+
def stream_file_from_gcs(self, blob_name):
|
61 |
+
blob = self.bucket.blob(blob_name)
|
62 |
+
if not blob.exists():
|
63 |
+
raise HTTPException(status_code=404, detail=f"Archivo '{blob_name}' no encontrado en GCS.")
|
64 |
+
return blob.download_as_bytes()
|
65 |
+
|
66 |
+
def upload_file_to_gcs(self, blob_name, data_stream):
|
67 |
+
blob = self.bucket.blob(blob_name)
|
68 |
+
blob.upload_from_file(data_stream)
|
69 |
+
print(f"Archivo {blob_name} subido a GCS.")
|
70 |
+
|
71 |
+
def ensure_bucket_structure(self, model_prefix):
|
72 |
+
# Crea automáticamente la estructura en el bucket si no existe
|
73 |
+
required_files = ["config.json", "tokenizer.json"]
|
74 |
+
for filename in required_files:
|
75 |
+
blob_name = f"{model_prefix}/{filename}"
|
76 |
+
if not self.file_exists(blob_name):
|
77 |
+
print(f"Creando archivo ficticio: {blob_name}")
|
78 |
+
self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
|
79 |
+
|
80 |
+
def stream_model_files(self, model_prefix, model_patterns):
|
81 |
+
model_files = {}
|
82 |
+
for pattern in model_patterns:
|
83 |
+
blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
|
84 |
+
for blob in blobs:
|
85 |
+
if re.match(pattern, blob.name.split('/')[-1]):
|
86 |
+
print(f"Archivo encontrado: {blob.name}")
|
87 |
+
model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
|
88 |
+
return model_files
|
89 |
+
|
90 |
+
|
91 |
+
@app.post("/predict/")
|
92 |
+
async def predict(request: DownloadModelRequest):
|
93 |
+
try:
|
94 |
+
gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
|
95 |
+
|
96 |
+
# Asegura la estructura del bucket
|
97 |
+
gcs_handler.ensure_bucket_structure(request.model_name)
|
98 |
+
|
99 |
+
# Define patrones para los archivos de modelos
|
100 |
+
model_patterns = [
|
101 |
+
r"pytorch_model-\d+-of-\d+",
|
102 |
+
r"model-\d+",
|
103 |
+
r"pytorch_model.bin",
|
104 |
+
r"model.safetensors"
|
105 |
+
]
|
106 |
+
|
107 |
+
# Carga los archivos del modelo desde el bucket
|
108 |
+
model_files = gcs_handler.stream_model_files(request.model_name, model_patterns)
|
109 |
+
|
110 |
+
# Cargar configuración y modelo
|
111 |
+
config_stream = gcs_handler.stream_file_from_gcs(f"{request.model_name}/config.json")
|
112 |
+
tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{request.model_name}/tokenizer.json")
|
113 |
+
|
114 |
+
model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
|
115 |
+
state_dict = {}
|
116 |
+
|
117 |
+
for filename, stream in model_files.items():
|
118 |
+
state_dict.update(torch.load(stream, map_location="cpu"))
|
119 |
+
|
120 |
+
model.load_state_dict(state_dict)
|
121 |
+
tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
|
122 |
+
|
123 |
+
# Crear pipeline
|
124 |
+
pipeline_task = request.pipeline_task
|
125 |
+
if pipeline_task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering"]:
|
126 |
+
raise HTTPException(status_code=400, detail="Unsupported pipeline task")
|
127 |
+
|
128 |
+
pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
|
129 |
+
input_text = request.input_text
|
130 |
+
result = pipeline_(input_text)
|
131 |
+
|
132 |
+
return {"response": result}
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
raise HTTPException(status_code=500, detail=f"Error: {e}")
|
136 |
+
|
137 |
+
|
138 |
+
@app.post("/upload/")
|
139 |
+
async def upload_model_to_gcs(model_name: str):
|
140 |
+
"""
|
141 |
+
Descarga un modelo desde Hugging Face y lo sube a GCS en streaming.
|
142 |
+
"""
|
143 |
+
try:
|
144 |
+
gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
|
145 |
+
|
146 |
+
# Archivos comunes de los modelos
|
147 |
+
file_patterns = [
|
148 |
+
"pytorch_model.bin",
|
149 |
+
"model.safetensors",
|
150 |
+
"config.json",
|
151 |
+
"tokenizer.json",
|
152 |
+
]
|
153 |
+
|
154 |
+
# Agregar patrones para fragmentos de modelos
|
155 |
+
for i in range(1, 100):
|
156 |
+
file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
|
157 |
+
file_patterns.append(f"model-{i:05}")
|
158 |
+
|
159 |
+
for filename in file_patterns:
|
160 |
+
url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
|
161 |
+
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
162 |
+
try:
|
163 |
+
response = requests.get(url, headers=headers, stream=True)
|
164 |
+
if response.status_code == 200:
|
165 |
+
blob_name = f"{model_name}/{filename}"
|
166 |
+
blob = bucket.blob(blob_name)
|
167 |
+
blob.upload_from_file(BytesIO(response.content))
|
168 |
+
print(f"Archivo {filename} subido correctamente a GCS.")
|
169 |
+
except Exception as e:
|
170 |
+
print(f"Archivo {filename} no encontrado: {e}")
|
171 |
+
except Exception as e:
|
172 |
+
raise HTTPException(status_code=500, detail=f"Error al subir modelo: {e}")
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|