Hjgugugjhuhjggg commited on
Commit
319a292
1 Parent(s): 51669c7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import requests
5
+ import torch
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+ from google.cloud import storage
9
+ from google.auth import exceptions
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.hf_api import HfApi, HfFolder, HfLoginManager
12
+ from io import BytesIO
13
+ from dotenv import load_dotenv
14
+ import uvicorn
15
+
16
+ load_dotenv()
17
+
18
+ # Variables de entorno
19
+ API_KEY = os.getenv("API_KEY")
20
+ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
21
+ GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
22
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Token de Hugging Face
23
+
24
+ # Inicialización del cliente de GCS
25
+ try:
26
+ credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
27
+ storage_client = storage.Client.from_service_account_info(credentials_info)
28
+ bucket = storage_client.bucket(GCS_BUCKET_NAME)
29
+ except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError) as e:
30
+ print(f"Error al cargar credenciales o bucket: {e}")
31
+ exit(1)
32
+
33
+ # Inicialización de FastAPI
34
+ app = FastAPI()
35
+
36
+ # Inicio de sesión en Hugging Face
37
+ try:
38
+ if not HF_API_TOKEN:
39
+ raise ValueError("El token de Hugging Face no está definido en las variables de entorno.")
40
+ HfApi().set_access_token(HF_API_TOKEN)
41
+ print("Inicio de sesión en Hugging Face exitoso.")
42
+ except Exception as e:
43
+ print(f"Error al iniciar sesión en Hugging Face: {e}")
44
+ exit(1)
45
+
46
+
47
+ class DownloadModelRequest(BaseModel):
48
+ model_name: str
49
+ pipeline_task: str
50
+ input_text: str
51
+
52
+
53
+ class GCSStreamHandler:
54
+ def __init__(self, bucket_name):
55
+ self.bucket = storage_client.bucket(bucket_name)
56
+
57
+ def file_exists(self, blob_name):
58
+ return self.bucket.blob(blob_name).exists()
59
+
60
+ def stream_file_from_gcs(self, blob_name):
61
+ blob = self.bucket.blob(blob_name)
62
+ if not blob.exists():
63
+ raise HTTPException(status_code=404, detail=f"Archivo '{blob_name}' no encontrado en GCS.")
64
+ return blob.download_as_bytes()
65
+
66
+ def upload_file_to_gcs(self, blob_name, data_stream):
67
+ blob = self.bucket.blob(blob_name)
68
+ blob.upload_from_file(data_stream)
69
+ print(f"Archivo {blob_name} subido a GCS.")
70
+
71
+ def ensure_bucket_structure(self, model_prefix):
72
+ # Crea automáticamente la estructura en el bucket si no existe
73
+ required_files = ["config.json", "tokenizer.json"]
74
+ for filename in required_files:
75
+ blob_name = f"{model_prefix}/{filename}"
76
+ if not self.file_exists(blob_name):
77
+ print(f"Creando archivo ficticio: {blob_name}")
78
+ self.bucket.blob(blob_name).upload_from_string("{}", content_type="application/json")
79
+
80
+ def stream_model_files(self, model_prefix, model_patterns):
81
+ model_files = {}
82
+ for pattern in model_patterns:
83
+ blobs = list(self.bucket.list_blobs(prefix=f"{model_prefix}/"))
84
+ for blob in blobs:
85
+ if re.match(pattern, blob.name.split('/')[-1]):
86
+ print(f"Archivo encontrado: {blob.name}")
87
+ model_files[blob.name.split('/')[-1]] = BytesIO(blob.download_as_bytes())
88
+ return model_files
89
+
90
+
91
+ @app.post("/predict/")
92
+ async def predict(request: DownloadModelRequest):
93
+ try:
94
+ gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
95
+
96
+ # Asegura la estructura del bucket
97
+ gcs_handler.ensure_bucket_structure(request.model_name)
98
+
99
+ # Define patrones para los archivos de modelos
100
+ model_patterns = [
101
+ r"pytorch_model-\d+-of-\d+",
102
+ r"model-\d+",
103
+ r"pytorch_model.bin",
104
+ r"model.safetensors"
105
+ ]
106
+
107
+ # Carga los archivos del modelo desde el bucket
108
+ model_files = gcs_handler.stream_model_files(request.model_name, model_patterns)
109
+
110
+ # Cargar configuración y modelo
111
+ config_stream = gcs_handler.stream_file_from_gcs(f"{request.model_name}/config.json")
112
+ tokenizer_stream = gcs_handler.stream_file_from_gcs(f"{request.model_name}/tokenizer.json")
113
+
114
+ model = AutoModelForCausalLM.from_pretrained(BytesIO(config_stream))
115
+ state_dict = {}
116
+
117
+ for filename, stream in model_files.items():
118
+ state_dict.update(torch.load(stream, map_location="cpu"))
119
+
120
+ model.load_state_dict(state_dict)
121
+ tokenizer = AutoTokenizer.from_pretrained(BytesIO(tokenizer_stream))
122
+
123
+ # Crear pipeline
124
+ pipeline_task = request.pipeline_task
125
+ if pipeline_task not in ["text-generation", "sentiment-analysis", "translation", "fill-mask", "question-answering"]:
126
+ raise HTTPException(status_code=400, detail="Unsupported pipeline task")
127
+
128
+ pipeline_ = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
129
+ input_text = request.input_text
130
+ result = pipeline_(input_text)
131
+
132
+ return {"response": result}
133
+
134
+ except Exception as e:
135
+ raise HTTPException(status_code=500, detail=f"Error: {e}")
136
+
137
+
138
+ @app.post("/upload/")
139
+ async def upload_model_to_gcs(model_name: str):
140
+ """
141
+ Descarga un modelo desde Hugging Face y lo sube a GCS en streaming.
142
+ """
143
+ try:
144
+ gcs_handler = GCSStreamHandler(GCS_BUCKET_NAME)
145
+
146
+ # Archivos comunes de los modelos
147
+ file_patterns = [
148
+ "pytorch_model.bin",
149
+ "model.safetensors",
150
+ "config.json",
151
+ "tokenizer.json",
152
+ ]
153
+
154
+ # Agregar patrones para fragmentos de modelos
155
+ for i in range(1, 100):
156
+ file_patterns.append(f"pytorch_model-{i:05}-of-{100:05}")
157
+ file_patterns.append(f"model-{i:05}")
158
+
159
+ for filename in file_patterns:
160
+ url = f"https://huggingface.co/{model_name}/resolve/main/{filename}"
161
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
162
+ try:
163
+ response = requests.get(url, headers=headers, stream=True)
164
+ if response.status_code == 200:
165
+ blob_name = f"{model_name}/{filename}"
166
+ blob = bucket.blob(blob_name)
167
+ blob.upload_from_file(BytesIO(response.content))
168
+ print(f"Archivo {filename} subido correctamente a GCS.")
169
+ except Exception as e:
170
+ print(f"Archivo {filename} no encontrado: {e}")
171
+ except Exception as e:
172
+ raise HTTPException(status_code=500, detail=f"Error al subir modelo: {e}")
173
+
174
+
175
+ if __name__ == "__main__":
176
+ uvicorn.run(app, host="0.0.0.0", port=8000)