File size: 10,544 Bytes
319a292
eb88339
eabbdd7
8b558e0
fcc4055
 
3d3cff1
 
52c8366
 
65aeab7
9ec8100
c7434cd
 
fcc4055
7a79578
1429d43
7d42dcb
c7434cd
 
03ed2e0
f7ca3aa
 
053347d
d8245fc
9ec8100
878326f
9ec8100
c7434cd
cdfd15f
 
c7434cd
fcc4055
 
6d32772
2058dee
52c8366
d63a1cc
 
18fc5d8
1428511
d8245fc
18fc5d8
 
4bf1bd9
f7ca3aa
4bf1bd9
52c8366
eb88339
00ee742
52c8366
eabbdd7
63f92cf
eabbdd7
 
 
 
 
6654ce5
7d42dcb
 
2a261dd
52c8366
 
 
 
 
 
 
 
fcc4055
52c8366
 
 
 
7d42dcb
2280244
7d42dcb
 
2280244
7d42dcb
 
 
 
 
 
 
 
 
 
19f95bc
52c8366
14051c4
52c8366
 
 
c7434cd
14051c4
1836528
4a8c11d
d9a044e
4a8c11d
c7434cd
 
 
 
 
 
 
fcc4055
 
 
 
 
 
 
399f6a8
fcc4055
 
399f6a8
 
fcc4055
 
 
 
 
 
 
 
399f6a8
fcc4055
 
2a261dd
fcc4055
14051c4
00ee742
1428511
f59828f
2280244
eabbdd7
7d42dcb
0154ba4
 
7d42dcb
0154ba4
7d42dcb
f59828f
 
fcc4055
 
 
 
6654ce5
7d42dcb
 
 
 
 
2280244
7d42dcb
 
 
 
 
 
 
 
2280244
7d42dcb
 
2280244
6654ce5
2ca418a
2280244
19f95bc
f05e47d
7d42dcb
f05e47d
19f95bc
 
 
 
 
7d42dcb
19f95bc
f05e47d
19f95bc
 
f05e47d
 
19f95bc
 
 
 
 
 
 
 
2280244
 
 
 
 
 
 
 
 
 
 
 
 
 
f05e47d
19f95bc
 
 
7d42dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19f95bc
fcc4055
6654ce5
49a991b
 
81bf699
fcc4055
cdddc8a
f59828f
 
 
 
 
 
52c8366
319a292
7d42dcb
c96f1f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, field_validator
from transformers import pipeline, AutoConfig, AutoTokenizer
from transformers.utils import logging
from google.cloud import storage
from google.auth.exceptions import DefaultCredentialsError
import uvicorn
import asyncio
import json
from huggingface_hub import login
from dotenv import load_dotenv
import huggingface_hub
from threading import Thread
from typing import AsyncIterator, List, Dict
from transformers import StoppingCriteria, StoppingCriteriaList
import torch

load_dotenv()

GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
HUGGINGFACE_HUB_TOKEN = os.getenv("HF_API_TOKEN")

if HUGGINGFACE_HUB_TOKEN:
    login(token=HUGGINGFACE_HUB_TOKEN)

os.system("git config --global credential.helper store")
if HUGGINGFACE_HUB_TOKEN:
    huggingface_hub.login(token=HUGGINGFACE_HUB_TOKEN, add_to_git_credential=True)

logging.set_verbosity_info()
logger = logging.get_logger(__name__)

try:
    credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
    client = storage.Client.from_service_account_info(credentials_info)
    bucket = client.get_bucket(GCS_BUCKET_NAME)
    logger.info(f"Connection to Google Cloud Storage successful. Bucket: {GCS_BUCKET_NAME}")

except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
    logger.error(f"Error loading credentials or bucket: {e}")
    raise RuntimeError(f"Error loading credentials or bucket: {e}")

app = FastAPI()

class GenerateRequest(BaseModel):
    model_name: str
    input_text: str
    task_type: str
    temperature: float = 1.0
    stream: bool = True
    top_p: float = 1.0
    top_k: int = 50
    repetition_penalty: float = 1.0
    num_return_sequences: int = 1
    do_sample: bool = False
    chunk_delay: float = 0.0
    max_new_tokens: int = 10
    stopping_strings: List[str] = None

    @field_validator("model_name")
    def model_name_cannot_be_empty(cls, v):
        if not v:
            raise ValueError("model_name cannot be empty.")
        return v

    @field_validator("task_type")
    def task_type_must_be_valid(cls, v):
        valid_types = ["text-generation"]
        if v not in valid_types:
            raise ValueError(f"task_type must be one of: {valid_types}")
        return v

class StopOnKeywords(StoppingCriteria):
    def __init__(self, stop_words_ids: List[List[int]], tokenizer, encounters: int = 1):
        super().__init__()
        self.stop_words_ids = stop_words_ids
        self.tokenizer = tokenizer
        self.encounters = encounters
        self.current_encounters = 0

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in self.stop_words_ids:
            if torch.all(input_ids[0][-len(stop_ids):] == torch.tensor(stop_ids).to(input_ids.device)):
                self.current_encounters += 1
                if self.current_encounters >= self.encounters:
                    return True
        return False

class GCSModelLoader:
    def __init__(self, bucket):
        self.bucket = bucket

    def _get_gcs_uri(self, model_name):
        return f"{model_name}"

    def _blob_exists(self, blob_path):
        blob = self.bucket.blob(blob_path)
        return blob.exists()

    def _create_model_folder(self, model_name):
        gcs_model_folder = self._get_gcs_uri(model_name)
        if not self._blob_exists(f"{gcs_model_folder}/.touch"):
            blob = self.bucket.blob(f"{gcs_model_folder}/.touch")
            blob.upload_from_string("")
            logger.info(f"Created folder '{gcs_model_folder}' in GCS.")

    def check_model_exists_locally(self, model_name):
        gcs_model_path = self._get_gcs_uri(model_name)
        blobs = self.bucket.list_blobs(prefix=gcs_model_path)
        return any(blobs)

    def download_model_from_huggingface(self, model_name):
        logger.info(f"Downloading model '{model_name}' from Hugging Face.")
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
            config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
            gcs_model_folder = self._get_gcs_uri(model_name)
            self._create_model_folder(model_name)
            tokenizer.save_pretrained(gcs_model_folder)
            config.save_pretrained(gcs_model_folder)
            for filename in os.listdir(config.name_or_path):
                if filename.endswith((".bin", ".safetensors")):
                    blob = self.bucket.blob(f"{gcs_model_folder}/{filename}")
                    blob.upload_from_filename(os.path.join(config.name_or_path, filename))
            logger.info(f"Model '{model_name}' downloaded and saved to GCS.")
            return True
        except Exception as e:
            logger.error(f"Error downloading model from Hugging Face: {e}")
            return False

model_loader = GCSModelLoader(bucket)

@app.post("/generate")
async def generate(request: GenerateRequest):
    model_name = request.model_name
    input_text = request.input_text
    task_type = request.task_type
    requested_max_new_tokens = request.max_new_tokens
    generation_params = request.model_dump(
        exclude_none=True,
        exclude={'model_name', 'input_text', 'task_type', 'stream', 'chunk_delay', 'max_new_tokens', 'stopping_strings'}
    )
    user_defined_stopping_strings = request.stopping_strings

    try:
        if not model_loader.check_model_exists_locally(model_name):
            if not model_loader.download_model_from_huggingface(model_name):
                raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")

        tokenizer = AutoTokenizer.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
        config = AutoConfig.from_pretrained(model_name, token=HUGGINGFACE_HUB_TOKEN)
        stopping_criteria_list = StoppingCriteriaList()

        if user_defined_stopping_strings:
            stop_words_ids = [tokenizer.encode(stop_string, add_special_tokens=False) for stop_string in user_defined_stopping_strings]
            stopping_criteria_list.append(StopOnKeywords(stop_words_ids, tokenizer)) # Pass tokenizer

        if config.eos_token_id is not None:
            eos_token_ids = [config.eos_token_id]
            if isinstance(config.eos_token_id, int):
                eos_token_ids = [[config.eos_token_id]]
            elif isinstance(config.eos_token_id, list):
                eos_token_ids = [[id] for id in config.eos_token_id]
            stop_words_ids_eos = [tokenizer.encode(tokenizer.decode(eos_id), add_special_tokens=False) for eos_id in eos_token_ids]
            stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer
        elif tokenizer.eos_token is not None:
            stop_words_ids_eos = [tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)]
            stopping_criteria_list.append(StopOnKeywords(stop_words_ids_eos, tokenizer)) # Pass tokenizer

        async def generate_responses() -> AsyncIterator[Dict[str, List[Dict[str, str]]]]:
            nonlocal input_text
            all_generated_text = ""
            stop_reason = None

            while True:
                text_pipeline = pipeline(
                    task_type,
                    model=model_name,
                    tokenizer=tokenizer,
                    token=HUGGINGFACE_HUB_TOKEN,
                    stopping_criteria=stopping_criteria_list,
                    **generation_params,
                    max_new_tokens=requested_max_new_tokens
                )

                def generate_on_thread(pipeline, current_input_text, output_queue):
                    result = pipeline(current_input_text)
                    output_queue.put_nowait(result)

                output_queue = asyncio.Queue()
                thread = Thread(target=generate_on_thread, args=(text_pipeline, input_text, output_queue))
                thread.start()
                result = await output_queue.get()
                thread.join()

                newly_generated_text = result[0]['generated_text']
                
                # Decode tokens to check for stopping strings
                for criteria in stopping_criteria_list:
                    if isinstance(criteria, StopOnKeywords):
                        for stop_ids in criteria.stop_words_ids:
                            decoded_stop_string = tokenizer.decode(stop_ids)
                            if decoded_stop_string in newly_generated_text:
                                stop_reason = f"stopping_string: {decoded_stop_string}"
                                break
                        if stop_reason:
                            break
                
                if stop_reason:
                    break

                all_generated_text += newly_generated_text
                yield {"response": [{'generated_text': newly_generated_text}]}

                if config.eos_token_id is not None:
                    eos_tokens = [config.eos_token_id]
                    if isinstance(config.eos_token_id, int):
                        eos_tokens = [config.eos_token_id]
                    elif isinstance(config.eos_token_id, list):
                        eos_tokens = config.eos_token_id
                    for eos_token in eos_tokens:
                        if tokenizer.decode([eos_token]) in newly_generated_text:
                            stop_reason = "eos_token"
                            break
                    if stop_reason:
                        break
                elif tokenizer.eos_token is not None and tokenizer.eos_token in newly_generated_text:
                    stop_reason = "eos_token"
                    break

                input_text = all_generated_text

        async def text_stream():
            async for data in generate_responses():
                yield f"data: {json.dumps(data)}\n\n"
            yield "data: [DONE]\n\n"

        return StreamingResponse(text_stream(), media_type="text/event-stream")

    except HTTPException as e:
        raise e
    except Exception as e:
        logger.error(f"Internal server error: {e}")
        raise HTTPException(status_code=500, detail=f"Internal server error: {e}")

if __name__ == "__main__":
    import torch
    uvicorn.run(app, host="0.0.0.0", port=7860)