File size: 5,412 Bytes
319a292
f84a20c
1c3034c
db17ba5
 
319a292
b5bc6a9
f173552
319a292
 
 
823bbba
319a292
 
 
 
 
 
f84a20c
319a292
 
 
 
 
09dbcd2
efa228b
319a292
 
 
 
 
 
 
 
b5bc6a9
319a292
 
 
 
 
 
b5bc6a9
319a292
b5bc6a9
319a292
b5bc6a9
319a292
b5bc6a9
 
 
319a292
f173552
db17ba5
 
3e20aa7
db17ba5
 
 
823bbba
 
 
 
 
 
 
 
 
 
 
 
db17ba5
 
 
 
f173552
319a292
 
 
b5bc6a9
f173552
b5bc6a9
 
 
 
 
319a292
3e20aa7
 
 
 
 
 
f173552
3e20aa7
 
b5bc6a9
3e20aa7
 
b5bc6a9
 
823bbba
3e20aa7
823bbba
b5bc6a9
3e20aa7
823bbba
 
 
 
 
 
 
 
 
 
 
 
 
 
be0c6cb
 
823bbba
 
 
 
 
 
3e20aa7
319a292
3e20aa7
 
 
319a292
 
 
 
d84cd10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import json
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from google.auth import exceptions
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from io import BytesIO
from dotenv import load_dotenv
import uvicorn
import tempfile

load_dotenv()

API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")

try:
    credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
    storage_client = storage.Client.from_service_account_info(credentials_info)
    bucket = storage_client.bucket(GCS_BUCKET_NAME)
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
    raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")

app = FastAPI()

class DownloadModelRequest(BaseModel):
    model_name: str
    pipeline_task: str
    input_text: str

class GCSHandler:
    def __init__(self, bucket_name):
        self.bucket = storage_client.bucket(bucket_name)

    def file_exists(self, blob_name):
        return self.bucket.blob(blob_name).exists()

    def upload_file(self, blob_name, file_stream):
        blob = self.bucket.blob(blob_name)
        blob.upload_from_file(file_stream)

    def download_file(self, blob_name):
        blob = self.bucket.blob(blob_name)
        if not blob.exists():
            raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
        return BytesIO(blob.download_as_bytes())

def download_model_from_huggingface(model_name):
    url = f"https://huggingface.co/{model_name}/tree/main"
    headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
    
    try:
        response = requests.get(url, headers=headers)
        if response.status_code == 200:
            # Enlace a los archivos del modelo
            model_files = [
                "pytorch_model.bin",
                "config.json",
                "tokenizer.json",
                "model.safetensors",
            ]
            for file_name in model_files:
                file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
                file_content = requests.get(file_url).content
                blob_name = f"{model_name}/{file_name}"
                bucket.blob(blob_name).upload_from_file(BytesIO(file_content))
        else:
            raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")

@app.post("/predict/")
async def predict(request: DownloadModelRequest):
    try:
        gcs_handler = GCSHandler(GCS_BUCKET_NAME)
        model_prefix = request.model_name
        model_files = [
            "pytorch_model.bin",
            "config.json",
            "tokenizer.json",
            "model.safetensors",
        ]
        
        # Verificar si los archivos del modelo están en GCS
        model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
        
        if not model_files_exist:
            # Descargar el modelo si no existe
            download_model_from_huggingface(model_prefix)
        
        # Descargar los archivos necesarios
        model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
        
        # Asegurar que los archivos esenciales estén presentes
        config_stream = model_files_streams.get("config.json")
        tokenizer_stream = model_files_streams.get("tokenizer.json")
        model_stream = model_files_streams.get("pytorch_model.bin")
        
        if not config_stream or not tokenizer_stream or not model_stream:
            raise HTTPException(status_code=500, detail="Required model files missing.")
        
        # Guardar los archivos en directorios temporales
        with tempfile.TemporaryDirectory() as tmp_dir:
            config_path = os.path.join(tmp_dir, "config.json")
            tokenizer_path = os.path.join(tmp_dir, "tokenizer.json")
            model_path = os.path.join(tmp_dir, "pytorch_model.bin")

            with open(config_path, 'wb') as f:
                f.write(config_stream.read())
            with open(tokenizer_path, 'wb') as f:
                f.write(tokenizer_stream.read())
            with open(model_path, 'wb') as f:
                f.write(model_stream.read())
            
            # Cargar el modelo y el tokenizador desde los archivos temporales
            model = AutoModelForCausalLM.from_pretrained(tmp_dir)
            tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
            
            # Crear un pipeline para la tarea deseada
            pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
            
            # Realizar la predicción
            result = pipeline_(request.input_text)
        
        return {"response": result}
    
    except HTTPException as e:
        raise e
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error: {e}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)