Hjgugugjhuhjggg
commited on
Commit
•
f84a20c
1
Parent(s):
f06b80f
Update app.py
Browse files
app.py
CHANGED
@@ -1,43 +1,42 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
import requests
|
|
|
4 |
from fastapi import FastAPI, HTTPException
|
5 |
from pydantic import BaseModel
|
|
|
6 |
from google.cloud import storage
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
8 |
from io import BytesIO
|
9 |
from dotenv import load_dotenv
|
10 |
import uvicorn
|
11 |
-
import json
|
12 |
-
from google.auth import exceptions
|
13 |
|
14 |
load_dotenv()
|
15 |
|
16 |
-
# Variables de entorno
|
17 |
API_KEY = os.getenv("API_KEY")
|
18 |
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
19 |
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
20 |
-
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
21 |
|
22 |
-
# Validar nombre del bucket
|
23 |
def validate_bucket_name(bucket_name):
|
|
|
|
|
|
|
|
|
24 |
if not re.match(r"^[a-z0-9][a-z0-9\-\.]*[a-z0-9]$", bucket_name):
|
25 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
return bucket_name
|
27 |
|
28 |
-
# Validar nombre del repositorio en Hugging Face
|
29 |
-
def validate_huggingface_repo_name(repo_name):
|
30 |
-
if not isinstance(repo_name, str) or not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
|
31 |
-
raise ValueError(f"El nombre del repositorio '{repo_name}' no es válido. Debe contener solo letras, números, '-', '_', y '.'")
|
32 |
-
if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
|
33 |
-
raise ValueError(f"El nombre del repositorio '{repo_name}' contiene caracteres no permitidos. Verifica los caracteres al inicio o final.")
|
34 |
-
if len(repo_name) > 96:
|
35 |
-
raise ValueError(f"El nombre del repositorio '{repo_name}' es demasiado largo. La longitud máxima es 96 caracteres.")
|
36 |
-
return repo_name
|
37 |
-
|
38 |
-
# Inicialización del cliente de GCS
|
39 |
try:
|
40 |
-
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
|
41 |
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
42 |
storage_client = storage.Client.from_service_account_info(credentials_info)
|
43 |
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
@@ -45,16 +44,13 @@ except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, Valu
|
|
45 |
print(f"Error al cargar credenciales o bucket: {e}")
|
46 |
exit(1)
|
47 |
|
48 |
-
# Inicialización de FastAPI
|
49 |
app = FastAPI()
|
50 |
|
51 |
-
|
52 |
class DownloadModelRequest(BaseModel):
|
53 |
model_name: str
|
54 |
pipeline_task: str
|
55 |
input_text: str
|
56 |
|
57 |
-
|
58 |
class GCSStreamHandler:
|
59 |
def __init__(self, bucket_name):
|
60 |
self.bucket = storage_client.bucket(bucket_name)
|
@@ -65,21 +61,18 @@ class GCSStreamHandler:
|
|
65 |
def stream_file_from_gcs(self, blob_name):
|
66 |
blob = self.bucket.blob(blob_name)
|
67 |
if not blob.exists():
|
68 |
-
raise HTTPException(status_code=404, detail=f"
|
69 |
return blob.download_as_bytes()
|
70 |
|
71 |
def upload_file_to_gcs(self, blob_name, data_stream):
|
72 |
blob = self.bucket.blob(blob_name)
|
73 |
blob.upload_from_file(data_stream)
|
74 |
-
print(f"Archivo {blob_name} subido a GCS.")
|
75 |
|
76 |
def ensure_bucket_structure(self, model_prefix):
|
77 |
-
# Crea automáticamente la estructura en el bucket si no existe
|
78 |
required_files = ["config.json", "tokenizer.json"]
|
79 |
for filename in required_files:
|
80 |
blob_name = f"{model_prefix}/{filename}"
|
81 |
if not self.file_exists(blob_name):
|
82 |
-
print(f"Creando archivo ficticio: {blob_name}")
|
83 |
self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
|
84 |
|
85 |
def stream_model_files(self, model_prefix, model_patterns):
|
@@ -88,29 +81,19 @@ class GCSStreamHandler:
|
|
88 |
blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
|
89 |
for blob in blobs:
|
90 |
if re.match(pattern, blob.name.split('/')[-1]):
|
91 |
-
print(f"Archivo encontrado: {blob.name}")
|
92 |
model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
|
93 |
return model_files
|
94 |
|
95 |
-
|
96 |
def download_model_from_huggingface(model_name):
|
97 |
-
"""
|
98 |
-
Descarga un modelo desde Hugging Face y lo sube a GCS en streaming.
|
99 |
-
"""
|
100 |
-
model_name = validate_huggingface_repo_name(model_name) # Validar nombre del repositorio
|
101 |
-
|
102 |
file_patterns = [
|
103 |
"pytorch_model.bin",
|
104 |
"model.safetensors",
|
105 |
"config.json",
|
106 |
"tokenizer.json",
|
107 |
]
|
108 |
-
|
109 |
-
# Agregar patrones para fragmentos de modelos
|
110 |
for i in range(1, 100):
|
111 |
file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
|
112 |
file_patterns.append(f"model-{i:05}")
|
113 |
-
|
114 |
for filename in file_patterns:
|
115 |
url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
|
116 |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
@@ -120,20 +103,13 @@ def download_model_from_huggingface(model_name):
|
|
120 |
blob_name = f"{model_name}/{filename}"
|
121 |
blob = bucket.blob(blob_name)
|
122 |
blob.upload_from_file(BytesIO(response.content))
|
123 |
-
print(f"Archivo {filename} subido correctamente a GCS.")
|
124 |
except Exception as e:
|
125 |
-
|
126 |
-
|
127 |
|
128 |
@app.post("/predict/")
|
129 |
async def predict(request: DownloadModelRequest):
|
130 |
-
"""
|
131 |
-
Endpoint para realizar predicciones. Si el modelo no existe en GCS, se descarga automáticamente.
|
132 |
-
"""
|
133 |
try:
|
134 |
gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
|
135 |
-
|
136 |
-
# Verificar si el modelo ya está en GCS
|
137 |
model_prefix = request.model_name
|
138 |
model_patterns = [
|
139 |
r"pytorch_model-\d+-of-\d+",
|
@@ -141,43 +117,22 @@ async def predict(request: DownloadModelRequest):
|
|
141 |
r"pytorch_model.bin",
|
142 |
r"model.safetensors",
|
143 |
]
|
144 |
-
|
145 |
if not any(
|
146 |
gcs_handler.file_exists(f"{model_prefix}/{pattern}") for pattern in model_patterns
|
147 |
):
|
148 |
-
print(f"Modelo {model_prefix} no encontrado en GCS. Descargando desde Hugging Face...")
|
149 |
download_model_from_huggingface(model_prefix)
|
150 |
-
|
151 |
-
# Carga archivos del modelo desde GCS
|
152 |
model_files = gcs_handler.stream_model_files(model_prefix, model_patterns)
|
153 |
-
|
154 |
-
# Configuración y tokenización
|
155 |
config_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/config.json")
|
156 |
tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/tokenizer.json")
|
157 |
-
|
158 |
model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
|
159 |
-
state_dict = {}
|
160 |
-
|
161 |
-
for filename, stream in model_files.items():
|
162 |
-
state_dict.update(torch.load(stream, map_location="cpu"))
|
163 |
-
|
164 |
-
model.load_state_dict(state_dict)
|
165 |
tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
|
166 |
-
|
167 |
-
# Crear pipeline
|
168 |
pipeline_task = request.pipeline_task
|
169 |
-
if pipeline_task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering"]:
|
170 |
-
raise HTTPException(status_code=400, detail="Unsupported pipeline task")
|
171 |
-
|
172 |
pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
|
173 |
input_text = request.input_text
|
174 |
result = pipeline_(input_text)
|
175 |
-
|
176 |
return {"response": result}
|
177 |
-
|
178 |
except Exception as e:
|
179 |
raise HTTPException(status_code=500, detail=f"Error: {e}")
|
180 |
|
181 |
-
|
182 |
if __name__ == "__main__":
|
183 |
-
uvicorn.run(app, host="0.0.0.0", port=
|
|
|
1 |
import os
|
2 |
import re
|
3 |
import requests
|
4 |
+
import json
|
5 |
from fastapi import FastAPI, HTTPException
|
6 |
from pydantic import BaseModel
|
7 |
+
from google.auth import exceptions
|
8 |
from google.cloud import storage
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
10 |
from io import BytesIO
|
11 |
from dotenv import load_dotenv
|
12 |
import uvicorn
|
|
|
|
|
13 |
|
14 |
load_dotenv()
|
15 |
|
|
|
16 |
API_KEY = os.getenv("API_KEY")
|
17 |
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
18 |
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
19 |
+
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
20 |
|
|
|
21 |
def validate_bucket_name(bucket_name):
|
22 |
+
if not isinstance(bucket_name, str):
|
23 |
+
raise ValueError("Bucket name must be a string.")
|
24 |
+
if len(bucket_name) < 3 or len(bucket_name) > 63:
|
25 |
+
raise ValueError("Bucket name must be between 3 and 63 characters long.")
|
26 |
if not re.match(r"^[a-z0-9][a-z0-9\-\.]*[a-z0-9]$", bucket_name):
|
27 |
+
raise ValueError(
|
28 |
+
f"Invalid bucket name '{bucket_name}'. Bucket names must:"
|
29 |
+
" - Use only lowercase letters, numbers, hyphens (-), and periods (.)"
|
30 |
+
" - Start and end with a letter or number."
|
31 |
+
)
|
32 |
+
if "--" in bucket_name or ".." in bucket_name or ".-" in bucket_name or "-." in bucket_name:
|
33 |
+
raise ValueError(
|
34 |
+
f"Invalid bucket name '{bucket_name}'. Bucket names cannot contain consecutive periods, hyphens, or use '.-' or '-.'"
|
35 |
+
)
|
36 |
return bucket_name
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
try:
|
39 |
+
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
|
40 |
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
41 |
storage_client = storage.Client.from_service_account_info(credentials_info)
|
42 |
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
|
|
44 |
print(f"Error al cargar credenciales o bucket: {e}")
|
45 |
exit(1)
|
46 |
|
|
|
47 |
app = FastAPI()
|
48 |
|
|
|
49 |
class DownloadModelRequest(BaseModel):
|
50 |
model_name: str
|
51 |
pipeline_task: str
|
52 |
input_text: str
|
53 |
|
|
|
54 |
class GCSStreamHandler:
|
55 |
def __init__(self, bucket_name):
|
56 |
self.bucket = storage_client.bucket(bucket_name)
|
|
|
61 |
def stream_file_from_gcs(self, blob_name):
|
62 |
blob = self.bucket.blob(blob_name)
|
63 |
if not blob.exists():
|
64 |
+
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found in GCS.")
|
65 |
return blob.download_as_bytes()
|
66 |
|
67 |
def upload_file_to_gcs(self, blob_name, data_stream):
|
68 |
blob = self.bucket.blob(blob_name)
|
69 |
blob.upload_from_file(data_stream)
|
|
|
70 |
|
71 |
def ensure_bucket_structure(self, model_prefix):
|
|
|
72 |
required_files = ["config.json", "tokenizer.json"]
|
73 |
for filename in required_files:
|
74 |
blob_name = f"{model_prefix}/{filename}"
|
75 |
if not self.file_exists(blob_name):
|
|
|
76 |
self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
|
77 |
|
78 |
def stream_model_files(self, model_prefix, model_patterns):
|
|
|
81 |
blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
|
82 |
for blob in blobs:
|
83 |
if re.match(pattern, blob.name.split('/')[-1]):
|
|
|
84 |
model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
|
85 |
return model_files
|
86 |
|
|
|
87 |
def download_model_from_huggingface(model_name):
|
|
|
|
|
|
|
|
|
|
|
88 |
file_patterns = [
|
89 |
"pytorch_model.bin",
|
90 |
"model.safetensors",
|
91 |
"config.json",
|
92 |
"tokenizer.json",
|
93 |
]
|
|
|
|
|
94 |
for i in range(1, 100):
|
95 |
file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
|
96 |
file_patterns.append(f"model-{i:05}")
|
|
|
97 |
for filename in file_patterns:
|
98 |
url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
|
99 |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
|
|
103 |
blob_name = f"{model_name}/{filename}"
|
104 |
blob = bucket.blob(blob_name)
|
105 |
blob.upload_from_file(BytesIO(response.content))
|
|
|
106 |
except Exception as e:
|
107 |
+
pass
|
|
|
108 |
|
109 |
@app.post("/predict/")
|
110 |
async def predict(request: DownloadModelRequest):
|
|
|
|
|
|
|
111 |
try:
|
112 |
gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
|
|
|
|
|
113 |
model_prefix = request.model_name
|
114 |
model_patterns = [
|
115 |
r"pytorch_model-\d+-of-\d+",
|
|
|
117 |
r"pytorch_model.bin",
|
118 |
r"model.safetensors",
|
119 |
]
|
|
|
120 |
if not any(
|
121 |
gcs_handler.file_exists(f"{model_prefix}/{pattern}") for pattern in model_patterns
|
122 |
):
|
|
|
123 |
download_model_from_huggingface(model_prefix)
|
|
|
|
|
124 |
model_files = gcs_handler.stream_model_files(model_prefix, model_patterns)
|
|
|
|
|
125 |
config_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/config.json")
|
126 |
tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/tokenizer.json")
|
|
|
127 |
model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
|
|
|
|
|
129 |
pipeline_task = request.pipeline_task
|
|
|
|
|
|
|
130 |
pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
|
131 |
input_text = request.input_text
|
132 |
result = pipeline_(input_text)
|
|
|
133 |
return {"response": result}
|
|
|
134 |
except Exception as e:
|
135 |
raise HTTPException(status_code=500, detail=f"Error: {e}")
|
136 |
|
|
|
137 |
if __name__ == "__main__":
|
138 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|