Hjgugugjhuhjggg commited on
Commit
f173552
1 Parent(s): cff40fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -59
app.py CHANGED
@@ -7,8 +7,7 @@ from fastapi import FastAPI, HTTPException
7
  from pydantic import BaseModel
8
  from google.cloud import storage
9
  from google.auth import exceptions
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
- from transformers.hf_api import HfApi, HfFolder, HfLoginManager
12
  from io import BytesIO
13
  from dotenv import load_dotenv
14
  import uvicorn
@@ -33,16 +32,6 @@ except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError) as e
33
  # Inicialización de FastAPI
34
  app = FastAPI()
35
 
36
- # Inicio de sesión en Hugging Face
37
- try:
38
- if not HF_API_TOKEN:
39
- raise ValueError("El token de Hugging Face no está definido en las variables de entorno.")
40
- HfApi().set_access_token(HF_API_TOKEN)
41
- print("Inicio de sesión en Hugging Face exitoso.")
42
- except Exception as e:
43
- print(f"Error al iniciar sesión en Hugging Face: {e}")
44
- exit(1)
45
-
46
 
47
  class DownloadModelRequest(BaseModel):
48
  model_name: str
@@ -88,28 +77,65 @@ class GCSStreamHandler:
88
  return model_files
89
 
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  @app.post("/predict/")
92
  async def predict(request: DownloadModelRequest):
 
 
 
93
  try:
94
  gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
95
 
96
- # Asegura la estructura del bucket
97
- gcs_handler.ensure_bucket_structure(request.model_name)
98
-
99
- # Define patrones para los archivos de modelos
100
  model_patterns = [
101
  r"pytorch_model-\d+-of-\d+",
102
  r"model-\d+",
103
  r"pytorch_model.bin",
104
- r"model.safetensors"
105
  ]
106
 
107
- # Carga los archivos del modelo desde el bucket
108
- model_files = gcs_handler.stream_model_files(request.model_name, model_patterns)
 
 
 
 
 
 
109
 
110
- # Cargar configuración y modelo
111
- config_stream = gcs_handler.stream_file_from_gcs(f"{request.model_name}/config.json")
112
- tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{request.model_name}/tokenizer.json")
113
 
114
  model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
115
  state_dict = {}
@@ -135,42 +161,5 @@ async def predict(request: DownloadModelRequest):
135
  raise HTTPException(status_code=500, detail=f"Error: {e}")
136
 
137
 
138
- @app.post("/upload/")
139
- async def upload_model_to_gcs(model_name: str):
140
- """
141
- Descarga un modelo desde Hugging Face y lo sube a GCS en streaming.
142
- """
143
- try:
144
- gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
145
-
146
- # Archivos comunes de los modelos
147
- file_patterns = [
148
- "pytorch_model.bin",
149
- "model.safetensors",
150
- "config.json",
151
- "tokenizer.json",
152
- ]
153
-
154
- # Agregar patrones para fragmentos de modelos
155
- for i in range(1, 100):
156
- file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
157
- file_patterns.append(f"model-{i:05}")
158
-
159
- for filename in file_patterns:
160
- url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
161
- headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
162
- try:
163
- response = requests.get(url, headers=headers, stream=True)
164
- if response.status_code == 200:
165
- blob_name = f"{model_name}/{filename}"
166
- blob = bucket.blob(blob_name)
167
- blob.upload_from_file(BytesIO(response.content))
168
- print(f"Archivo {filename} subido correctamente a GCS.")
169
- except Exception as e:
170
- print(f"Archivo {filename} no encontrado: {e}")
171
- except Exception as e:
172
- raise HTTPException(status_code=500, detail=f"Error al subir modelo: {e}")
173
-
174
-
175
  if __name__ == "__main__":
176
  uvicorn.run(app, host="0.0.0.0", port=8000)
 
7
  from pydantic import BaseModel
8
  from google.cloud import storage
9
  from google.auth import exceptions
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
11
  from io import BytesIO
12
  from dotenv import load_dotenv
13
  import uvicorn
 
32
  # Inicialización de FastAPI
33
  app = FastAPI()
34
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  class DownloadModelRequest(BaseModel):
37
  model_name: str
 
77
  return model_files
78
 
79
 
80
+ def download_model_from_huggingface(model_name):
81
+ """
82
+ Descarga un modelo desde Hugging Face y lo sube a GCS en streaming.
83
+ """
84
+ file_patterns = [
85
+ "pytorch_model.bin",
86
+ "model.safetensors",
87
+ "config.json",
88
+ "tokenizer.json",
89
+ ]
90
+
91
+ # Agregar patrones para fragmentos de modelos
92
+ for i in range(1, 100):
93
+ file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
94
+ file_patterns.append(f"model-{i:05}")
95
+
96
+ for filename in file_patterns:
97
+ url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
98
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
99
+ try:
100
+ response = requests.get(url, headers=headers, stream=True)
101
+ if response.status_code == 200:
102
+ blob_name = f"{model_name}/{filename}"
103
+ blob = bucket.blob(blob_name)
104
+ blob.upload_from_file(BytesIO(response.content))
105
+ print(f"Archivo {filename} subido correctamente a GCS.")
106
+ except Exception as e:
107
+ print(f"Archivo {filename} no encontrado: {e}")
108
+
109
+
110
  @app.post("/predict/")
111
  async def predict(request: DownloadModelRequest):
112
+ """
113
+ Endpoint para realizar predicciones. Si el modelo no existe en GCS, se descarga automáticamente.
114
+ """
115
  try:
116
  gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
117
 
118
+ # Verificar si el modelo ya está en GCS
119
+ model_prefix = request.model_name
 
 
120
  model_patterns = [
121
  r"pytorch_model-\d+-of-\d+",
122
  r"model-\d+",
123
  r"pytorch_model.bin",
124
+ r"model.safetensors",
125
  ]
126
 
127
+ if not any(
128
+ gcs_handler.file_exists(f"{model_prefix}/{pattern}") for pattern in model_patterns
129
+ ):
130
+ print(f"Modelo {model_prefix} no encontrado en GCS. Descargando desde Hugging Face...")
131
+ download_model_from_huggingface(model_prefix)
132
+
133
+ # Carga archivos del modelo desde GCS
134
+ model_files = gcs_handler.stream_model_files(model_prefix, model_patterns)
135
 
136
+ # Configuración y tokenización
137
+ config_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/config.json")
138
+ tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{model_prefix}/tokenizer.json")
139
 
140
  model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
141
  state_dict = {}
 
161
  raise HTTPException(status_code=500, detail=f"Error: {e}")
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  if __name__ == "__main__":
165
  uvicorn.run(app, host="0.0.0.0", port=8000)