Hjgugugjhuhjggg commited on
Commit
03ed2e0
1 Parent(s): c496fe5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -61
app.py CHANGED
@@ -1,21 +1,19 @@
1
  import os
2
  import json
3
- import uuid
4
  import logging
 
 
 
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from google.cloud import storage
8
- from google.auth import exceptions
9
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
10
  import uvicorn
11
- from google.cloud.storage.blob import Blob
12
  import requests
13
- import io
14
  from safetensors import safe_open
15
- import torch
16
-
17
- # Cargar las variables de entorno
18
  from dotenv import load_dotenv
 
19
  load_dotenv()
20
 
21
  API_KEY = os.getenv("API_KEY")
@@ -23,17 +21,15 @@ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
23
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
24
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
25
 
26
- # Configuración del logger
27
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
28
  logger = logging.getLogger(__name__)
29
 
30
- # Inicializar el cliente de Google Cloud Storage
31
  try:
32
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
33
  storage_client = storage.Client.from_service_account_info(credentials_info)
34
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
35
  logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
36
- except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
37
  logger.error(f"Error al cargar las credenciales o bucket: {e}")
38
  raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")
39
 
@@ -50,76 +46,100 @@ class GCSHandler:
50
 
51
  def file_exists(self, blob_name):
52
  exists = self.bucket.blob(blob_name).exists()
53
- logger.debug(f"Comprobando existencia de archivo '{blob_name}': {exists}")
54
  return exists
55
 
56
  def download_file(self, blob_name):
57
  blob = self.bucket.blob(blob_name)
58
  if not blob.exists():
59
- logger.error(f"Archivo '{blob_name}' no encontrado en GCS.")
60
  raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
61
- logger.debug(f"Descargando archivo '{blob_name}' de GCS.")
62
- return blob
 
 
 
63
 
64
  def generate_signed_url(self, blob_name, expiration=3600):
65
  blob = self.bucket.blob(blob_name)
66
  url = blob.generate_signed_url(expiration=expiration)
67
- logger.debug(f"Generada URL firmada para '{blob_name}': {url}")
68
  return url
69
 
70
  def load_model_from_gcs(model_name: str, model_files: list):
71
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
72
  model_blobs = {file: gcs_handler.download_file(f"{model_name}/{file}") for file in model_files}
73
 
74
- # Verificar si el modelo es de safetensors o torch
75
  model_stream = model_blobs.get("pytorch_model.bin") or model_blobs.get("model.safetensors")
76
  config_stream = model_blobs.get("config.json")
77
  tokenizer_stream = model_blobs.get("tokenizer.json")
78
 
79
  if "safetensors" in model_stream.name:
80
- model = load_safetensors_model(model_stream, config_stream)
81
  else:
82
- model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
83
 
84
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
85
 
86
  return model, tokenizer
87
 
88
- def load_safetensors_model(model_stream, config_stream):
89
- with safe_open(model_stream, framework="pt") as model_data:
90
  model = torch.load(model_data)
91
  return model
92
 
93
  def get_model_files_from_gcs(model_name: str):
94
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
95
  blob_list = list(gcs_handler.bucket.list_blobs(prefix=f"{model_name}/"))
96
- model_files = [blob.name for blob in blob_list if "pytorch_model" in blob.name or "model" in blob.name]
97
- model_files = sorted(model_files) # Asegurar que los archivos fragmentados estén en el orden correcto
98
  return model_files
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @app.post("/predict/")
101
  async def predict(request: DownloadModelRequest):
102
- logger.info(f"Iniciando predicción para el modelo '{request.model_name}' con tarea '{request.pipeline_task}'...")
103
  try:
104
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
105
  model_prefix = request.model_name
106
 
107
- # Obtener los archivos del modelo (incluyendo fragmentados)
108
  model_files = get_model_files_from_gcs(model_prefix)
109
 
110
  if not model_files:
111
- logger.error(f"Modelos no encontrados en GCS para '{model_prefix}'.")
112
- raise HTTPException(status_code=404, detail="Model files not found in GCS.")
113
 
114
- # Cargar el modelo desde GCS
115
  model, tokenizer = load_model_from_gcs(model_prefix, model_files)
116
 
117
- # Instanciar el pipeline de Hugging Face
118
  pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
119
 
120
  if request.pipeline_task in ["text-generation", "translation", "summarization"]:
121
  result = pipe(request.input_text)
122
- logger.info(f"Resultado generado para la tarea '{request.pipeline_task}': {result[0]}")
123
  return {"response": result[0]}
124
 
125
  elif request.pipeline_task == "image-generation":
@@ -160,40 +180,10 @@ async def predict(request: DownloadModelRequest):
160
  return {"response": {"model_3d_url": model_3d_url}}
161
 
162
  except HTTPException as e:
163
- logger.error(f"HTTPException: {e.detail}")
164
  raise e
165
  except Exception as e:
166
- logger.error(f"Error inesperado: {e}")
167
  raise HTTPException(status_code=500, detail=f"Error: {e}")
168
 
169
- def download_model_from_huggingface(model_name):
170
- url = f"https://huggingface.co/{model_name}/tree/main"
171
- headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
172
-
173
- try:
174
- logger.info(f"Descargando el modelo '{model_name}' desde Hugging Face...")
175
- response = requests.get(url, headers=headers)
176
- if response.status_code == 200:
177
- model_files = [
178
- "pytorch_model.bin",
179
- "config.json",
180
- "tokenizer.json",
181
- "model.safetensors",
182
- ]
183
- for file_name in model_files:
184
- file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
185
- file_content = requests.get(file_url).content
186
- blob_name = f"{model_name}/{file_name}"
187
- blob = bucket.blob(blob_name)
188
- blob.upload_from_string(file_content)
189
- logger.info(f"Archivo '{file_name}' subido exitosamente al bucket GCS.")
190
- else:
191
- logger.error(f"Error al acceder al árbol de archivos de Hugging Face para '{model_name}'.")
192
- raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
193
- except Exception as e:
194
- logger.error(f"Error descargando archivos de Hugging Face: {e}")
195
- raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
196
-
197
  @app.on_event("startup")
198
  async def startup_event():
199
  logger.info("Iniciando la API...")
 
1
  import os
2
  import json
 
3
  import logging
4
+ import uuid
5
+ import threading
6
+ import io
7
  from fastapi import FastAPI, HTTPException
8
  from pydantic import BaseModel
9
  from google.cloud import storage
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
11
  import uvicorn
12
+ import torch
13
  import requests
 
14
  from safetensors import safe_open
 
 
 
15
  from dotenv import load_dotenv
16
+
17
  load_dotenv()
18
 
19
  API_KEY = os.getenv("API_KEY")
 
21
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
22
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
23
 
 
24
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
25
  logger = logging.getLogger(__name__)
26
 
 
27
  try:
28
  credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
29
  storage_client = storage.Client.from_service_account_info(credentials_info)
30
  bucket = storage_client.bucket(GCS_BUCKET_NAME)
31
  logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
32
+ except (json.JSONDecodeError, KeyError, ValueError) as e:
33
  logger.error(f"Error al cargar las credenciales o bucket: {e}")
34
  raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")
35
 
 
46
 
47
  def file_exists(self, blob_name):
48
  exists = self.bucket.blob(blob_name).exists()
 
49
  return exists
50
 
51
  def download_file(self, blob_name):
52
  blob = self.bucket.blob(blob_name)
53
  if not blob.exists():
 
54
  raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
55
+ return blob.download_as_bytes()
56
+
57
+ def upload_file(self, blob_name, file_data):
58
+ blob = self.bucket.blob(blob_name)
59
+ blob.upload_from_file(file_data)
60
 
61
  def generate_signed_url(self, blob_name, expiration=3600):
62
  blob = self.bucket.blob(blob_name)
63
  url = blob.generate_signed_url(expiration=expiration)
 
64
  return url
65
 
66
  def load_model_from_gcs(model_name: str, model_files: list):
67
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
68
  model_blobs = {file: gcs_handler.download_file(f"{model_name}/{file}") for file in model_files}
69
 
 
70
  model_stream = model_blobs.get("pytorch_model.bin") or model_blobs.get("model.safetensors")
71
  config_stream = model_blobs.get("config.json")
72
  tokenizer_stream = model_blobs.get("tokenizer.json")
73
 
74
  if "safetensors" in model_stream.name:
75
+ model = load_safetensors_model(model_stream)
76
  else:
77
+ model = AutoModelForCausalLM.from_pretrained(io.BytesIO(model_stream), config=config_stream)
78
 
79
+ tokenizer = AutoTokenizer.from_pretrained(io.BytesIO(tokenizer_stream))
80
 
81
  return model, tokenizer
82
 
83
+ def load_safetensors_model(model_stream):
84
+ with safe_open(io.BytesIO(model_stream), framework="pt") as model_data:
85
  model = torch.load(model_data)
86
  return model
87
 
88
  def get_model_files_from_gcs(model_name: str):
89
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
90
  blob_list = list(gcs_handler.bucket.list_blobs(prefix=f"{model_name}/"))
91
+ model_files = [blob.name for blob in blob_list if any(part in blob.name for part in ["pytorch_model", "model"]) and "index" not in blob.name]
92
+ model_files = sorted(model_files)
93
  return model_files
94
 
95
+ def download_model_from_huggingface(model_name):
96
+ url = f"https://huggingface.co/{model_name}/tree/main"
97
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
98
+
99
+ try:
100
+ response = requests.get(url, headers=headers)
101
+ if response.status_code == 200:
102
+ model_files = [
103
+ "pytorch_model.bin",
104
+ "config.json",
105
+ "tokenizer.json",
106
+ "model.safetensors",
107
+ ]
108
+ def download_file(file_name):
109
+ file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
110
+ file_content = requests.get(file_url).content
111
+ blob_name = f"{model_name}/{file_name}"
112
+ blob = bucket.blob(blob_name)
113
+ blob.upload_from_string(file_content)
114
+
115
+ threads = [threading.Thread(target=download_file, args=(file_name,)) for file_name in model_files]
116
+ for thread in threads:
117
+ thread.start()
118
+ for thread in threads:
119
+ thread.join()
120
+ else:
121
+ raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
122
+ except Exception as e:
123
+ raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
124
+
125
  @app.post("/predict/")
126
  async def predict(request: DownloadModelRequest):
 
127
  try:
128
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
129
  model_prefix = request.model_name
130
 
 
131
  model_files = get_model_files_from_gcs(model_prefix)
132
 
133
  if not model_files:
134
+ download_model_from_huggingface(model_prefix)
135
+ model_files = get_model_files_from_gcs(model_prefix)
136
 
 
137
  model, tokenizer = load_model_from_gcs(model_prefix, model_files)
138
 
 
139
  pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
140
 
141
  if request.pipeline_task in ["text-generation", "translation", "summarization"]:
142
  result = pipe(request.input_text)
 
143
  return {"response": result[0]}
144
 
145
  elif request.pipeline_task == "image-generation":
 
180
  return {"response": {"model_3d_url": model_3d_url}}
181
 
182
  except HTTPException as e:
 
183
  raise e
184
  except Exception as e:
 
185
  raise HTTPException(status_code=500, detail=f"Error: {e}")
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  @app.on_event("startup")
188
  async def startup_event():
189
  logger.info("Iniciando la API...")