Hjgugugjhuhjggg
commited on
Commit
•
b5bc6a9
1
Parent(s):
1c3034c
Update app.py
Browse files
app.py
CHANGED
@@ -4,13 +4,14 @@ import json
|
|
4 |
import requests
|
5 |
from fastapi import FastAPI, HTTPException
|
6 |
from pydantic import BaseModel
|
7 |
-
from google.auth import exceptions
|
8 |
from google.cloud import storage
|
|
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
10 |
from io import BytesIO
|
11 |
from dotenv import load_dotenv
|
12 |
import uvicorn
|
13 |
|
|
|
14 |
load_dotenv()
|
15 |
|
16 |
API_KEY = os.getenv("API_KEY")
|
@@ -18,35 +19,28 @@ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
|
|
18 |
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
19 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
20 |
|
21 |
-
def sanitize_bucket_name(bucket_name):
|
22 |
-
"""Corrige un nombre de bucket inválido para ajustarse a las reglas de Google Cloud Storage."""
|
23 |
-
bucket_name = bucket_name.lower()
|
24 |
-
bucket_name = re.sub(r"[^a-z0-9-\.]", "-", bucket_name)
|
25 |
-
bucket_name = bucket_name.strip("-.")
|
26 |
-
if len(bucket_name) > 63:
|
27 |
-
bucket_name = bucket_name[:63]
|
28 |
-
if not re.match(r"^[a-z0-9]", bucket_name):
|
29 |
-
bucket_name = "a" + bucket_name
|
30 |
-
if not re.match(r"[a-z0-9]$", bucket_name):
|
31 |
-
bucket_name = bucket_name + "a"
|
32 |
-
return bucket_name
|
33 |
-
|
34 |
def validate_bucket_name(bucket_name):
|
35 |
-
"""Valida
|
36 |
-
if not re.match(r"^[a-z0-9][a-z0-9
|
37 |
-
raise ValueError(f"
|
38 |
return bucket_name
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
try:
|
41 |
-
# Sanitizar y validar el nombre del bucket
|
42 |
-
GCS_BUCKET_NAME = sanitize_bucket_name(GCS_BUCKET_NAME)
|
43 |
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
|
44 |
-
|
45 |
-
# Cargar credenciales de Google Cloud Storage
|
46 |
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
47 |
storage_client = storage.Client.from_service_account_info(credentials_info)
|
48 |
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
49 |
-
|
50 |
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
|
51 |
print(f"Error al cargar credenciales o bucket: {e}")
|
52 |
exit(1)
|
@@ -58,49 +52,34 @@ class DownloadModelRequest(BaseModel):
|
|
58 |
pipeline_task: str
|
59 |
input_text: str
|
60 |
|
61 |
-
class
|
62 |
def __init__(self, bucket_name):
|
63 |
self.bucket = storage_client.bucket(bucket_name)
|
64 |
|
65 |
def file_exists(self, blob_name):
|
66 |
return self.bucket.blob(blob_name).exists()
|
67 |
|
68 |
-
def
|
69 |
blob = self.bucket.blob(blob_name)
|
70 |
-
|
71 |
-
raise HTTPException(status_code=404, detail=f"Archivo '{blob_name}' no encontrado en GCS.")
|
72 |
-
return blob.download_as_bytes()
|
73 |
|
74 |
-
def
|
75 |
blob = self.bucket.blob(blob_name)
|
76 |
-
blob.
|
77 |
-
|
78 |
-
|
79 |
-
required_files = ["config.json", "tokenizer.json"]
|
80 |
-
for filename in required_files:
|
81 |
-
blob_name = f"{model_prefix}/{filename}"
|
82 |
-
if not self.file_exists(blob_name):
|
83 |
-
self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
|
84 |
-
|
85 |
-
def stream_model_files(self, model_prefix, model_patterns):
|
86 |
-
model_files = {}
|
87 |
-
for pattern in model_patterns:
|
88 |
-
blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
|
89 |
-
for blob in blobs:
|
90 |
-
if re.match(pattern, blob.name.split('/')[-1]):
|
91 |
-
model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
|
92 |
-
return model_files
|
93 |
|
94 |
def download_model_from_huggingface(model_name):
|
|
|
|
|
95 |
file_patterns = [
|
96 |
"pytorch_model.bin",
|
97 |
-
"model.safetensors",
|
98 |
"config.json",
|
99 |
"tokenizer.json",
|
|
|
100 |
]
|
101 |
for i in range(1, 100):
|
102 |
-
file_patterns.
|
103 |
-
file_patterns.append(f"model-{i:05}")
|
104 |
for filename in file_patterns:
|
105 |
url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
|
106 |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
@@ -108,37 +87,34 @@ def download_model_from_huggingface(model_name):
|
|
108 |
response = requests.get(url, headers=headers, stream=True)
|
109 |
if response.status_code == 200:
|
110 |
blob_name = f"{model_name}/{filename}"
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
pass
|
115 |
|
116 |
@app.post("/predict/")
|
117 |
async def predict(request: DownloadModelRequest):
|
118 |
try:
|
119 |
-
gcs_handler =
|
120 |
model_prefix = request.model_name
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
]
|
127 |
-
|
128 |
-
|
129 |
-
):
|
130 |
download_model_from_huggingface(model_prefix)
|
131 |
-
|
132 |
-
config_stream =
|
133 |
-
tokenizer_stream =
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
input_text = request.input_text
|
141 |
-
result = pipeline_(input_text)
|
142 |
return {"response": result}
|
143 |
except Exception as e:
|
144 |
raise HTTPException(status_code=500, detail=f"Error: {e}")
|
|
|
4 |
import requests
|
5 |
from fastapi import FastAPI, HTTPException
|
6 |
from pydantic import BaseModel
|
|
|
7 |
from google.cloud import storage
|
8 |
+
from google.auth import exceptions
|
9 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
10 |
from io import BytesIO
|
11 |
from dotenv import load_dotenv
|
12 |
import uvicorn
|
13 |
|
14 |
+
# Carga de variables de entorno
|
15 |
load_dotenv()
|
16 |
|
17 |
API_KEY = os.getenv("API_KEY")
|
|
|
19 |
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
|
20 |
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
def validate_bucket_name(bucket_name):
|
23 |
+
"""Valida que el nombre del bucket cumpla con las restricciones de Google Cloud."""
|
24 |
+
if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
|
25 |
+
raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
|
26 |
return bucket_name
|
27 |
|
28 |
+
def validate_huggingface_repo_name(repo_name):
|
29 |
+
"""Valida que el nombre del repositorio cumpla con las restricciones de Hugging Face."""
|
30 |
+
if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
|
31 |
+
raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
|
32 |
+
if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
|
33 |
+
raise ValueError(f"Invalid repository name '{repo_name}'. Cannot start or end with '-' or '.', or contain '..'.")
|
34 |
+
if len(repo_name) > 96:
|
35 |
+
raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
|
36 |
+
return repo_name
|
37 |
+
|
38 |
+
# Validar y configurar cliente de GCS
|
39 |
try:
|
|
|
|
|
40 |
GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
|
|
|
|
|
41 |
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
|
42 |
storage_client = storage.Client.from_service_account_info(credentials_info)
|
43 |
bucket = storage_client.bucket(GCS_BUCKET_NAME)
|
|
|
44 |
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
|
45 |
print(f"Error al cargar credenciales o bucket: {e}")
|
46 |
exit(1)
|
|
|
52 |
pipeline_task: str
|
53 |
input_text: str
|
54 |
|
55 |
+
class GCSHandler:
|
56 |
def __init__(self, bucket_name):
|
57 |
self.bucket = storage_client.bucket(bucket_name)
|
58 |
|
59 |
def file_exists(self, blob_name):
|
60 |
return self.bucket.blob(blob_name).exists()
|
61 |
|
62 |
+
def upload_file(self, blob_name, file_stream):
|
63 |
blob = self.bucket.blob(blob_name)
|
64 |
+
blob.upload_from_file(file_stream)
|
|
|
|
|
65 |
|
66 |
+
def download_file(self, blob_name):
|
67 |
blob = self.bucket.blob(blob_name)
|
68 |
+
if not blob.exists():
|
69 |
+
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
|
70 |
+
return BytesIO(blob.download_as_bytes())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
def download_model_from_huggingface(model_name):
|
73 |
+
"""Descarga un modelo desde Hugging Face y lo sube a GCS."""
|
74 |
+
model_name = validate_huggingface_repo_name(model_name)
|
75 |
file_patterns = [
|
76 |
"pytorch_model.bin",
|
|
|
77 |
"config.json",
|
78 |
"tokenizer.json",
|
79 |
+
"model.safetensors",
|
80 |
]
|
81 |
for i in range(1, 100):
|
82 |
+
file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
|
|
|
83 |
for filename in file_patterns:
|
84 |
url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
|
85 |
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
|
|
|
87 |
response = requests.get(url, headers=headers, stream=True)
|
88 |
if response.status_code == 200:
|
89 |
blob_name = f"{model_name}/{filename}"
|
90 |
+
bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
|
91 |
+
except Exception as e:
|
92 |
+
print(f"Error downloading {filename} from Hugging Face: {e}")
|
|
|
93 |
|
94 |
@app.post("/predict/")
|
95 |
async def predict(request: DownloadModelRequest):
|
96 |
try:
|
97 |
+
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
|
98 |
model_prefix = request.model_name
|
99 |
+
model_files = [
|
100 |
+
"pytorch_model.bin",
|
101 |
+
"config.json",
|
102 |
+
"tokenizer.json",
|
103 |
+
"model.safetensors",
|
104 |
]
|
105 |
+
for i in range(1, 100):
|
106 |
+
model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
|
107 |
+
if not any(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files):
|
108 |
download_model_from_huggingface(model_prefix)
|
109 |
+
model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
|
110 |
+
config_stream = model_files_streams.get("config.json")
|
111 |
+
tokenizer_stream = model_files_streams.get("tokenizer.json")
|
112 |
+
if not config_stream or not tokenizer_stream:
|
113 |
+
raise HTTPException(status_code=500, detail="Required model files missing.")
|
114 |
+
model = AutoModelForCausalLM.from_pretrained(config_stream)
|
115 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
|
116 |
+
pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
|
117 |
+
result = pipeline_(request.input_text)
|
|
|
|
|
118 |
return {"response": result}
|
119 |
except Exception as e:
|
120 |
raise HTTPException(status_code=500, detail=f"Error: {e}")
|