Hjgugugjhuhjggg commited on
Commit
efa228b
1 Parent(s): b5bc6a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -8
app.py CHANGED
@@ -11,7 +11,6 @@ from io import BytesIO
11
  from dotenv import load_dotenv
12
  import uvicorn
13
 
14
- # Carga de variables de entorno
15
  load_dotenv()
16
 
17
  API_KEY = os.getenv("API_KEY")
@@ -20,13 +19,11 @@ GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_
20
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
21
 
22
  def validate_bucket_name(bucket_name):
23
- """Valida que el nombre del bucket cumpla con las restricciones de Google Cloud."""
24
  if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
25
  raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
26
  return bucket_name
27
 
28
  def validate_huggingface_repo_name(repo_name):
29
- """Valida que el nombre del repositorio cumpla con las restricciones de Hugging Face."""
30
  if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
31
  raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
32
  if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
@@ -35,15 +32,13 @@ def validate_huggingface_repo_name(repo_name):
35
  raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
36
  return repo_name
37
 
38
- # Validar y configurar cliente de GCS
39
  try:
40
  GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
41
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
42
  storage_client = storage.Client.from_service_account_info(credentials_info)
43
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
44
  except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
45
- print(f"Error al cargar credenciales o bucket: {e}")
46
- exit(1)
47
 
48
  app = FastAPI()
49
 
@@ -70,7 +65,6 @@ class GCSHandler:
70
  return BytesIO(blob.download_as_bytes())
71
 
72
  def download_model_from_huggingface(model_name):
73
- """Descarga un modelo desde Hugging Face y lo sube a GCS."""
74
  model_name = validate_huggingface_repo_name(model_name)
75
  file_patterns = [
76
  "pytorch_model.bin",
@@ -89,7 +83,7 @@ def download_model_from_huggingface(model_name):
89
  blob_name = f"{model_name}/{filename}"
90
  bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
91
  except Exception as e:
92
- print(f"Error downloading {filename} from Hugging Face: {e}")
93
 
94
  @app.post("/predict/")
95
  async def predict(request: DownloadModelRequest):
 
11
  from dotenv import load_dotenv
12
  import uvicorn
13
 
 
14
  load_dotenv()
15
 
16
  API_KEY = os.getenv("API_KEY")
 
19
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
20
 
21
  def validate_bucket_name(bucket_name):
 
22
  if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
23
  raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
24
  return bucket_name
25
 
26
  def validate_huggingface_repo_name(repo_name):
 
27
  if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
28
  raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
29
  if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
 
32
  raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
33
  return repo_name
34
 
 
35
  try:
36
  GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
37
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
38
  storage_client = storage.Client.from_service_account_info(credentials_info)
39
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
40
  except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
41
+ raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")
 
42
 
43
  app = FastAPI()
44
 
 
65
  return BytesIO(blob.download_as_bytes())
66
 
67
  def download_model_from_huggingface(model_name):
 
68
  model_name = validate_huggingface_repo_name(model_name)
69
  file_patterns = [
70
  "pytorch_model.bin",
 
83
  blob_name = f"{model_name}/{filename}"
84
  bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
85
  except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"Error downloading {filename} from Hugging Face: {e}")
87
 
88
  @app.post("/predict/")
89
  async def predict(request: DownloadModelRequest):