Hjgugugjhuhjggg commited on
Commit
f7ca3aa
1 Parent(s): 2c926f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -185
app.py CHANGED
@@ -1,197 +1,155 @@
1
  import os
 
 
2
  import logging
3
- import time
4
- from io import BytesIO
5
- from typing import Union
6
-
7
- from fastapi import FastAPI, HTTPException, Response, Request, UploadFile, File
8
- from fastapi.responses import StreamingResponse
9
- from pydantic import BaseModel, ValidationError, field_validator
10
- from transformers import (
11
- AutoConfig,
12
- AutoModelForCausalLM,
13
- AutoTokenizer,
14
- pipeline,
15
- GenerationConfig,
16
- StoppingCriteriaList
17
- )
18
- import boto3
19
- from huggingface_hub import hf_hub_download
20
- import soundfile as sf
21
- import numpy as np
22
- import torch
23
  import uvicorn
 
24
 
25
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s")
26
 
27
- AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
28
- AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
29
- AWS_REGION = os.getenv("AWS_REGION")
30
- S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
31
- HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
32
 
33
- class GenerateRequest(BaseModel):
34
- model_name: str
35
- input_text: str = ""
36
- task_type: str
37
- temperature: float = 1.0
38
- max_new_tokens: int = 200
39
- stream: bool = False
40
- top_p: float = 1.0
41
- top_k: int = 50
42
- repetition_penalty: float = 1.0
43
- num_return_sequences: int = 1
44
- do_sample: bool = True
45
- chunk_delay: float = 0.0
46
- stop_sequences: list[str] = []
47
-
48
- model_config = {"protected_namespaces": ()}
49
-
50
- @field_validator("model_name")
51
- def model_name_cannot_be_empty(cls, v):
52
- if not v:
53
- raise ValueError("model_name cannot be empty.")
54
- return v
55
-
56
- @field_validator("task_type")
57
- def task_type_must_be_valid(cls, v):
58
- valid_types = ["text-to-text", "text-to-image", "text-to-speech", "text-to-video"]
59
- if v not in valid_types:
60
- raise ValueError(f"task_type must be one of: {valid_types}")
61
- return v
62
-
63
- class S3ModelLoader:
64
- def __init__(self, bucket_name, s3_client):
65
- self.bucket_name = bucket_name
66
- self.s3_client = s3_client
67
-
68
- def _get_s3_uri(self, model_name):
69
- return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
70
-
71
- async def load_model_and_tokenizer(self, model_name):
72
- s3_uri = self._get_s3_uri(model_name)
73
- try:
74
- logging.info(f"Trying to load {model_name} from S3...")
75
- config = AutoConfig.from_pretrained(s3_uri)
76
- model = AutoModelForCausalLM.from_pretrained(s3_uri, config=config)
77
- tokenizer = AutoTokenizer.from_pretrained(s3_uri, config=config)
78
-
79
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
80
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
81
-
82
- logging.info(f"Loaded {model_name} from S3 successfully.")
83
- return model, tokenizer
84
- except EnvironmentError:
85
- logging.info(f"Model {model_name} not found in S3. Downloading...")
86
- try:
87
- config = AutoConfig.from_pretrained(model_name)
88
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
89
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, token=HUGGINGFACE_HUB_TOKEN)
90
-
91
- if tokenizer.eos_token_id is not None and tokenizer.pad_token_id is None:
92
- tokenizer.pad_token_id = config.pad_token_id or tokenizer.eos_token_id
93
-
94
- logging.info(f"Downloaded {model_name} successfully.")
95
- logging.info(f"Saving {model_name} to S3...")
96
- model.save_pretrained(s3_uri)
97
- tokenizer.save_pretrained(s3_uri)
98
- logging.info(f"Saved {model_name} to S3 successfully.")
99
- return model, tokenizer
100
- except Exception as e:
101
- logging.exception(f"Error downloading/uploading model: {e}")
102
- raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
103
 
104
- app = FastAPI()
 
 
105
 
106
- s3_client = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=AWS_REGION)
107
- model_loader = S3ModelLoader(S3_BUCKET_NAME, s3_client)
108
 
109
- @app.post("/generate")
110
- async def generate(request: Request, body: GenerateRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  try:
112
- validated_body = GenerateRequest(**body.model_dump())
113
- model, tokenizer = await model_loader.load_model_and_tokenizer(validated_body.model_name)
114
- device = "cuda" if torch.cuda.is_available() else "cpu"
115
- model.to(device)
116
-
117
- if validated_body.task_type == "text-to-text":
118
- generation_config = GenerationConfig(
119
- temperature=validated_body.temperature,
120
- max_new_tokens=validated_body.max_new_tokens,
121
- top_p=validated_body.top_p,
122
- top_k=validated_body.top_k,
123
- repetition_penalty=validated_body.repetition_penalty,
124
- do_sample=validated_body.do_sample,
125
- num_return_sequences=validated_body.num_return_sequences
126
- )
127
-
128
- async def stream_text():
129
- input_text = validated_body.input_text
130
- generated_text = ""
131
- max_length = model.config.max_position_embeddings
132
-
133
- while True:
134
- encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
135
- input_length = encoded_input["input_ids"].shape[1]
136
- remaining_tokens = max_length - input_length
137
-
138
- if remaining_tokens <= 0:
139
- break
140
-
141
- generation_config.max_new_tokens = min(remaining_tokens, validated_body.max_new_tokens)
142
-
143
- stopping_criteria = StoppingCriteriaList(
144
- [lambda _, outputs: tokenizer.decode(outputs[0][-1], skip_special_tokens=True) in validated_body.stop_sequences] if validated_body.stop_sequences else []
145
- )
146
-
147
- output = model.generate(**encoded_input, generation_config=generation_config, stopping_criteria=stopping_criteria)
148
- chunk = tokenizer.decode(output[0], skip_special_tokens=True)
149
- generated_text += chunk
150
- yield chunk
151
- time.sleep(validated_body.chunk_delay)
152
- input_text = generated_text
153
-
154
- if validated_body.stream:
155
- return StreamingResponse(stream_text(), media_type="text/plain")
156
- else:
157
- generated_text = ""
158
- async for chunk in stream_text():
159
- generated_text += chunk
160
- return {"result": generated_text}
161
-
162
- elif validated_body.task_type == "text-to-image":
163
- generator = pipeline("text-to-image", model=model, tokenizer=tokenizer, device=device)
164
- image = generator(validated_body.input_text)[0]
165
- image_bytes = image.tobytes()
166
- return Response(content=image_bytes, media_type="image/png")
167
-
168
- elif validated_body.task_type == "text-to-speech":
169
- generator = pipeline("text-to-speech", model=model, tokenizer=tokenizer, device=device)
170
- audio = generator(validated_body.input_text)
171
- audio_bytesio = BytesIO()
172
- sf.write(audio_bytesio, audio["sampling_rate"], np.int16(audio["audio"]))
173
- audio_bytes = audio_bytesio.getvalue()
174
- return Response(content=audio_bytes, media_type="audio/wav")
175
-
176
- elif validated_body.task_type == "text-to-video":
177
- try:
178
- generator = pipeline("text-to-video", model=model, tokenizer=tokenizer, device=device)
179
- video = generator(validated_body.input_text)
180
- return Response(content=video, media_type="video/mp4")
181
- except Exception as e:
182
- raise HTTPException(status_code=500, detail=f"Error in text-to-video generation: {e}")
183
-
184
- else:
185
- raise HTTPException(status_code=400, detail="Unsupported task type")
186
-
187
- except HTTPException as e:
188
- raise e
189
- except ValidationError as e:
190
- raise HTTPException(status_code=422, detail=e.errors())
191
- except Exception as e:
192
- logging.exception(f"An unexpected error occurred: {e}")
193
- raise HTTPException(status_code=500, detail="An unexpected error occurred.")
194
-
195
 
196
  if __name__ == "__main__":
197
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ import json
3
+ import threading
4
  import logging
5
+ from google.cloud import storage
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
7
+ from pydantic import BaseModel
8
+ from fastapi import FastAPI, HTTPException
9
+ import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import uvicorn
11
+ from dotenv import load_dotenv
12
 
13
+ load_dotenv()
14
 
15
+ API_KEY = os.getenv("API_KEY")
16
+ GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
17
+ GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
18
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
 
19
 
20
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
21
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
24
+ storage_client = storage.Client.from_service_account_info(credentials_info)
25
+ bucket = storage_client.bucket(GCS_BUCKET_NAME)
26
 
27
+ app = FastAPI()
 
28
 
29
+ class DownloadModelRequest(BaseModel):
30
+ model_name: str
31
+ pipeline_task: str
32
+ input_text: str
33
+
34
+ class GCSHandler:
35
+ def __init__(self, bucket_name):
36
+ self.bucket = storage_client.bucket(bucket_name)
37
+
38
+ def file_exists(self, blob_name):
39
+ return self.bucket.blob(blob_name).exists()
40
+
41
+ def create_folder_if_not_exists(self, folder_name):
42
+ if not self.file_exists(folder_name):
43
+ self.bucket.blob(folder_name + "/").upload_from_string("")
44
+
45
+ def upload_file(self, blob_name, file_stream):
46
+ self.create_folder_if_not_exists(os.path.dirname(blob_name))
47
+ blob = self.bucket.blob(blob_name)
48
+ blob.upload_from_file(file_stream)
49
+
50
+ def download_file(self, blob_name):
51
+ blob = self.bucket.blob(blob_name)
52
+ if not blob.exists():
53
+ raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
54
+ return blob.open("rb")
55
+
56
+ def generate_signed_url(self, blob_name, expiration=3600):
57
+ blob = self.bucket.blob(blob_name)
58
+ return blob.generate_signed_url(expiration=expiration)
59
+
60
+ def download_model_from_huggingface(model_name):
61
+ url = f"https://huggingface.co/{model_name}/tree/main"
62
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
63
+ response = requests.get(url, headers=headers)
64
+ if response.status_code == 200:
65
+ model_files = [
66
+ "pytorch_model.bin",
67
+ "config.json",
68
+ "tokenizer.json",
69
+ "model.safetensors",
70
+ ]
71
+ for file_name in model_files:
72
+ file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
73
+ file_content = requests.get(file_url).content
74
+ blob_name = f"models/{model_name}/{file_name}"
75
+ bucket.blob(blob_name).upload_from_string(file_content)
76
+ else:
77
+ raise HTTPException(status_code=404, detail="Error accessing Hugging Face model files.")
78
+
79
+ def download_and_verify_model(model_name):
80
+ model_files = [
81
+ "pytorch_model.bin",
82
+ "config.json",
83
+ "tokenizer.json",
84
+ "model.safetensors",
85
+ ]
86
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
87
+ if not all(gcs_handler.file_exists(f"models/{model_name}/{file}") for file in model_files):
88
+ download_model_from_huggingface(model_name)
89
+
90
+ def load_model_from_gcs(model_name):
91
+ model_files = [
92
+ "pytorch_model.bin",
93
+ "config.json",
94
+ "tokenizer.json",
95
+ "model.safetensors",
96
+ ]
97
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
98
+ model_files_streams = {
99
+ file: gcs_handler.download_file(f"models/{model_name}/{file}")
100
+ for file in model_files if gcs_handler.file_exists(f"models/{model_name}/{file}")
101
+ }
102
+ model_stream = model_files_streams.get("pytorch_model.bin") or model_files_streams.get("model.safetensors")
103
+ tokenizer_stream = model_files_streams.get("tokenizer.json")
104
+ config_stream = model_files_streams.get("config.json")
105
+ model = AutoModelForCausalLM.from_pretrained(model_stream, config=config_stream)
106
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_stream)
107
+ return model, tokenizer
108
+
109
+ def load_model(model_name):
110
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
111
  try:
112
+ return load_model_from_gcs(model_name)
113
+ except HTTPException:
114
+ download_and_verify_model(model_name)
115
+ return load_model_from_gcs(model_name)
116
+
117
+ @app.on_event("startup")
118
+ async def startup():
119
+ gcs_handler = GCSHandler(GCS_BUCKET_NAME)
120
+ blobs = list(bucket.list_blobs(prefix="models/"))
121
+ model_names = set(blob.name.split("/")[1] for blob in blobs)
122
+ def download_model_thread(model_name):
123
+ try:
124
+ download_and_verify_model(model_name)
125
+ except Exception as e:
126
+ logger.error(f"Error downloading model '{model_name}': {e}")
127
+ threads = [threading.Thread(target=download_model_thread, args=(model_name,)) for model_name in model_names]
128
+ for thread in threads:
129
+ thread.start()
130
+ for thread in threads:
131
+ thread.join()
132
+
133
+ @app.post("/predict/")
134
+ async def predict(request: DownloadModelRequest):
135
+ model_name = request.model_name
136
+ pipeline_task = request.pipeline_task
137
+ input_text = request.input_text
138
+ model, tokenizer = load_model(model_name)
139
+ pipe = pipeline(pipeline_task, model=model, tokenizer=tokenizer)
140
+ result = pipe(input_text)
141
+ return {"result": result}
142
+
143
+ def download_all_models_in_background():
144
+ models_url = "https://huggingface.co/api/models"
145
+ response = requests.get(models_url)
146
+ if response.status_code == 200:
147
+ models = response.json()
148
+ for model in models:
149
+ download_model_from_huggingface(model["id"])
150
+
151
+ def run_in_background():
152
+ threading.Thread(target=download_all_models_in_background, daemon=True).start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  if __name__ == "__main__":
155
+ uvicorn.run(app, host="0.0.0.0", port=7860)