File size: 10,544 Bytes
319a292 eb88339 eabbdd7 8b558e0 fcc4055 3d3cff1 52c8366 65aeab7 9ec8100 c7434cd fcc4055 7a79578 1429d43 7d42dcb c7434cd 03ed2e0 f7ca3aa 053347d d8245fc 9ec8100 878326f 9ec8100 c7434cd cdfd15f c7434cd fcc4055 6d32772 2058dee 52c8366 d63a1cc 18fc5d8 1428511 d8245fc 18fc5d8 4bf1bd9 f7ca3aa 4bf1bd9 52c8366 eb88339 00ee742 52c8366 eabbdd7 63f92cf eabbdd7 6654ce5 7d42dcb 2a261dd 52c8366 fcc4055 52c8366 7d42dcb 2280244 7d42dcb 2280244 7d42dcb 19f95bc 52c8366 14051c4 52c8366 c7434cd 14051c4 1836528 4a8c11d d9a044e 4a8c11d c7434cd fcc4055 399f6a8 fcc4055 399f6a8 fcc4055 399f6a8 fcc4055 2a261dd fcc4055 14051c4 00ee742 1428511 f59828f 2280244 eabbdd7 7d42dcb 0154ba4 7d42dcb 0154ba4 7d42dcb f59828f fcc4055 6654ce5 7d42dcb 2280244 7d42dcb 2280244 7d42dcb 2280244 6654ce5 2ca418a 2280244 19f95bc f05e47d 7d42dcb f05e47d 19f95bc 7d42dcb 19f95bc f05e47d 19f95bc f05e47d 19f95bc 2280244 f05e47d 19f95bc 7d42dcb 19f95bc fcc4055 6654ce5 49a991b 81bf699 fcc4055 cdddc8a f59828f 52c8366 319a292 7d42dcb c96f1f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
from transformers import pipeline, AutoConfig, AutoTokenizer
from transformers.utils import logging
from google.cloud import storage
from google.auth.exceptions import DefaultCredentialsError
import uvicorn
import asyncio
import json
from huggingface_hub import login
from dotenv import load_dotenv
import huggingface_hub
from threading import Thread
from typing import AsyncIterator, List, Dict
from transformers import StoppingCriteria, StoppingCriteriaList
import torch
load_dotenv()
GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN")
if HUGGINGFACE_HUB_TOKEN:
login(token=HUGGINGFACE_HUB_TOKEN)
os.system("git config --global credential.helper store")
if HUGGINGFACE_HUB_TOKEN:
huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
try:
credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
client = storage.Client.from_service_account_info(credentials_info)
bucket = client.get_bucket(GCS_BUCKET_NAME)
logger.info(f"Connection to Google Cloud Storage successful. Bucket: {GCS_BUCKET_NAME}")
except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
logger.error(f"Error loading credentials or bucket: {e}")
raise RuntimeError(f"Error loading credentials or bucket: {e}")
app = FastAPI()
class GenerateRequest(BaseModel):
model_name: str
input_text: str
task_type: str
temperature: float = 1.0
stream: bool = True
top_p: float = 1.0
top_k: int = 50
repetition_penalty: float = 1.0
num_return_sequences: int = 1
do_sample: bool = False
chunk_delay: float = 0.0
max_new_tokens: int = 10
stopping_strings: List[str] = None
@field_validator("model_name")
def model_name_cannot_be_empty(cls, v):
if not v:
raise ValueError("model_name cannot be empty.")
return v
@field_validator("task_type")
def task_type_must_be_valid(cls, v):
valid_types = ["text-generation"]
if v not in valid_types:
raise ValueError(f"task_type must be one of: {valid_types}")
return v
class StopOnKeywords(StoppingCriteria):
def __init__(self, stop_words_ids: List[List[int]], tokenizer, encounters: int = 1):
super().__init__()
self.stop_words_ids = stop_words_ids
self.tokenizer = tokenizer
self.encounters = encounters
self.current_encounters = 0
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in self.stop_words_ids:
if torch.all(input_ids[0][-len(stop_ids):] == torch.tensor(stop_ids).to(input_ids.device)):
self.current_encounters += 1
if self.current_encounters >= self.encounters:
return True
return False
class GCSModelLoader:
def __init__(self, bucket):
self.bucket = bucket
def _get_gcs_uri(self, model_name):
return f"{model_name}"
def _blob_exists(self, blob_path):
blob = self.bucket.blob(blob_path)
return blob.exists()
def _create_model_folder(self, model_name):
gcs_model_folder = self._get_gcs_uri(model_name)
if not self._blob_exists(f"{gcs_model_folder}/.touch"):
blob = self.bucket.blob(f"{gcs_model_folder}/.touch")
blob.upload_from_string("")
logger.info(f"Created folder '{gcs_model_folder}' in GCS.")
def check_model_exists_locally(self, model_name):
gcs_model_path = self._get_gcs_uri(model_name)
blobs = self.bucket.list_blobs(prefix=gcs_model_path)
return any(blobs)
def download_model_from_huggingface(self, model_name):
logger.info(f"Downloading model '{model_name}' from Hugging Face.")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
gcs_model_folder = self._get_gcs_uri(model_name)
self._create_model_folder(model_name)
tokenizer.save_pretrained(gcs_model_folder)
config.save_pretrained(gcs_model_folder)
for filename in os.listdir(config.name_or_path):
if filename.endswith((".bin", ".safetensors")):
blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
blob.upload_from_filename(os.path.join(config.name_or_path, filename))
logger.info(f"Model '{model_name}' downloaded and saved to GCS.")
return True
except Exception as e:
logger.error(f"Error downloading model from Hugging Face: {e}")
return False
model_loader = GCSModelLoader(bucket)
@app.post("/generate")
async def generate(request: GenerateRequest):
model_name = request.model_name
input_text = request.input_text
task_type = request.task_type
requested_max_new_tokens = request.max_new_tokens
generation_params = request.model_dump(
exclude_none=True,
exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens', 'stopping_strings'}
)
user_defined_stopping_strings = request.stopping_strings
try:
if not model_loader.check_model_exists_locally(model_name):
if not model_loader.download_model_from_huggingface(model_name):
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
stopping_criteria_list = StoppingCriteriaList()
if user_defined_stopping_strings:
stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
stopping_criteria_list.append(StopOnKeywords(stop_words_ids, tokenizer)) # Pass tokenizer
if config.eos_token_id is not None:
eos_token_ids = [config.eos_token_id]
if isinstance(config.eos_token_id, int):
eos_token_ids = [[config.eos_token_id]]
elif isinstance(config.eos_token_id, list):
eos_token_ids = [[id] for id in config.eos_token_id]
stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
elif tokenizer.eos_token is not None:
stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
nonlocal input_text
all_generated_text = ""
stop_reason = None
while True:
text_pipeline = pipeline(
task_type,
model=model_name,
tokenizer=tokenizer,
token=HUGGINGFACE_HUB_TOKEN,
stopping_criteria=stopping_criteria_list,
**generation_params,
max_new_tokens=requested_max_new_tokens
)
def generate_on_thread(pipeline, current_input_text, output_queue):
result = pipeline(current_input_text)
output_queue.put_nowait(result)
output_queue = asyncio.Queue()
thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
thread.start()
result = await output_queue.get()
thread.join()
newly_generated_text = result[0]['generated_text']
# Decode tokens to check for stopping strings
for criteria in stopping_criteria_list:
if isinstance(criteria, StopOnKeywords):
for stop_ids in criteria.stop_words_ids:
decoded_stop_string = tokenizer.decode(stop_ids)
if decoded_stop_string in newly_generated_text:
stop_reason = f"stopping_string: {decoded_stop_string}"
break
if stop_reason:
break
if stop_reason:
break
all_generated_text += newly_generated_text
yield {"response": [{'generated_text': newly_generated_text}]}
if config.eos_token_id is not None:
eos_tokens = [config.eos_token_id]
if isinstance(config.eos_token_id, int):
eos_tokens = [config.eos_token_id]
elif isinstance(config.eos_token_id, list):
eos_tokens = config.eos_token_id
for eos_token in eos_tokens:
if tokenizer.decode([eos_token]) in newly_generated_text:
stop_reason = "eos_token"
break
if stop_reason:
break
elif tokenizer.eos_token is not None and tokenizer.eos_token in newly_generated_text:
stop_reason = "eos_token"
break
input_text = all_generated_text
async def text_stream():
async for data in generate_responses():
yield f"data: {json.dumps(data)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(text_stream(), media_type="text/event-stream")
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Internal server error: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
if __name__ == "__main__":
import torch
uvicorn.run(app, host="0.0.0.0", port=7860) |