gcs / app.py
Hjgugugjhuhjggg's picture
Update app.py
1692af9 verified
raw
history blame
4.94 kB
import os
import json
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from google.cloud import storage
from google.auth import exceptions
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from io import BytesIO
from dotenv import load_dotenv
import uvicorn
import tempfile
load_dotenv()
API_KEY = os.getenv("API_KEY")
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
try:
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
storage_client = storage.Client.from_service_account_info(credentials_info)
bucket = storage_client.bucket(GCS_BUCKET_NAME)
except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")
app = FastAPI()
class DownloadModelRequest(BaseModel):
model_name: str
pipeline_task: str
input_text: str
class GCSHandler:
def __init__(self, bucket_name):
self.bucket = storage_client.bucket(bucket_name)
def file_exists(self, blob_name):
return self.bucket.blob(blob_name).exists()
def upload_file(self, blob_name, file_stream):
blob = self.bucket.blob(blob_name)
blob.upload_from_file(file_stream)
def download_file(self, blob_name):
blob = self.bucket.blob(blob_name)
if not blob.exists():
raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
return BytesIO(blob.download_as_bytes())
def download_model_from_huggingface(model_name):
url = f"https://huggingface.co/{model_name}/tree/main"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
try:
response = requests.get(url, headers=headers)
if response.status_code == 200:
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
for file_name in model_files:
file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
file_content = requests.get(file_url).content
blob_name = f"{model_name}/{file_name}"
bucket.blob(blob_name).upload_from_file(BytesIO(file_content))
else:
raise HTTPException(status_code=404, detail="Error al acceder al árbol de archivos de Hugging Face.")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
@app.post("/predict/")
async def predict(request: DownloadModelRequest):
try:
gcs_handler = GCSHandler(GCS_BUCKET_NAME)
model_prefix = request.model_name
model_files = [
"pytorch_model.bin",
"config.json",
"tokenizer.json",
"model.safetensors",
]
model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
if not model_files_exist:
download_model_from_huggingface(model_prefix)
model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
config_stream = model_files_streams.get("config.json")
tokenizer_stream = model_files_streams.get("tokenizer.json")
model_stream = model_files_streams.get("pytorch_model.bin")
if not config_stream or not tokenizer_stream or not model_stream:
raise HTTPException(status_code=500, detail="Required model files missing.")
with tempfile.TemporaryDirectory() as tmp_dir:
config_path = os.path.join(tmp_dir, "config.json")
tokenizer_path = os.path.join(tmp_dir, "tokenizer.json")
model_path = os.path.join(tmp_dir, "pytorch_model.bin")
with open(config_path, 'wb') as f:
f.write(config_stream.read())
with open(tokenizer_path, 'wb') as f:
f.write(tokenizer_stream.read())
with open(model_path, 'wb') as f:
f.write(model_stream.read())
model = AutoModelForCausalLM.from_pretrained(tmp_dir, from_tf=True)
tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
result = pipeline_(request.input_text)
return {"response": result}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)