Hjgugugjhuhjggg
commited on
Commit
•
efa228b
1
Parent(s):
b5bc6a9
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,6 @@ from io import BytesIO
|
|
11 |
from dotenv import load_dotenv
|
12 |
import uvicorn
|
13 |
|
14 |
-
# Carga de variables de entorno
|
15 |
load_dotenv()
|
16 |
|
17 |
API_KEY = os.getenv("API_KEY")
|
@@ -20,13 +19,11 @@ GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_
|
|
20 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
21 |
|
22 |
def validate_bucket_name(bucket_name):
|
23 |
-
"""Valida que el nombre del bucket cumpla con las restricciones de Google Cloud."""
|
24 |
if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
|
25 |
raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
|
26 |
return bucket_name
|
27 |
|
28 |
def validate_huggingface_repo_name(repo_name):
|
29 |
-
"""Valida que el nombre del repositorio cumpla con las restricciones de Hugging Face."""
|
30 |
if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
|
31 |
raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
|
32 |
if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
|
@@ -35,15 +32,13 @@ def validate_huggingface_repo_name(repo_name):
|
|
35 |
raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
|
36 |
return repo_name
|
37 |
|
38 |
-
# Validar y configurar cliente de GCS
|
39 |
try:
|
40 |
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
|
41 |
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
42 |
storage_client = storage.Client.from_service_account_info(credentials_info)
|
43 |
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
44 |
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
|
45 |
-
|
46 |
-
exit(1)
|
47 |
|
48 |
app = FastAPI()
|
49 |
|
@@ -70,7 +65,6 @@ class GCSHandler:
|
|
70 |
return BytesIO(blob.download_as_bytes())
|
71 |
|
72 |
def download_model_from_huggingface(model_name):
|
73 |
-
"""Descarga un modelo desde Hugging Face y lo sube a GCS."""
|
74 |
model_name = validate_huggingface_repo_name(model_name)
|
75 |
file_patterns = [
|
76 |
"pytorch_model.bin",
|
@@ -89,7 +83,7 @@ def download_model_from_huggingface(model_name):
|
|
89 |
blob_name = f"{model_name}/{filename}"
|
90 |
bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
|
91 |
except Exception as e:
|
92 |
-
|
93 |
|
94 |
@app.post("/predict/")
|
95 |
async def predict(request: DownloadModelRequest):
|
|
|
11 |
from dotenv import load_dotenv
|
12 |
import uvicorn
|
13 |
|
|
|
14 |
load_dotenv()
|
15 |
|
16 |
API_KEY = os.getenv("API_KEY")
|
|
|
19 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
20 |
|
21 |
def validate_bucket_name(bucket_name):
|
|
|
22 |
if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
|
23 |
raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
|
24 |
return bucket_name
|
25 |
|
26 |
def validate_huggingface_repo_name(repo_name):
|
|
|
27 |
if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
|
28 |
raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
|
29 |
if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
|
|
|
32 |
raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
|
33 |
return repo_name
|
34 |
|
|
|
35 |
try:
|
36 |
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
|
37 |
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
38 |
storage_client = storage.Client.from_service_account_info(credentials_info)
|
39 |
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
40 |
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
|
41 |
+
raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")
|
|
|
42 |
|
43 |
app = FastAPI()
|
44 |
|
|
|
65 |
return BytesIO(blob.download_as_bytes())
|
66 |
|
67 |
def download_model_from_huggingface(model_name):
|
|
|
68 |
model_name = validate_huggingface_repo_name(model_name)
|
69 |
file_patterns = [
|
70 |
"pytorch_model.bin",
|
|
|
83 |
blob_name = f"{model_name}/{filename}"
|
84 |
bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
|
85 |
except Exception as e:
|
86 |
+
raise HTTPException(status_code=500, detail=f"Error downloading {filename} from Hugging Face: {e}")
|
87 |
|
88 |
@app.post("/predict/")
|
89 |
async def predict(request: DownloadModelRequest):
|