Hjgugugjhuhjggg commited on
Commit
abeeac6
1 Parent(s): e909ba4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -116
app.py CHANGED
@@ -6,7 +6,6 @@ from google.cloud import storage
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from pydantic import BaseModel
8
  from fastapi import FastAPI, HTTPException
9
- from io import BytesIO
10
  import requests
11
  import uvicorn
12
  from dotenv import load_dotenv
@@ -21,14 +20,9 @@ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
21
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
22
  logger = logging.getLogger(__name__)
23
 
24
- try:
25
- credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
26
- storage_client = storage.Client.from_service_account_info(credentials_info)
27
- bucket = storage_client.bucket(GCS_BUCKET_NAME)
28
- logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
29
- except Exception as e:
30
- logger.error(f"Error al cargar las credenciales o bucket: {e}")
31
- raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")
32
 
33
  app = FastAPI()
34
 
@@ -46,23 +40,16 @@ class GCSHandler:
46
 
47
  def create_folder_if_not_exists(self, folder_name):
48
  if not self.file_exists(folder_name):
49
- logger.debug(f"Creando carpeta {folder_name} en GCS.")
50
  self.bucket.blob(folder_name + "/").upload_from_string("")
51
 
52
  def upload_file(self, blob_name, file_stream):
53
  self.create_folder_if_not_exists(os.path.dirname(blob_name))
54
  blob = self.bucket.blob(blob_name)
55
- try:
56
- blob.upload_from_file(file_stream)
57
- logger.info(f"Archivo '{blob_name}' subido exitosamente a GCS.")
58
- except Exception as e:
59
- logger.error(f"Error subiendo el archivo '{blob_name}' a GCS: {e}")
60
- raise HTTPException(status_code=500, detail=f"Error subiendo archivo '{blob_name}' a GCS")
61
 
62
  def download_file(self, blob_name):
63
  blob = self.bucket.blob(blob_name)
64
  if not blob.exists():
65
- logger.error(f"Archivo '{blob_name}' no encontrado en GCS.")
66
  raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
67
  return blob.open("rb")
68
 
@@ -73,29 +60,21 @@ class GCSHandler:
73
  def download_model_from_huggingface(model_name):
74
  url = f"https://huggingface.co/{model_name}/tree/main"
75
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
76
-
77
- try:
78
- logger.info(f"Descargando el modelo '{model_name}' desde Hugging Face...")
79
- response = requests.get(url, headers=headers)
80
- if response.status_code == 200:
81
- model_files = [
82
- "pytorch_model.bin",
83
- "config.json",
84
- "tokenizer.json",
85
- "model.safetensors",
86
- ]
87
- for file_name in model_files:
88
- file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
89
- file_content = requests.get(file_url).content
90
- blob_name = f"lilmeaty_garca/{model_name}/{file_name}"
91
- bucket.blob(blob_name).upload_from_string(file_content)
92
- logger.info(f"Archivo '{file_name}' subido exitosamente al bucket GCS.")
93
- else:
94
- logger.error(f"Error al acceder al árbol de archivos de Hugging Face para '{model_name}'.")
95
- raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
96
- except Exception as e:
97
- logger.error(f"Error descargando archivos de Hugging Face: {e}")
98
- raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
99
 
100
  def download_and_verify_model(model_name):
101
  model_files = [
@@ -105,8 +84,7 @@ def download_and_verify_model(model_name):
105
  "model.safetensors",
106
  ]
107
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
108
- model_files_exist = all(gcs_handler.file_exists(f"lilmeaty_garca/{model_name}/{file}") for file in model_files)
109
- if not model_files_exist:
110
  download_model_from_huggingface(model_name)
111
 
112
  def load_model_from_gcs(model_name):
@@ -117,103 +95,58 @@ def load_model_from_gcs(model_name):
117
  "model.safetensors",
118
  ]
119
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
120
- model_files_streams = {}
121
-
122
- for file in model_files:
123
- file_path = f"lilmeaty_garca/{model_name}/{file}"
124
- if gcs_handler.file_exists(file_path):
125
- model_files_streams[file] = gcs_handler.download_file(file_path)
126
- else:
127
- logger.error(f"Archivo '{file}' no encontrado en GCS.")
128
- raise HTTPException(status_code=500, detail=f"Archivo '{file}' no encontrado.")
129
-
130
  model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors")
131
  tokenizer_stream = model_files_streams.get("tokenizer.json")
132
  config_stream = model_files_streams.get("config.json")
133
-
134
  model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
135
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
136
-
137
  return model, tokenizer
138
 
139
  def load_model(model_name):
140
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
141
  try:
142
- model, tokenizer = load_model_from_gcs(model_name)
143
- logger.info(f"Modelo '{model_name}' cargado exitosamente desde GCS.")
144
- return model, tokenizer
145
  except HTTPException:
146
- logger.warning(f"Modelo '{model_name}' no encontrado en GCS. Intentando como model_id...")
147
-
148
- try:
149
  download_and_verify_model(model_name)
150
- model, tokenizer = load_model_from_gcs(model_name)
151
- logger.info(f"Modelo '{model_name}' cargado exitosamente desde Hugging Face.")
152
- return model, tokenizer
153
- except Exception as e:
154
- logger.error(f"Error al intentar cargar el modelo '{model_name}': {e}")
155
- raise HTTPException(status_code=500, detail=f"Error al cargar el modelo '{model_name}': {e}")
156
 
157
  @app.on_event("startup")
158
  async def startup():
159
- try:
160
- logger.info("Iniciando la descarga de modelos en segundo plano...")
161
- run_in_background()
162
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
163
- blobs = list(bucket.list_blobs(prefix="lilmeaty_garca/"))
164
- model_names = set([blob.name.split("/")[1] for blob in blobs])
165
-
166
- def download_model_thread(model_name):
167
- try:
168
- download_and_verify_model(model_name)
169
- except Exception as e:
170
- logger.error(f"Error descargando modelo '{model_name}': {e}")
171
-
172
- threads = []
173
- for model_name in model_names:
174
- thread = threading.Thread(target=download_model_thread, args=(model_name,))
175
- thread.start()
176
- threads.append(thread)
177
-
178
- for thread in threads:
179
- thread.join()
180
-
181
- logger.info("Todos los modelos se descargaron correctamente o ya estaban presentes.")
182
- except Exception as e:
183
- logger.error(f"Error durante la descarga de modelos al iniciar: {e}")
184
- raise HTTPException(status_code=500, detail=f"Error durante la descarga de modelos: {e}")
185
 
186
  @app.post("/predict/")
187
  async def predict(request: DownloadModelRequest):
188
- try:
189
- model_name = request.model_name
190
- pipeline_task = request.pipeline_task
191
- input_text = request.input_text
192
-
193
- model, tokenizer = load_model(model_name)
194
- pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
195
- result = pipe(input_text)
196
-
197
- return {"result": result}
198
- except Exception as e:
199
- logger.error(f"Error procesando la solicitud: {e}")
200
- raise HTTPException(status_code=500, detail=str(e))
201
 
202
  def download_all_models_in_background():
203
  models_url = "https://huggingface.co/api/models"
204
- try:
205
- response = requests.get(models_url)
206
- if response.status_code != 200:
207
- logger.error("Error al obtener la lista de modelos de Hugging Face.")
208
- raise HTTPException(status_code=500, detail="Error al obtener la lista de modelos.")
209
-
210
  models = response.json()
211
  for model in models:
212
- model_name = model["id"]
213
- download_model_from_huggingface(model_name)
214
- except Exception as e:
215
- logger.error(f"Error al descargar modelos en segundo plano: {e}")
216
- raise HTTPException(status_code=500, detail="Error al descargar modelos en segundo plano.")
217
 
218
  def run_in_background():
219
  threading.Thread(target=download_all_models_in_background, daemon=True).start()
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
  from pydantic import BaseModel
8
  from fastapi import FastAPI, HTTPException
 
9
  import requests
10
  import uvicorn
11
  from dotenv import load_dotenv
 
20
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
21
  logger = logging.getLogger(__name__)
22
 
23
+ credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
24
+ storage_client = storage.Client.from_service_account_info(credentials_info)
25
+ bucket = storage_client.bucket(GCS_BUCKET_NAME)
 
 
 
 
 
26
 
27
  app = FastAPI()
28
 
 
40
 
41
  def create_folder_if_not_exists(self, folder_name):
42
  if not self.file_exists(folder_name):
 
43
  self.bucket.blob(folder_name + "/").upload_from_string("")
44
 
45
  def upload_file(self, blob_name, file_stream):
46
  self.create_folder_if_not_exists(os.path.dirname(blob_name))
47
  blob = self.bucket.blob(blob_name)
48
+ blob.upload_from_file(file_stream)
 
 
 
 
 
49
 
50
  def download_file(self, blob_name):
51
  blob = self.bucket.blob(blob_name)
52
  if not blob.exists():
 
53
  raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
54
  return blob.open("rb")
55
 
 
60
  def download_model_from_huggingface(model_name):
61
  url = f"https://huggingface.co/{model_name}/tree/main"
62
  headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
63
+ response = requests.get(url, headers=headers)
64
+ if response.status_code == 200:
65
+ model_files = [
66
+ "pytorch_model.bin",
67
+ "config.json",
68
+ "tokenizer.json",
69
+ "model.safetensors",
70
+ ]
71
+ for file_name in model_files:
72
+ file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
73
+ file_content = requests.get(file_url).content
74
+ blob_name = f"models/{model_name}/{file_name}"
75
+ bucket.blob(blob_name).upload_from_string(file_content)
76
+ else:
77
+ raise HTTPException(status_code=404, detail="Error accessing Hugging Face model files.")
 
 
 
 
 
 
 
 
78
 
79
  def download_and_verify_model(model_name):
80
  model_files = [
 
84
  "model.safetensors",
85
  ]
86
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
87
+ if not all(gcs_handler.file_exists(f"models/{model_name}/{file}") for file in model_files):
 
88
  download_model_from_huggingface(model_name)
89
 
90
  def load_model_from_gcs(model_name):
 
95
  "model.safetensors",
96
  ]
97
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
98
+ model_files_streams = {
99
+ file: gcs_handler.download_file(f"models/{model_name}/{file}")
100
+ for file in model_files if gcs_handler.file_exists(f"models/{model_name}/{file}")
101
+ }
 
 
 
 
 
 
102
  model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors")
103
  tokenizer_stream = model_files_streams.get("tokenizer.json")
104
  config_stream = model_files_streams.get("config.json")
 
105
  model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
106
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
 
107
  return model, tokenizer
108
 
109
  def load_model(model_name):
110
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
111
  try:
112
+ return load_model_from_gcs(model_name)
 
 
113
  except HTTPException:
 
 
 
114
  download_and_verify_model(model_name)
115
+ return load_model_from_gcs(model_name)
 
 
 
 
 
116
 
117
  @app.on_event("startup")
118
  async def startup():
119
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
120
+ blobs = list(bucket.list_blobs(prefix="models/"))
121
+ model_names = set(blob.name.split("/")[1] for blob in blobs)
122
+ def download_model_thread(model_name):
123
+ try:
124
+ download_and_verify_model(model_name)
125
+ except Exception as e:
126
+ logger.error(f"Error downloading model '{model_name}': {e}")
127
+ threads = [threading.Thread(target=download_model_thread, args=(model_name,)) for model_name in model_names]
128
+ for thread in threads:
129
+ thread.start()
130
+ for thread in threads:
131
+ thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  @app.post("/predict/")
134
  async def predict(request: DownloadModelRequest):
135
+ model_name = request.model_name
136
+ pipeline_task = request.pipeline_task
137
+ input_text = request.input_text
138
+ model, tokenizer = load_model(model_name)
139
+ pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
140
+ result = pipe(input_text)
141
+ return {"result": result}
 
 
 
 
 
 
142
 
143
  def download_all_models_in_background():
144
  models_url = "https://huggingface.co/api/models"
145
+ response = requests.get(models_url)
146
+ if response.status_code == 200:
 
 
 
 
147
  models = response.json()
148
  for model in models:
149
+ download_model_from_huggingface(model["id"])
 
 
 
 
150
 
151
  def run_in_background():
152
  threading.Thread(target=download_all_models_in_background, daemon=True).start()