Hjgugugjhuhjggg commited on
Commit
f84a20c
1 Parent(s): f06b80f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -65
app.py CHANGED
@@ -1,43 +1,42 @@
1
  import os
2
  import re
3
  import requests
 
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
 
6
  from google.cloud import storage
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
  from io import BytesIO
9
  from dotenv import load_dotenv
10
  import uvicorn
11
- import json
12
- from google.auth import exceptions
13
 
14
  load_dotenv()
15
 
16
- # Variables de entorno
17
  API_KEY = os.getenv("API_KEY")
18
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
19
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
20
- HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Token de Hugging Face
21
 
22
- # Validar nombre del bucket
23
  def validate_bucket_name(bucket_name):
 
 
 
 
24
  if not re.match(r"^[a-z0-9][a-z0-9\-\.]*[a-z0-9]$", bucket_name):
25
- raise ValueError(f"El nombre del bucket '{bucket_name}' no es válido. Debe comenzar y terminar con una letra o número.")
 
 
 
 
 
 
 
 
26
  return bucket_name
27
 
28
- # Validar nombre del repositorio en Hugging Face
29
- def validate_huggingface_repo_name(repo_name):
30
- if not isinstance(repo_name, str) or not re.match(r"^[a-zA-Z0-9_.-]+$", repo_name):
31
- raise ValueError(f"El nombre del repositorio '{repo_name}' no es válido. Debe contener solo letras, números, '-', '_', y '.'")
32
- if repo_name.startswith(('-', '.')) or repo_name.endswith(('-', '.')) or '..' in repo_name:
33
- raise ValueError(f"El nombre del repositorio '{repo_name}' contiene caracteres no permitidos. Verifica los caracteres al inicio o final.")
34
- if len(repo_name) > 96:
35
- raise ValueError(f"El nombre del repositorio '{repo_name}' es demasiado largo. La longitud máxima es 96 caracteres.")
36
- return repo_name
37
-
38
- # Inicialización del cliente de GCS
39
  try:
40
- GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME) # Validar el nombre del bucket
41
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
42
  storage_client = storage.Client.from_service_account_info(credentials_info)
43
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
@@ -45,16 +44,13 @@ except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, Valu
45
  print(f"Error al cargar credenciales o bucket: {e}")
46
  exit(1)
47
 
48
- # Inicialización de FastAPI
49
  app = FastAPI()
50
 
51
-
52
  class DownloadModelRequest(BaseModel):
53
  model_name: str
54
  pipeline_task: str
55
  input_text: str
56
 
57
-
58
  class GCSStreamHandler:
59
  def __init__(self, bucket_name):
60
  self.bucket = storage_client.bucket(bucket_name)
@@ -65,21 +61,18 @@ class GCSStreamHandler:
65
  def stream_file_from_gcs(self, blob_name):
66
  blob = self.bucket.blob(blob_name)
67
  if not blob.exists():
68
- raise HTTPException(status_code=404, detail=f"Archivo '{blob_name}' no encontrado en GCS.")
69
  return blob.download_as_bytes()
70
 
71
  def upload_file_to_gcs(self, blob_name, data_stream):
72
  blob = self.bucket.blob(blob_name)
73
  blob.upload_from_file(data_stream)
74
- print(f"Archivo {blob_name} subido a GCS.")
75
 
76
  def ensure_bucket_structure(self, model_prefix):
77
- # Crea automáticamente la estructura en el bucket si no existe
78
  required_files = ["config.json", "tokenizer.json"]
79
  for filename in required_files:
80
  blob_name = f"{model_prefix}/{filename}"
81
  if not self.file_exists(blob_name):
82
- print(f"Creando archivo ficticio: {blob_name}")
83
  self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
84
 
85
  def stream_model_files(self, model_prefix, model_patterns):
@@ -88,29 +81,19 @@ class GCSStreamHandler:
88
  blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
89
  for blob in blobs:
90
  if re.match(pattern, blob.name.split('/')[-1]):
91
- print(f"Archivo encontrado: {blob.name}")
92
  model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
93
  return model_files
94
 
95
-
96
  def download_model_from_huggingface(model_name):
97
- """
98
- Descarga un modelo desde Hugging Face y lo sube a GCS en streaming.
99
- """
100
- model_name = validate_huggingface_repo_name(model_name) # Validar nombre del repositorio
101
-
102
  file_patterns = [
103
  "pytorch_model.bin",
104
  "model.safetensors",
105
  "config.json",
106
  "tokenizer.json",
107
  ]
108
-
109
- # Agregar patrones para fragmentos de modelos
110
  for i in range(1, 100):
111
  file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
112
  file_patterns.append(f"model-{i:05}")
113
-
114
  for filename in file_patterns:
115
  url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
116
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
@@ -120,20 +103,13 @@ def download_model_from_huggingface(model_name):
120
  blob_name = f"{model_name}/{filename}"
121
  blob = bucket.blob(blob_name)
122
  blob.upload_from_file(BytesIO(response.content))
123
- print(f"Archivo {filename} subido correctamente a GCS.")
124
  except Exception as e:
125
- print(f"Archivo {filename} no encontrado: {e}")
126
-
127
 
128
  @app.post("/predict/")
129
  async def predict(request: DownloadModelRequest):
130
- """
131
- Endpoint para realizar predicciones. Si el modelo no existe en GCS, se descarga automáticamente.
132
- """
133
  try:
134
  gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
135
-
136
- # Verificar si el modelo ya está en GCS
137
  model_prefix = request.model_name
138
  model_patterns = [
139
  r"pytorch_model-\d+-of-\d+",
@@ -141,43 +117,22 @@ async def predict(request: DownloadModelRequest):
141
  r"pytorch_model.bin",
142
  r"model.safetensors",
143
  ]
144
-
145
  if not any(
146
  gcs_handler.file_exists(f"{model_prefix}/{pattern}") for pattern in model_patterns
147
  ):
148
- print(f"Modelo {model_prefix} no encontrado en GCS. Descargando desde Hugging Face...")
149
  download_model_from_huggingface(model_prefix)
150
-
151
- # Carga archivos del modelo desde GCS
152
  model_files = gcs_handler.stream_model_files(model_prefix, model_patterns)
153
-
154
- # Configuración y tokenización
155
  config_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/config.json")
156
  tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/tokenizer.json")
157
-
158
  model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
159
- state_dict = {}
160
-
161
- for filename, stream in model_files.items():
162
- state_dict.update(torch.load(stream, map_location="cpu"))
163
-
164
- model.load_state_dict(state_dict)
165
  tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
166
-
167
- # Crear pipeline
168
  pipeline_task = request.pipeline_task
169
- if pipeline_task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering"]:
170
- raise HTTPException(status_code=400, detail="Unsupported pipeline task")
171
-
172
  pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
173
  input_text = request.input_text
174
  result = pipeline_(input_text)
175
-
176
  return {"response": result}
177
-
178
  except Exception as e:
179
  raise HTTPException(status_code=500, detail=f"Error: {e}")
180
 
181
-
182
  if __name__ == "__main__":
183
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
  import re
3
  import requests
4
+ import json
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
+ from google.auth import exceptions
8
  from google.cloud import storage
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
10
  from io import BytesIO
11
  from dotenv import load_dotenv
12
  import uvicorn
 
 
13
 
14
  load_dotenv()
15
 
 
16
  API_KEY = os.getenv("API_KEY")
17
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
18
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
19
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
20
 
 
21
  def validate_bucket_name(bucket_name):
22
+ if not isinstance(bucket_name, str):
23
+ raise ValueError("Bucket name must be a string.")
24
+ if len(bucket_name) < 3 or len(bucket_name) > 63:
25
+ raise ValueError("Bucket name must be between 3 and 63 characters long.")
26
  if not re.match(r"^[a-z0-9][a-z0-9\-\.]*[a-z0-9]$", bucket_name):
27
+ raise ValueError(
28
+ f"Invalid bucket name '{bucket_name}'. Bucket names must:"
29
+ " - Use only lowercase letters, numbers, hyphens (-), and periods (.)"
30
+ " - Start and end with a letter or number."
31
+ )
32
+ if "--" in bucket_name or ".." in bucket_name or ".-" in bucket_name or "-." in bucket_name:
33
+ raise ValueError(
34
+ f"Invalid bucket name '{bucket_name}'. Bucket names cannot contain consecutive periods, hyphens, or use '.-' or '-.'"
35
+ )
36
  return bucket_name
37
 
 
 
 
 
 
 
 
 
 
 
 
38
  try:
39
+ GCS_BUCKET_NAME = validate_bucket_name(GCS_BUCKET_NAME)
40
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
41
  storage_client = storage.Client.from_service_account_info(credentials_info)
42
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
 
44
  print(f"Error al cargar credenciales o bucket: {e}")
45
  exit(1)
46
 
 
47
  app = FastAPI()
48
 
 
49
  class DownloadModelRequest(BaseModel):
50
  model_name: str
51
  pipeline_task: str
52
  input_text: str
53
 
 
54
  class GCSStreamHandler:
55
  def __init__(self, bucket_name):
56
  self.bucket = storage_client.bucket(bucket_name)
 
61
  def stream_file_from_gcs(self, blob_name):
62
  blob = self.bucket.blob(blob_name)
63
  if not blob.exists():
64
+ raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found in GCS.")
65
  return blob.download_as_bytes()
66
 
67
  def upload_file_to_gcs(self, blob_name, data_stream):
68
  blob = self.bucket.blob(blob_name)
69
  blob.upload_from_file(data_stream)
 
70
 
71
  def ensure_bucket_structure(self, model_prefix):
 
72
  required_files = ["config.json", "tokenizer.json"]
73
  for filename in required_files:
74
  blob_name = f"{model_prefix}/{filename}"
75
  if not self.file_exists(blob_name):
 
76
  self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
77
 
78
  def stream_model_files(self, model_prefix, model_patterns):
 
81
  blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
82
  for blob in blobs:
83
  if re.match(pattern, blob.name.split('/')[-1]):
 
84
  model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
85
  return model_files
86
 
 
87
  def download_model_from_huggingface(model_name):
 
 
 
 
 
88
  file_patterns = [
89
  "pytorch_model.bin",
90
  "model.safetensors",
91
  "config.json",
92
  "tokenizer.json",
93
  ]
 
 
94
  for i in range(1, 100):
95
  file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
96
  file_patterns.append(f"model-{i:05}")
 
97
  for filename in file_patterns:
98
  url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
99
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
 
103
  blob_name = f"{model_name}/{filename}"
104
  blob = bucket.blob(blob_name)
105
  blob.upload_from_file(BytesIO(response.content))
 
106
  except Exception as e:
107
+ pass
 
108
 
109
  @app.post("/predict/")
110
  async def predict(request: DownloadModelRequest):
 
 
 
111
  try:
112
  gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
 
 
113
  model_prefix = request.model_name
114
  model_patterns = [
115
  r"pytorch_model-\d+-of-\d+",
 
117
  r"pytorch_model.bin",
118
  r"model.safetensors",
119
  ]
 
120
  if not any(
121
  gcs_handler.file_exists(f"{model_prefix}/{pattern}") for pattern in model_patterns
122
  ):
 
123
  download_model_from_huggingface(model_prefix)
 
 
124
  model_files = gcs_handler.stream_model_files(model_prefix, model_patterns)
 
 
125
  config_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/config.json")
126
  tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/tokenizer.json")
 
127
  model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
 
 
 
 
 
 
128
  tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
 
 
129
  pipeline_task = request.pipeline_task
 
 
 
130
  pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
131
  input_text = request.input_text
132
  result = pipeline_(input_text)
 
133
  return {"response": result}
 
134
  except Exception as e:
135
  raise HTTPException(status_code=500, detail=f"Error: {e}")
136
 
 
137
  if __name__ == "__main__":
138
+ uvicorn.run(app, host="0.0.0.0", port=8000)