Hjgugugjhuhjggg commited on
Commit
8e4fcb7
1 Parent(s): 03ed2e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -45,8 +45,7 @@ class GCSHandler:
45
  self.bucket = storage_client.bucket(bucket_name)
46
 
47
  def file_exists(self, blob_name):
48
- exists = self.bucket.blob(blob_name).exists()
49
- return exists
50
 
51
  def download_file(self, blob_name):
52
  blob = self.bucket.blob(blob_name)
@@ -60,8 +59,11 @@ class GCSHandler:
60
 
61
  def generate_signed_url(self, blob_name, expiration=3600):
62
  blob = self.bucket.blob(blob_name)
63
- url = blob.generate_signed_url(expiration=expiration)
64
- return url
 
 
 
65
 
66
  def load_model_from_gcs(model_name: str, model_files: list):
67
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
@@ -71,7 +73,7 @@ def load_model_from_gcs(model_name: str, model_files: list):
71
  config_stream = model_blobs.get("config.json")
72
  tokenizer_stream = model_blobs.get("tokenizer.json")
73
 
74
- if "safetensors" in model_stream.name:
75
  model = load_safetensors_model(model_stream)
76
  else:
77
  model = AutoModelForCausalLM.from_pretrained(io.BytesIO(model_stream), config=config_stream)
@@ -122,18 +124,20 @@ def download_model_from_huggingface(model_name):
122
  except Exception as e:
123
  raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
124
 
 
 
 
 
 
 
 
125
  @app.post("/predict/")
126
  async def predict(request: DownloadModelRequest):
127
  try:
128
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
129
  model_prefix = request.model_name
130
 
131
- model_files = get_model_files_from_gcs(model_prefix)
132
-
133
- if not model_files:
134
- download_model_from_huggingface(model_prefix)
135
- model_files = get_model_files_from_gcs(model_prefix)
136
-
137
  model, tokenizer = load_model_from_gcs(model_prefix, model_files)
138
 
139
  pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
 
45
  self.bucket = storage_client.bucket(bucket_name)
46
 
47
  def file_exists(self, blob_name):
48
+ return self.bucket.blob(blob_name).exists()
 
49
 
50
  def download_file(self, blob_name):
51
  blob = self.bucket.blob(blob_name)
 
59
 
60
  def generate_signed_url(self, blob_name, expiration=3600):
61
  blob = self.bucket.blob(blob_name)
62
+ return blob.generate_signed_url(expiration=expiration)
63
+
64
+ def create_folder(self, folder_name):
65
+ blob = self.bucket.blob(folder_name + "/")
66
+ blob.upload_from_string("") # Create an empty "folder"
67
 
68
  def load_model_from_gcs(model_name: str, model_files: list):
69
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
 
73
  config_stream = model_blobs.get("config.json")
74
  tokenizer_stream = model_blobs.get("tokenizer.json")
75
 
76
+ if model_stream and model_stream.endswith(".safetensors"):
77
  model = load_safetensors_model(model_stream)
78
  else:
79
  model = AutoModelForCausalLM.from_pretrained(io.BytesIO(model_stream), config=config_stream)
 
124
  except Exception as e:
125
  raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
126
 
127
+ def download_model_files(model_name: str):
128
+ model_files = get_model_files_from_gcs(model_name)
129
+ if not model_files:
130
+ download_model_from_huggingface(model_name)
131
+ model_files = get_model_files_from_gcs(model_name)
132
+ return model_files
133
+
134
  @app.post("/predict/")
135
  async def predict(request: DownloadModelRequest):
136
  try:
137
  gcs_handler = GCSHandler(GCS_BUCKET_NAME)
138
  model_prefix = request.model_name
139
 
140
+ model_files = download_model_files(model_prefix)
 
 
 
 
 
141
  model, tokenizer = load_model_from_gcs(model_prefix, model_files)
142
 
143
  pipe = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)