File size: 10,127 Bytes
319a292
f7ca3aa
64cb25e
0b77f45
f7ca3aa
2058dee
 
 
0b77f45
f7ca3aa
0b77f45
 
 
 
 
 
 
2058dee
0b77f45
 
03ed2e0
f7ca3aa
319a292
f7ca3aa
 
 
 
319a292
f7ca3aa
 
4bf1bd9
0b77f45
2058dee
 
 
 
 
 
 
 
4bf1bd9
f7ca3aa
4bf1bd9
f7ca3aa
 
 
 
 
 
 
 
 
 
2058dee
 
 
f7ca3aa
0b77f45
 
 
 
 
 
 
 
f7ca3aa
 
2058dee
 
 
 
 
 
f7ca3aa
 
 
2058dee
 
 
f7ca3aa
 
 
 
d0dc403
2058dee
 
 
 
 
 
 
 
 
0b77f45
2058dee
 
 
 
d0dc403
2058dee
 
 
 
 
 
 
 
 
 
 
0b77f45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2058dee
 
 
 
 
 
0b77f45
 
d0dc403
0b77f45
 
 
 
 
 
 
 
 
d0dc403
 
0b77f45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0dc403
 
 
 
 
 
2058dee
0b77f45
d0dc403
0b77f45
 
 
 
2058dee
0b77f45
d0dc403
 
 
0b77f45
d0dc403
 
 
 
2058dee
319a292
9ef439e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import os
import json
import logging
import io
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from google.auth import exceptions
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
import torch
import safetensors.torch
import requests
from diffusers import StableDiffusionPipeline
from audiocraft.models import AudioLM
import asyncio
import threading
import uvicorn
from transformers import pipeline as tts_pipeline
import soundfile as sf  # Para manejar el audio de salida

load_dotenv()

API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Configuraci贸n de GCS
try:
    credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
    storage_client = storage.Client.from_service_account_info(credentials_info)
    bucket = storage_client.bucket(GCS_BUCKET_NAME)
    logger.info(f"Conexi贸n con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
    logger.error(f"Error al cargar las credenciales o bucket: {e}")
    raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_name: str
    pipeline_task: str
    input_text: str

class GCSHandler:
    def __init__(self, bucket_name):
        self.bucket = storage_client.bucket(bucket_name)

    def file_exists(self, blob_name):
        exists = self.bucket.blob(blob_name).exists()
        logger.debug(f"Comprobando existencia de archivo '{blob_name}': {exists}")
        return exists

    def download_file_as_stream(self, blob_name):
        blob = self.bucket.blob(blob_name)
        if not blob.exists():
            logger.error(f"Archivo '{blob_name}' no encontrado en GCS.")
            raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
        logger.debug(f"Descargando archivo '{blob_name}' de GCS.")
        return blob.open("rb")  # Devuelve un stream (modo lectura binaria)

    def upload_file(self, blob_name, file_stream):
        blob = self.bucket.blob(blob_name)
        try:
            blob.upload_from_file(file_stream)
            logger.info(f"Archivo '{blob_name}' subido exitosamente a GCS.")
        except Exception as e:
            logger.error(f"Error subiendo el archivo '{blob_name}' a GCS: {e}")
            raise HTTPException(status_code=500, detail=f"Error subiendo archivo '{blob_name}' a GCS")

    def generate_signed_url(self, blob_name, expiration=3600):
        blob = self.bucket.blob(blob_name)
        url = blob.generate_signed_url(expiration=expiration)
        logger.debug(f"Generada URL firmada para '{blob_name}': {url}")
        return url

def download_model_from_huggingface(model_name):
    url = f"https://huggingface.co/{model_name}/tree/main"
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
    
    try:
        logger.info(f"Descargando el modelo '{model_name}' desde Hugging Face...")
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            model_files = [
                "pytorch_model.bin",
                "config.json",
                "tokenizer.json",
                "model.safetensors",
                "pytorch_model.bin"
            ]
            for file_name in model_files:
                file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
                file_content = requests.get(file_url).content
                # Subir el archivo directamente desde el contenido
                blob_name = f"{model_name}/{file_name}"
                blob = bucket.blob(blob_name)
                blob.upload_from_string(file_content)
                logger.info(f"Archivo '{file_name}' subido exitosamente al bucket GCS.")
        else:
            logger.error(f"Error al acceder al 谩rbol de archivos de Hugging Face para '{model_name}'.")
            raise HTTPException(status_code=404, detail="Error al acceder al 谩rbol de archivos de Hugging Face.")
    except Exception as e:
        logger.error(f"Error descargando archivos de Hugging Face: {e}")
        raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")

def load_model_from_gcs(model_name, gcs_handler):
    model_files = {
        "config": f"{model_name}/config.json",
        "tokenizer": f"{model_name}/tokenizer.json",
        "model_bin": f"{model_name}/pytorch_model.bin",
        "model_safetensors": f"{model_name}/model.safetensors"
    }

    model_data = {}
    for key, blob_name in model_files.items():
        if not gcs_handler.file_exists(blob_name):
            logger.info(f"{key.capitalize()} no encontrado en GCS, descargando desde Hugging Face...")
            download_model_from_huggingface(model_name)
        model_data[key] = gcs_handler.download_file_as_stream(blob_name)

    return model_data

def load_diffuser_model_from_streams(model_data, model_name):
    model_bin_stream = model_data.get("model_bin")
    model_safetensors_stream = model_data.get("model_safetensors")

    if model_bin_stream or model_safetensors_stream:
        # Cargar el modelo de difusi贸n desde los streams de GCS
        logger.info(f"Cargando modelo Diffusers para '{model_name}'...")
        pipe = StableDiffusionPipeline.from_pretrained(io.BytesIO(model_bin_stream.read()))
    else:
        raise HTTPException(status_code=404, detail="No se encontr贸 modelo compatible en el bucket.")
    
    return pipe

def load_audiocraft_model_from_streams(model_data, model_name):
    model_bin_stream = model_data.get("model_bin")
    model_safetensors_stream = model_data.get("model_safetensors")

    if model_bin_stream or model_safetensors_stream:
        # Cargar el modelo AudioCraft desde los streams de GCS
        logger.info(f"Cargando modelo Audiocraft para '{model_name}'...")
        model = AudioLM.from_pretrained(io.BytesIO(model_bin_stream.read()))
    else:
        raise HTTPException(status_code=404, detail="No se encontr贸 modelo compatible en el bucket.")
    
    return model

@app.post("/predict/")
async def predict(request: DownloadModelRequest):
    logger.info(f"Iniciando predicci贸n para el modelo '{request.model_name}' con tarea '{request.pipeline_task}'...")
    try:
        gcs_handler = GCSHandler(GCS_BUCKET_NAME)
        model_prefix = request.model_name

        model_data = load_model_from_gcs(model_prefix, gcs_handler)
        
        # Cargar los archivos de modelo y tokenizer directamente desde los streams
        config_stream = model_data["config"]
        tokenizer_stream = model_data["tokenizer"]

        if request.pipeline_task == "text-generation":
            # Usar el modelo HuggingFace normal si es una tarea de texto
            model = load_model_from_streams(model_data, model_prefix)
            tokenizer = AutoTokenizer.from_pretrained(io.BytesIO(tokenizer_stream.read()))
            pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
            result = pipe(request.input_text)
        elif request.pipeline_task == "image-generation":
            # Usar el modelo Diffuser si es tarea de generaci贸n de im谩genes
            pipe = load_diffuser_model_from_streams(model_data, model_prefix)
            result = pipe(request.input_text).images
        elif request.pipeline_task == "audio-generation":
            # Usar el modelo Audiocraft si es tarea de generaci贸n de audio
            model = load_audiocraft_model_from_streams(model_data, model_prefix)
            result = model.generate(request.input_text)
        elif request.pipeline_task == "text-to-speech":
            # TTS pipeline utilizando transformers
            tts_pipe = tts_pipeline("text-to-speech", model=model, tokenizer=tokenizer)
            audio_output = tts_pipe(request.input_text)[0]['audio']
            # Se devuelve el archivo de audio
            audio_path = "output.wav"
            sf.write(audio_path, audio_output, 16000)  # Guardar el audio en un archivo
            result = audio_path
        elif request.pipeline_task == "text-to-audio":
            # Usar audiocraft o modelo espec铆fico para text-to-audio
            model = load_audiocraft_model_from_streams(model_data, model_prefix)
            audio_output = model.generate(request.input_text)
            # Guardar o procesar el audio de salida
            audio_path = "output_audio.wav"
            sf.write(audio_path, audio_output, 16000)
            result = audio_path
        else:
            raise HTTPException(status_code=400, detail="Tarea no soportada.")
        
        logger.info(f"Resultado generado para la tarea '{request.pipeline_task}': {result[0]}")
        return {"response": result[0]}

    except HTTPException as e:
        logger.error(f"HTTPException: {e.detail}")
        raise e
    except Exception as e:
        logger.error(f"Error inesperado: {e}")
        raise HTTPException(status_code=500, detail=f"Error: {e}")

def download_model_in_background(model_name):
    try:
        gcs_handler = GCSHandler(GCS_BUCKET_NAME)
        logger.info(f"Iniciando descarga en segundo plano del modelo '{model_name}' a GCS...")
        download_model_from_huggingface(model_name)
        logger.info(f"Descarga del modelo '{model_name}' completada.")
    except Exception as e:
        logger.error(f"Error al descargar el modelo '{model_name}' en segundo plano: {e}")

def run_in_background():
    logger.info("Iniciando la descarga de modelos en segundo plano...")
    threading.Thread(target=download_model_in_background, args=("modelo_ejemplo",)).start()

@app.on_event("startup")
async def startup_event():
    run_in_background()

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)