Hjgugugjhuhjggg commited on
Commit
4bf1bd9
1 Parent(s): abeeac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +181 -132
app.py CHANGED
@@ -1,155 +1,204 @@
1
  import os
2
- import json
3
- import threading
4
  import logging
5
- from google.cloud import storage
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
- from pydantic import BaseModel
8
- from fastapi import FastAPI, HTTPException
9
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import uvicorn
11
- from dotenv import load_dotenv
12
 
13
- load_dotenv()
14
 
15
- API_KEY = os.getenv("API_KEY")
16
- GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
17
- GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
18
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
 
19
 
20
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
21
- logger = logging.getLogger(__name__)
22
-
23
- credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
24
- storage_client = storage.Client.from_service_account_info(credentials_info)
25
- bucket = storage_client.bucket(GCS_BUCKET_NAME)
26
-
27
- app = FastAPI()
28
-
29
- class DownloadModelRequest(BaseModel):
30
  model_name: str
31
- pipeline_task: str
32
  input_text: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- class GCSHandler:
35
- def __init__(self, bucket_name):
36
- self.bucket = storage_client.bucket(bucket_name)
37
-
38
- def file_exists(self, blob_name):
39
- return self.bucket.blob(blob_name).exists()
40
-
41
- def create_folder_if_not_exists(self, folder_name):
42
- if not self.file_exists(folder_name):
43
- self.bucket.blob(folder_name + "/").upload_from_string("")
44
-
45
- def upload_file(self, blob_name, file_stream):
46
- self.create_folder_if_not_exists(os.path.dirname(blob_name))
47
- blob = self.bucket.blob(blob_name)
48
- blob.upload_from_file(file_stream)
49
-
50
- def download_file(self, blob_name):
51
- blob = self.bucket.blob(blob_name)
52
- if not blob.exists():
53
- raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
54
- return blob.open("rb")
55
-
56
- def generate_signed_url(self, blob_name, expiration=3600):
57
- blob = self.bucket.blob(blob_name)
58
- return blob.generate_signed_url(expiration=expiration)
59
-
60
- def download_model_from_huggingface(model_name):
61
- url = f"https://huggingface.co/{model_name}/tree/main"
62
- headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
63
- response = requests.get(url, headers=headers)
64
- if response.status_code == 200:
65
- model_files = [
66
- "pytorch_model.bin",
67
- "config.json",
68
- "tokenizer.json",
69
- "model.safetensors",
70
- ]
71
- for file_name in model_files:
72
- file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
73
- file_content = requests.get(file_url).content
74
- blob_name = f"models/{model_name}/{file_name}"
75
- bucket.blob(blob_name).upload_from_string(file_content)
76
- else:
77
- raise HTTPException(status_code=404, detail="Error accessing Hugging Face model files.")
78
-
79
- def download_and_verify_model(model_name):
80
- model_files = [
81
- "pytorch_model.bin",
82
- "config.json",
83
- "tokenizer.json",
84
- "model.safetensors",
85
- ]
86
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
87
- if not all(gcs_handler.file_exists(f"models/{model_name}/{file}") for file in model_files):
88
- download_model_from_huggingface(model_name)
89
-
90
- def load_model_from_gcs(model_name):
91
- model_files = [
92
- "pytorch_model.bin",
93
- "config.json",
94
- "tokenizer.json",
95
- "model.safetensors",
96
- ]
97
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
98
- model_files_streams = {
99
- file: gcs_handler.download_file(f"models/{model_name}/{file}")
100
- for file in model_files if gcs_handler.file_exists(f"models/{model_name}/{file}")
101
- }
102
- model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors")
103
- tokenizer_stream = model_files_streams.get("tokenizer.json")
104
- config_stream = model_files_streams.get("config.json")
105
- model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
106
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
107
- return model, tokenizer
108
-
109
- def load_model(model_name):
110
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
111
- try:
112
- return load_model_from_gcs(model_name)
113
- except HTTPException:
114
- download_and_verify_model(model_name)
115
- return load_model_from_gcs(model_name)
116
 
117
- @app.on_event("startup")
118
- async def startup():
119
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
120
- blobs = list(bucket.list_blobs(prefix="models/"))
121
- model_names = set(blob.name.split("/")[1] for blob in blobs)
122
- def download_model_thread(model_name):
 
 
 
 
 
 
 
 
123
  try:
124
- download_and_verify_model(model_name)
 
 
 
125
  except Exception as e:
126
- logger.error(f"Error downloading model '{model_name}': {e}")
127
- threads = [threading.Thread(target=download_model_thread, args=(model_name,)) for model_name in model_names]
128
- for thread in threads:
129
- thread.start()
130
- for thread in threads:
131
- thread.join()
132
-
133
- @app.post("/predict/")
134
- async def predict(request: DownloadModelRequest):
135
- model_name = request.model_name
136
- pipeline_task = request.pipeline_task
137
- input_text = request.input_text
138
- model, tokenizer = load_model(model_name)
139
- pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
140
- result = pipe(input_text)
141
- return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  def download_all_models_in_background():
144
  models_url = "https://huggingface.co/api/models"
145
- response = requests.get(models_url)
146
- if response.status_code == 200:
 
 
 
 
147
  models = response.json()
148
  for model in models:
149
- download_model_from_huggingface(model["id"])
 
 
 
 
150
 
151
  def run_in_background():
152
  threading.Thread(target=download_all_models_in_background, daemon=True).start()
153
 
 
 
 
 
154
  if __name__ == "__main__":
155
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
 
 
2
  import logging
 
 
 
 
3
  import requests
4
+ import threading
5
+ from io import BytesIO
6
+ from fastapi import FastAPI, HTTPException, Response, Request
7
+ from fastapi.responses import StreamingResponse
8
+ from pydantic import BaseModel
9
+ from transformers import (
10
+ AutoConfig,
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ pipeline,
14
+ GenerationConfig
15
+ )
16
+ import boto3
17
+ from huggingface_hub import hf_hub_download
18
+ import soundfile as sf
19
+ import numpy as np
20
+ import torch
21
  import uvicorn
22
+ from tqdm import tqdm
23
 
24
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
25
 
26
+ AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
27
+ AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
28
+ AWS_REGION = os.getenv("AWS_REGION")
29
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
30
+ HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
31
 
32
+ class GenerateRequest(BaseModel):
 
 
 
 
 
 
 
 
 
33
  model_name: str
 
34
  input_text: str
35
+ task_type: str
36
+ temperature: float = 1.0
37
+ max_new_tokens: int = 200
38
+ stream: bool = False
39
+ top_p: float = 1.0
40
+ top_k: int = 50
41
+ repetition_penalty: float = 1.0
42
+ num_return_sequences: int = 1
43
+ do_sample: bool = True
44
+
45
+ class S3ModelLoader:
46
+ def __init__(self, bucket_name, s3_client):
47
+ self.bucket_name = bucket_name
48
+ self.s3_client = s3_client
49
+
50
+ def _get_s3_uri(self, model_name):
51
+ return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
52
+
53
+ def download_model_from_s3(self, model_name):
54
+ try:
55
+ logging.info(f"Trying to load {model_name} from S3...")
56
+ config = AutoConfig.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
57
+ model = AutoModelForCausalLM.from_pretrained(f"s3://{self.bucket_name}/{model_name}", config=config)
58
+ tokenizer = AutoTokenizer.from_pretrained(f"s3://{self.bucket_name}/{model_name}")
59
+ logging.info(f"Loaded {model_name} from S3 successfully.")
60
+ return model, tokenizer
61
+ except Exception as e:
62
+ logging.error(f"Error loading {model_name} from S3: {e}")
63
+ return None, None
64
 
65
+ async def load_model_and_tokenizer(self, model_name):
66
+ try:
67
+ model, tokenizer = self.download_model_from_s3(model_name)
68
+ if model is None or tokenizer is None:
69
+ model, tokenizer = await self.download_and_save_model_from_huggingface(model_name)
70
+ return model, tokenizer
71
+ except Exception as e:
72
+ raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ async def download_and_save_model_from_huggingface(self, model_name):
75
+ try:
76
+ logging.info(f"Downloading {model_name} from Hugging Face...")
77
+ with tqdm(unit="B", unit_scale=True, desc=f"Downloading {model_name}") as t:
78
+ model = AutoModelForCausalLM.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN, _tqdm=t)
79
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
80
+ logging.info(f"Downloaded {model_name} successfully.")
81
+ self.upload_model_to_s3(model_name, model, tokenizer)
82
+ return model, tokenizer
83
+ except Exception as e:
84
+ logging.error(f"Error downloading model from Hugging Face: {e}")
85
+ raise HTTPException(status_code=500, detail=f"Error downloading model from Hugging Face: {e}")
86
+
87
+ def upload_model_to_s3(self, model_name, model, tokenizer):
88
  try:
89
+ s3_uri = self._get_s3_uri(model_name)
90
+ model.save_pretrained(s3_uri)
91
+ tokenizer.save_pretrained(s3_uri)
92
+ logging.info(f"Saved {model_name} to S3 successfully.")
93
  except Exception as e:
94
+ logging.error(f"Error saving {model_name} to S3: {e}")
95
+ raise HTTPException(status_code=500, detail=f"Error saving model to S3: {e}")
96
+
97
+ app = FastAPI()
98
+
99
+ s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
100
+ model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
101
+
102
+ @app.post("/generate")
103
+ async def generate(request: Request, body: GenerateRequest):
104
+ try:
105
+ model, tokenizer = await model_loader.load_model_and_tokenizer(body.model_name)
106
+ device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ model.to(device)
108
+
109
+ if body.task_type == "text-to-text":
110
+ generation_config = GenerationConfig(
111
+ temperature=body.temperature,
112
+ max_new_tokens=body.max_new_tokens,
113
+ top_p=body.top_p,
114
+ top_k=body.top_k,
115
+ repetition_penalty=body.repetition_penalty,
116
+ do_sample=body.do_sample,
117
+ num_return_sequences=body.num_return_sequences
118
+ )
119
+
120
+ async def stream_text():
121
+ input_text = body.input_text
122
+ max_length = model.config.max_position_embeddings
123
+ generated_text = ""
124
+
125
+ while True:
126
+ inputs = tokenizer(input_text, return_tensors="pt").to(device)
127
+ input_length = inputs.input_ids.shape[1]
128
+ remaining_tokens = max_length - input_length
129
+ if remaining_tokens < body.max_new_tokens:
130
+ generation_config.max_new_tokens = remaining_tokens
131
+ if remaining_tokens <= 0:
132
+ break
133
+
134
+ output = model.generate(**inputs, generation_config=generation_config)
135
+ chunk = tokenizer.decode(output[0], skip_special_tokens=True)
136
+ generated_text += chunk
137
+ yield chunk
138
+ if len(tokenizer.encode(generated_text)) >= max_length:
139
+ break
140
+ input_text = chunk
141
+
142
+ if body.stream:
143
+ return StreamingResponse(stream_text(), media_type="text/plain")
144
+ else:
145
+ generated_text = ""
146
+ async for chunk in stream_text():
147
+ generated_text += chunk
148
+ return {"result": generated_text}
149
+
150
+ elif body.task_type == "text-to-image":
151
+ generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
152
+ image = generator(body.input_text)[0]
153
+ image_bytes = image.tobytes()
154
+ return Response(content=image_bytes, media_type="image/png")
155
+
156
+ elif body.task_type == "text-to-speech":
157
+ generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
158
+ audio = generator(body.input_text)
159
+ audio_bytesio = BytesIO()
160
+ sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
161
+ audio_bytes = audio_bytesio.getvalue()
162
+ return Response(content=audio_bytes, media_type="audio/wav")
163
+
164
+ elif body.task_type == "text-to-video":
165
+ try:
166
+ generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
167
+ video = generator(body.input_text)
168
+ return Response(content=video, media_type="video/mp4")
169
+ except Exception as e:
170
+ raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
171
+
172
+ else:
173
+ raise HTTPException(status_code=400, detail="Unsupported task type")
174
+
175
+ except HTTPException as e:
176
+ raise e
177
+ except Exception as e:
178
+ raise HTTPException(status_code=500, detail=str(e))
179
 
180
  def download_all_models_in_background():
181
  models_url = "https://huggingface.co/api/models"
182
+ try:
183
+ response = requests.get(models_url)
184
+ if response.status_code != 200:
185
+ logging.error("Error al obtener la lista de modelos de Hugging Face.")
186
+ raise HTTPException(status_code=500, detail="Error al obtener la lista de modelos.")
187
+
188
  models = response.json()
189
  for model in models:
190
+ model_name = model["id"]
191
+ model_loader.download_and_save_model_from_huggingface(model_name)
192
+ except Exception as e:
193
+ logging.error(f"Error al descargar modelos en segundo plano: {e}")
194
+ raise HTTPException(status_code=500, detail="Error al descargar modelos en segundo plano.")
195
 
196
  def run_in_background():
197
  threading.Thread(target=download_all_models_in_background, daemon=True).start()
198
 
199
+ @app.on_event("startup")
200
+ async def startup_event():
201
+ run_in_background()
202
+
203
  if __name__ == "__main__":
204
+ uvicorn.run(app, host="0.0.0.0", port=8000)