Hjgugugjhuhjggg commited on
Commit
e909ba4
verified
1 Parent(s): 2bb4773

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -42
app.py CHANGED
@@ -21,10 +21,14 @@ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
21
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
22
  logger = logging.getLogger(__name__)
23
 
24
- credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
25
- storage_client = storage.Client.from_service_account_info(credentials_info)
26
- bucket = storage_client.bucket(GCS_BUCKET_NAME)
27
- logger.info(f"Conexi贸n con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
 
 
 
 
28
 
29
  app = FastAPI()
30
 
@@ -42,65 +46,142 @@ class GCSHandler:
42
 
43
  def create_folder_if_not_exists(self, folder_name):
44
  if not self.file_exists(folder_name):
 
45
  self.bucket.blob(folder_name + "/").upload_from_string("")
46
 
47
  def upload_file(self, blob_name, file_stream):
48
  self.create_folder_if_not_exists(os.path.dirname(blob_name))
49
  blob = self.bucket.blob(blob_name)
50
- blob.upload_from_file(file_stream)
 
 
 
 
 
51
 
52
  def download_file(self, blob_name):
53
  blob = self.bucket.blob(blob_name)
54
  if not blob.exists():
 
55
  raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
56
- return BytesIO(blob.download_as_bytes())
 
 
 
 
57
 
58
  def download_model_from_huggingface(model_name):
59
- base_url = f"https://huggingface.co/{model_name}/resolve/main/"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  model_files = [
61
- "pytorch_model.bin", "config.json", "tokenizer.json", "model.safetensors",
62
- "pytorch_model.bin.index.json", "tokenizer_config.json",
63
- "special_tokens_map.json", "vocab.json", "merges.txt"
 
64
  ]
65
- for filename in model_files:
66
- try:
67
- url = base_url + filename
68
- response = requests.get(url, stream=True, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
69
- response.raise_for_status()
70
- blob_name = f"lilmeaty_garca/{model_name}/{filename}"
71
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
72
- gcs_handler.upload_file(blob_name, response.raw)
73
- except requests.exceptions.RequestException as e:
74
- logger.warning(f"No se pudo descargar {filename} para {model_name}: {e}")
75
 
76
  def load_model_from_gcs(model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
78
  try:
79
- model = AutoModelForCausalLM.from_pretrained(f"gs://{GCS_BUCKET_NAME}/lilmeaty_garca/{model_name}")
80
- tokenizer = AutoTokenizer.from_pretrained(f"gs://{GCS_BUCKET_NAME}/lilmeaty_garca/{model_name}")
 
 
 
 
 
 
 
 
81
  return model, tokenizer
82
  except Exception as e:
83
- logger.error(f"Error al cargar el modelo '{model_name}' desde GCS: {e}")
84
  raise HTTPException(status_code=500, detail=f"Error al cargar el modelo '{model_name}': {e}")
85
 
86
  @app.on_event("startup")
87
  async def startup():
88
- def download_all_models_in_background():
89
- models_url = "https://huggingface.co/api/models?full=true&limit=100"
90
- try:
91
- while models_url:
92
- response = requests.get(models_url)
93
- response.raise_for_status()
94
- models_data = response.json()
95
- for model in models_data["models"]: # Corrected: Access 'models' list
96
- model_name = model["id"]
97
- download_model_from_huggingface(model_name)
98
-
99
- models_url = models_data.get("next") #removed , None because its not necessary
100
- except Exception as e:
101
- logger.error(f"Error al descargar modelos en segundo plano: {e}")
102
-
103
- threading.Thread(target=download_all_models_in_background, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
104
 
105
  @app.post("/predict/")
106
  async def predict(request: DownloadModelRequest):
@@ -108,13 +189,34 @@ async def predict(request: DownloadModelRequest):
108
  model_name = request.model_name
109
  pipeline_task = request.pipeline_task
110
  input_text = request.input_text
111
- model, tokenizer = load_model_from_gcs(model_name)
112
- pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer, device=0 if os.getenv("USE_GPU") else -1)
 
113
  result = pipe(input_text)
 
114
  return {"result": result}
115
  except Exception as e:
116
  logger.error(f"Error procesando la solicitud: {e}")
117
  raise HTTPException(status_code=500, detail=str(e))
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if __name__ == "__main__":
120
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
21
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
22
  logger = logging.getLogger(__name__)
23
 
24
+ try:
25
+ credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
26
+ storage_client = storage.Client.from_service_account_info(credentials_info)
27
+ bucket = storage_client.bucket(GCS_BUCKET_NAME)
28
+ logger.info(f"Conexi贸n con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
29
+ except Exception as e:
30
+ logger.error(f"Error al cargar las credenciales o bucket: {e}")
31
+ raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")
32
 
33
  app = FastAPI()
34
 
 
46
 
47
  def create_folder_if_not_exists(self, folder_name):
48
  if not self.file_exists(folder_name):
49
+ logger.debug(f"Creando carpeta {folder_name} en GCS.")
50
  self.bucket.blob(folder_name + "/").upload_from_string("")
51
 
52
  def upload_file(self, blob_name, file_stream):
53
  self.create_folder_if_not_exists(os.path.dirname(blob_name))
54
  blob = self.bucket.blob(blob_name)
55
+ try:
56
+ blob.upload_from_file(file_stream)
57
+ logger.info(f"Archivo '{blob_name}' subido exitosamente a GCS.")
58
+ except Exception as e:
59
+ logger.error(f"Error subiendo el archivo '{blob_name}' a GCS: {e}")
60
+ raise HTTPException(status_code=500, detail=f"Error subiendo archivo '{blob_name}' a GCS")
61
 
62
  def download_file(self, blob_name):
63
  blob = self.bucket.blob(blob_name)
64
  if not blob.exists():
65
+ logger.error(f"Archivo '{blob_name}' no encontrado en GCS.")
66
  raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
67
+ return blob.open("rb")
68
+
69
+ def generate_signed_url(self, blob_name, expiration=3600):
70
+ blob = self.bucket.blob(blob_name)
71
+ return blob.generate_signed_url(expiration=expiration)
72
 
73
  def download_model_from_huggingface(model_name):
74
+ url = f"https://huggingface.co/{model_name}/tree/main"
75
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
76
+
77
+ try:
78
+ logger.info(f"Descargando el modelo '{model_name}' desde Hugging Face...")
79
+ response = requests.get(url, headers=headers)
80
+ if response.status_code == 200:
81
+ model_files = [
82
+ "pytorch_model.bin",
83
+ "config.json",
84
+ "tokenizer.json",
85
+ "model.safetensors",
86
+ ]
87
+ for file_name in model_files:
88
+ file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
89
+ file_content = requests.get(file_url).content
90
+ blob_name = f"lilmeaty_garca/{model_name}/{file_name}"
91
+ bucket.blob(blob_name).upload_from_string(file_content)
92
+ logger.info(f"Archivo '{file_name}' subido exitosamente al bucket GCS.")
93
+ else:
94
+ logger.error(f"Error al acceder al 谩rbol de archivos de Hugging Face para '{model_name}'.")
95
+ raise HTTPException(status_code=404, detail="Error al acceder al 谩rbol de archivos de Hugging Face.")
96
+ except Exception as e:
97
+ logger.error(f"Error descargando archivos de Hugging Face: {e}")
98
+ raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
99
+
100
+ def download_and_verify_model(model_name):
101
  model_files = [
102
+ "pytorch_model.bin",
103
+ "config.json",
104
+ "tokenizer.json",
105
+ "model.safetensors",
106
  ]
107
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
108
+ model_files_exist = all(gcs_handler.file_exists(f"lilmeaty_garca/{model_name}/{file}") for file in model_files)
109
+ if not model_files_exist:
110
+ download_model_from_huggingface(model_name)
 
 
 
 
 
 
111
 
112
  def load_model_from_gcs(model_name):
113
+ model_files = [
114
+ "pytorch_model.bin",
115
+ "config.json",
116
+ "tokenizer.json",
117
+ "model.safetensors",
118
+ ]
119
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
120
+ model_files_streams = {}
121
+
122
+ for file in model_files:
123
+ file_path = f"lilmeaty_garca/{model_name}/{file}"
124
+ if gcs_handler.file_exists(file_path):
125
+ model_files_streams[file] = gcs_handler.download_file(file_path)
126
+ else:
127
+ logger.error(f"Archivo '{file}' no encontrado en GCS.")
128
+ raise HTTPException(status_code=500, detail=f"Archivo '{file}' no encontrado.")
129
+
130
+ model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors")
131
+ tokenizer_stream = model_files_streams.get("tokenizer.json")
132
+ config_stream = model_files_streams.get("config.json")
133
+
134
+ model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
135
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
136
+
137
+ return model, tokenizer
138
+
139
+ def load_model(model_name):
140
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
141
  try:
142
+ model, tokenizer = load_model_from_gcs(model_name)
143
+ logger.info(f"Modelo '{model_name}' cargado exitosamente desde GCS.")
144
+ return model, tokenizer
145
+ except HTTPException:
146
+ logger.warning(f"Modelo '{model_name}' no encontrado en GCS. Intentando como model_id...")
147
+
148
+ try:
149
+ download_and_verify_model(model_name)
150
+ model, tokenizer = load_model_from_gcs(model_name)
151
+ logger.info(f"Modelo '{model_name}' cargado exitosamente desde Hugging Face.")
152
  return model, tokenizer
153
  except Exception as e:
154
+ logger.error(f"Error al intentar cargar el modelo '{model_name}': {e}")
155
  raise HTTPException(status_code=500, detail=f"Error al cargar el modelo '{model_name}': {e}")
156
 
157
  @app.on_event("startup")
158
  async def startup():
159
+ try:
160
+ logger.info("Iniciando la descarga de modelos en segundo plano...")
161
+ run_in_background()
162
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
163
+ blobs = list(bucket.list_blobs(prefix="lilmeaty_garca/"))
164
+ model_names = set([blob.name.split("/")[1] for blob in blobs])
165
+
166
+ def download_model_thread(model_name):
167
+ try:
168
+ download_and_verify_model(model_name)
169
+ except Exception as e:
170
+ logger.error(f"Error descargando modelo '{model_name}': {e}")
171
+
172
+ threads = []
173
+ for model_name in model_names:
174
+ thread = threading.Thread(target=download_model_thread, args=(model_name,))
175
+ thread.start()
176
+ threads.append(thread)
177
+
178
+ for thread in threads:
179
+ thread.join()
180
+
181
+ logger.info("Todos los modelos se descargaron correctamente o ya estaban presentes.")
182
+ except Exception as e:
183
+ logger.error(f"Error durante la descarga de modelos al iniciar: {e}")
184
+ raise HTTPException(status_code=500, detail=f"Error durante la descarga de modelos: {e}")
185
 
186
  @app.post("/predict/")
187
  async def predict(request: DownloadModelRequest):
 
189
  model_name = request.model_name
190
  pipeline_task = request.pipeline_task
191
  input_text = request.input_text
192
+
193
+ model, tokenizer = load_model(model_name)
194
+ pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
195
  result = pipe(input_text)
196
+
197
  return {"result": result}
198
  except Exception as e:
199
  logger.error(f"Error procesando la solicitud: {e}")
200
  raise HTTPException(status_code=500, detail=str(e))
201
 
202
+ def download_all_models_in_background():
203
+ models_url = "https://huggingface.co/api/models"
204
+ try:
205
+ response = requests.get(models_url)
206
+ if response.status_code != 200:
207
+ logger.error("Error al obtener la lista de modelos de Hugging Face.")
208
+ raise HTTPException(status_code=500, detail="Error al obtener la lista de modelos.")
209
+
210
+ models = response.json()
211
+ for model in models:
212
+ model_name = model["id"]
213
+ download_model_from_huggingface(model_name)
214
+ except Exception as e:
215
+ logger.error(f"Error al descargar modelos en segundo plano: {e}")
216
+ raise HTTPException(status_code=500, detail="Error al descargar modelos en segundo plano.")
217
+
218
+ def run_in_background():
219
+ threading.Thread(target=download_all_models_in_background, daemon=True).start()
220
+
221
  if __name__ == "__main__":
222
+ uvicorn.run(app, host="0.0.0.0", port=7860)