Hjgugugjhuhjggg commited on
Commit
b5bc6a9
1 Parent(s): 1c3034c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -72
app.py CHANGED
@@ -4,13 +4,14 @@ import json
4
  import requests
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
- from google.auth import exceptions
8
  from google.cloud import storage
 
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
  from io import BytesIO
11
  from dotenv import load_dotenv
12
  import uvicorn
13
 
 
14
  load_dotenv()
15
 
16
  API_KEY = os.getenv("API_KEY")
@@ -18,35 +19,28 @@ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
18
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
19
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
20
 
21
- def sanitize_bucket_name(bucket_name):
22
- """Corrige un nombre de bucket inválido para ajustarse a las reglas de Google Cloud Storage."""
23
- bucket_name = bucket_name.lower()
24
- bucket_name = re.sub(r"[^a-z0-9-\.]", "-", bucket_name)
25
- bucket_name = bucket_name.strip("-.")
26
- if len(bucket_name) > 63:
27
- bucket_name = bucket_name[:63]
28
- if not re.match(r"^[a-z0-9]", bucket_name):
29
- bucket_name = "a" + bucket_name
30
- if not re.match(r"[a-z0-9]$", bucket_name):
31
- bucket_name = bucket_name + "a"
32
- return bucket_name
33
-
34
  def validate_bucket_name(bucket_name):
35
- """Valida si el nombre de bucket cumple con las reglas de Google Cloud Storage."""
36
- if not re.match(r"^[a-z0-9][a-z0-9\-\.]*[a-z0-9]$", bucket_name):
37
- raise ValueError(f"Nombre de bucket inválido: '{bucket_name}'. Debe cumplir con las reglas de GCS.")
38
  return bucket_name
39
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
- # Sanitizar y validar el nombre del bucket
42
- GCS_BUCKET_NAME = sanitize_bucket_name(GCS_BUCKET_NAME)
43
  GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
44
-
45
- # Cargar credenciales de Google Cloud Storage
46
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
47
  storage_client = storage.Client.from_service_account_info(credentials_info)
48
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
49
-
50
  except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
51
  print(f"Error al cargar credenciales o bucket: {e}")
52
  exit(1)
@@ -58,49 +52,34 @@ class DownloadModelRequest(BaseModel):
58
  pipeline_task: str
59
  input_text: str
60
 
61
- class GCSStreamHandler:
62
  def __init__(self, bucket_name):
63
  self.bucket = storage_client.bucket(bucket_name)
64
 
65
  def file_exists(self, blob_name):
66
  return self.bucket.blob(blob_name).exists()
67
 
68
- def stream_file_from_gcs(self, blob_name):
69
  blob = self.bucket.blob(blob_name)
70
- if not blob.exists():
71
- raise HTTPException(status_code=404, detail=f"Archivo '{blob_name}' no encontrado en GCS.")
72
- return blob.download_as_bytes()
73
 
74
- def upload_file_to_gcs(self, blob_name, data_stream):
75
  blob = self.bucket.blob(blob_name)
76
- blob.upload_from_file(data_stream)
77
-
78
- def ensure_bucket_structure(self, model_prefix):
79
- required_files = ["config.json", "tokenizer.json"]
80
- for filename in required_files:
81
- blob_name = f"{model_prefix}/{filename}"
82
- if not self.file_exists(blob_name):
83
- self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
84
-
85
- def stream_model_files(self, model_prefix, model_patterns):
86
- model_files = {}
87
- for pattern in model_patterns:
88
- blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
89
- for blob in blobs:
90
- if re.match(pattern, blob.name.split('/')[-1]):
91
- model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
92
- return model_files
93
 
94
  def download_model_from_huggingface(model_name):
 
 
95
  file_patterns = [
96
  "pytorch_model.bin",
97
- "model.safetensors",
98
  "config.json",
99
  "tokenizer.json",
 
100
  ]
101
  for i in range(1, 100):
102
- file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
103
- file_patterns.append(f"model-{i:05}")
104
  for filename in file_patterns:
105
  url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
106
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
@@ -108,37 +87,34 @@ def download_model_from_huggingface(model_name):
108
  response = requests.get(url, headers=headers, stream=True)
109
  if response.status_code == 200:
110
  blob_name = f"{model_name}/{filename}"
111
- blob = bucket.blob(blob_name)
112
- blob.upload_from_file(BytesIO(response.content))
113
- except Exception:
114
- pass
115
 
116
  @app.post("/predict/")
117
  async def predict(request: DownloadModelRequest):
118
  try:
119
- gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
120
  model_prefix = request.model_name
121
- model_patterns = [
122
- r"pytorch_model-\d+-of-\d+",
123
- r"model-\d+",
124
- r"pytorch_model.bin",
125
- r"model.safetensors",
126
  ]
127
- if not any(
128
- gcs_handler.file_exists(f"{model_prefix}/{pattern}") for pattern in model_patterns
129
- ):
130
  download_model_from_huggingface(model_prefix)
131
- model_files = gcs_handler.stream_model_files(model_prefix, model_patterns)
132
- config_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/config.json")
133
- tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/tokenizer.json")
134
- model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
135
- tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
136
- pipeline_task = request.pipeline_task
137
- if pipeline_task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering"]:
138
- raise HTTPException(status_code=400, detail="Tarea no soportada")
139
- pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
140
- input_text = request.input_text
141
- result = pipeline_(input_text)
142
  return {"response": result}
143
  except Exception as e:
144
  raise HTTPException(status_code=500, detail=f"Error: {e}")
 
4
  import requests
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
 
7
  from google.cloud import storage
8
+ from google.auth import exceptions
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
  from io import BytesIO
11
  from dotenv import load_dotenv
12
  import uvicorn
13
 
14
+ # Carga de variables de entorno
15
  load_dotenv()
16
 
17
  API_KEY = os.getenv("API_KEY")
 
19
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
20
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def validate_bucket_name(bucket_name):
23
+ """Valida que el nombre del bucket cumpla con las restricciones de Google Cloud."""
24
+ if not re.match(r"^[a-z0-9][a-z0-9\-]*[a-z0-9]$", bucket_name):
25
+ raise ValueError(f"Invalid bucket name '{bucket_name}'. Must start and end with a letter or number.")
26
  return bucket_name
27
 
28
+ def validate_huggingface_repo_name(repo_name):
29
+ """Valida que el nombre del repositorio cumpla con las restricciones de Hugging Face."""
30
+ if not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
31
+ raise ValueError(f"Invalid repository name '{repo_name}'. Must use alphanumeric characters, '-', '_', or '.'.")
32
+ if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
33
+ raise ValueError(f"Invalid repository name '{repo_name}'. Cannot start or end with '-' or '.', or contain '..'.")
34
+ if len(repo_name) > 96:
35
+ raise ValueError(f"Repository name '{repo_name}' exceeds max length of 96 characters.")
36
+ return repo_name
37
+
38
+ # Validar y configurar cliente de GCS
39
  try:
 
 
40
  GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
 
 
41
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
42
  storage_client = storage.Client.from_service_account_info(credentials_info)
43
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
 
44
  except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
45
  print(f"Error al cargar credenciales o bucket: {e}")
46
  exit(1)
 
52
  pipeline_task: str
53
  input_text: str
54
 
55
+ class GCSHandler:
56
  def __init__(self, bucket_name):
57
  self.bucket = storage_client.bucket(bucket_name)
58
 
59
  def file_exists(self, blob_name):
60
  return self.bucket.blob(blob_name).exists()
61
 
62
+ def upload_file(self, blob_name, file_stream):
63
  blob = self.bucket.blob(blob_name)
64
+ blob.upload_from_file(file_stream)
 
 
65
 
66
+ def download_file(self, blob_name):
67
  blob = self.bucket.blob(blob_name)
68
+ if not blob.exists():
69
+ raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
70
+ return BytesIO(blob.download_as_bytes())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  def download_model_from_huggingface(model_name):
73
+ """Descarga un modelo desde Hugging Face y lo sube a GCS."""
74
+ model_name = validate_huggingface_repo_name(model_name)
75
  file_patterns = [
76
  "pytorch_model.bin",
 
77
  "config.json",
78
  "tokenizer.json",
79
+ "model.safetensors",
80
  ]
81
  for i in range(1, 100):
82
+ file_patterns.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
 
83
  for filename in file_patterns:
84
  url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
85
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
 
87
  response = requests.get(url, headers=headers, stream=True)
88
  if response.status_code == 200:
89
  blob_name = f"{model_name}/{filename}"
90
+ bucket.blob(blob_name).upload_from_file(BytesIO(response.content))
91
+ except Exception as e:
92
+ print(f"Error downloading {filename} from Hugging Face: {e}")
 
93
 
94
  @app.post("/predict/")
95
  async def predict(request: DownloadModelRequest):
96
  try:
97
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
98
  model_prefix = request.model_name
99
+ model_files = [
100
+ "pytorch_model.bin",
101
+ "config.json",
102
+ "tokenizer.json",
103
+ "model.safetensors",
104
  ]
105
+ for i in range(1, 100):
106
+ model_files.extend([f"pytorch_model-{i:05}-of-00001", f"model-{i:05}"])
107
+ if not any(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files):
108
  download_model_from_huggingface(model_prefix)
109
+ model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
110
+ config_stream = model_files_streams.get("config.json")
111
+ tokenizer_stream = model_files_streams.get("tokenizer.json")
112
+ if not config_stream or not tokenizer_stream:
113
+ raise HTTPException(status_code=500, detail="Required model files missing.")
114
+ model = AutoModelForCausalLM.from_pretrained(config_stream)
115
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
116
+ pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
117
+ result = pipeline_(request.input_text)
 
 
118
  return {"response": result}
119
  except Exception as e:
120
  raise HTTPException(status_code=500, detail=f"Error: {e}")