Spaces:
Sleeping
Sleeping
NGUYEN, Xuan Phi
commited on
Commit
•
c1519e7
1
Parent(s):
7eb44d4
add sea lava16
Browse files
multipurpose_chatbot/engines/__init__.py
CHANGED
@@ -9,6 +9,7 @@ BACKENDS = [
|
|
9 |
# "llava_llama_cpp",
|
10 |
"debug",
|
11 |
"sealmmm_transformers",
|
|
|
12 |
]
|
13 |
|
14 |
ENGINE_LOADED = False
|
@@ -42,6 +43,9 @@ def load_multipurpose_chatbot_engine(backend: str):
|
|
42 |
elif backend == 'sealmmm_transformers':
|
43 |
from .sealmmm_engine import SeaLMMMv0Engine
|
44 |
model_engine = SeaLMMMv0Engine()
|
|
|
|
|
|
|
45 |
else:
|
46 |
raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
|
47 |
|
|
|
9 |
# "llava_llama_cpp",
|
10 |
"debug",
|
11 |
"sealmmm_transformers",
|
12 |
+
"sealava16_transformers"
|
13 |
]
|
14 |
|
15 |
ENGINE_LOADED = False
|
|
|
43 |
elif backend == 'sealmmm_transformers':
|
44 |
from .sealmmm_engine import SeaLMMMv0Engine
|
45 |
model_engine = SeaLMMMv0Engine()
|
46 |
+
elif backend == 'sealava16_transformers':
|
47 |
+
from .sealava16_transformers_engine import SeaLlava16Engine
|
48 |
+
model_engine = SeaLlava16Engine()
|
49 |
else:
|
50 |
raise ValueError(f'backend invalid: {BACKENDS} vs {backend}')
|
51 |
|
multipurpose_chatbot/engines/sealava16_transformers_engine.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from typing import Any, Iterator
|
7 |
+
from typing import Iterator, List, Optional, Tuple
|
8 |
+
import filelock
|
9 |
+
import glob
|
10 |
+
import json
|
11 |
+
import time
|
12 |
+
from gradio.routes import Request
|
13 |
+
from gradio.utils import SyncToAsyncIterator, async_iteration
|
14 |
+
from gradio.helpers import special_args
|
15 |
+
import anyio
|
16 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
17 |
+
|
18 |
+
from gradio_client.documentation import document, set_documentation_group
|
19 |
+
|
20 |
+
from typing import List, Optional, Union, Dict, Tuple
|
21 |
+
from tqdm.auto import tqdm
|
22 |
+
from huggingface_hub import snapshot_download
|
23 |
+
|
24 |
+
from gradio.components import Button
|
25 |
+
from gradio.events import Dependency, EventListenerMethod
|
26 |
+
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
|
27 |
+
import types
|
28 |
+
import sys
|
29 |
+
from .base_engine import BaseEngine
|
30 |
+
from .transformers_engine import TransformersEngine, NewGenerationMixin
|
31 |
+
|
32 |
+
from ..configs import (
|
33 |
+
STREAM_CHECK_MULTIPLE,
|
34 |
+
STREAM_YIELD_MULTIPLE,
|
35 |
+
)
|
36 |
+
|
37 |
+
CODE_PATH = os.environ.get("CODE_PATH", "")
|
38 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "")
|
39 |
+
|
40 |
+
IMAGE_TOKEN = "[IMAGE]<|image|>[/IMAGE]"
|
41 |
+
IMAGE_TOKEN = "<|image|>"
|
42 |
+
|
43 |
+
IMAGE_LENGTH = 576
|
44 |
+
MAX_PACHES = 5
|
45 |
+
|
46 |
+
|
47 |
+
BLOCK_LANGS = str(os.environ.get("BLOCK_LANGS", ""))
|
48 |
+
BLOCK_LANGS = [x.strip() for x in BLOCK_LANGS.strip().split(";")] if len(BLOCK_LANGS.strip()) > 0 else []
|
49 |
+
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0")))
|
50 |
+
KEYWORDS = os.environ.get("KEYWORDS", "").strip()
|
51 |
+
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else []
|
52 |
+
KEYWORDS = [x.lower() for x in KEYWORDS]
|
53 |
+
|
54 |
+
LANG_BLOCK_MESSAGE = """Unsupported language."""
|
55 |
+
|
56 |
+
KEYWORD_BLOCK_MESSAGE = "Invalid request."
|
57 |
+
|
58 |
+
|
59 |
+
def _detect_lang(text):
|
60 |
+
# Disable language that may have safety risk
|
61 |
+
from langdetect import detect as detect_lang
|
62 |
+
dlang = None
|
63 |
+
try:
|
64 |
+
dlang = detect_lang(text)
|
65 |
+
except Exception as e:
|
66 |
+
if "No features in text." in str(e):
|
67 |
+
return "en"
|
68 |
+
else:
|
69 |
+
return "zh"
|
70 |
+
return dlang
|
71 |
+
|
72 |
+
|
73 |
+
def block_lang(
|
74 |
+
message: str,
|
75 |
+
history: List[Tuple[str, str]] = None,
|
76 |
+
) -> str:
|
77 |
+
# relieve history base block
|
78 |
+
if len(BLOCK_LANGS) == 0:
|
79 |
+
return False
|
80 |
+
|
81 |
+
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history):
|
82 |
+
return True
|
83 |
+
else:
|
84 |
+
_lang = _detect_lang(message)
|
85 |
+
if _lang in BLOCK_LANGS:
|
86 |
+
# print(f'Detect blocked {_lang}: {message}')
|
87 |
+
return True
|
88 |
+
else:
|
89 |
+
return False
|
90 |
+
|
91 |
+
def safety_check(text, history=None, ) -> Optional[str]:
|
92 |
+
"""
|
93 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
94 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
95 |
+
"""
|
96 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
97 |
+
return KEYWORD_BLOCK_MESSAGE
|
98 |
+
|
99 |
+
if len(BLOCK_LANGS) > 0:
|
100 |
+
if block_lang(text, history):
|
101 |
+
return LANG_BLOCK_MESSAGE
|
102 |
+
|
103 |
+
return None
|
104 |
+
|
105 |
+
|
106 |
+
def safety_check_conversation_string(text, delimiter=None) -> Optional[str]:
|
107 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
108 |
+
return KEYWORD_BLOCK_MESSAGE
|
109 |
+
if len(BLOCK_LANGS) > 0:
|
110 |
+
import re
|
111 |
+
delimiter = delimiter or (r"</s>\n<\|im_start\|>user\n", r"</s>\n<\|im_start\|>assistant\n", r"<\|im_start\|>system\n")
|
112 |
+
turns = re.split(r"|".join(delimiter), text)
|
113 |
+
turns = [t for t in turns if t.strip() != '']
|
114 |
+
for t in turns:
|
115 |
+
if block_lang(t):
|
116 |
+
return LANG_BLOCK_MESSAGE
|
117 |
+
return None
|
118 |
+
|
119 |
+
|
120 |
+
def is_check_safety():
|
121 |
+
return len(KEYWORDS) > 0 or len(BLOCK_LANGS) > 0
|
122 |
+
|
123 |
+
|
124 |
+
def safety_check_conversation(conversation) -> Optional[str]:
|
125 |
+
"""
|
126 |
+
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content.
|
127 |
+
This provides an additional security measure to enhance safety and compliance with local regulations.
|
128 |
+
"""
|
129 |
+
texts = [c['content'] for c in conversation]
|
130 |
+
for text in texts:
|
131 |
+
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS):
|
132 |
+
return KEYWORD_BLOCK_MESSAGE
|
133 |
+
|
134 |
+
if len(BLOCK_LANGS) > 0:
|
135 |
+
if block_lang(text):
|
136 |
+
return LANG_BLOCK_MESSAGE
|
137 |
+
return None
|
138 |
+
|
139 |
+
|
140 |
+
class SeaLlava16Engine(TransformersEngine):
|
141 |
+
|
142 |
+
@property
|
143 |
+
def image_token(self):
|
144 |
+
return IMAGE_TOKEN
|
145 |
+
|
146 |
+
@property
|
147 |
+
def max_position_embeddings(self) -> int:
|
148 |
+
return self._model.config.max_position_embeddings
|
149 |
+
|
150 |
+
@property
|
151 |
+
def tokenizer(self):
|
152 |
+
return self._tokenizer
|
153 |
+
|
154 |
+
@property
|
155 |
+
def processor(self):
|
156 |
+
return self._processor
|
157 |
+
|
158 |
+
def load_model(self):
|
159 |
+
from transformers import AutoProcessor
|
160 |
+
import sys
|
161 |
+
# caution: path[0] is reserved for script path (or '' in REPL)
|
162 |
+
sys.path.append(CODE_PATH)
|
163 |
+
|
164 |
+
|
165 |
+
from transformers.models.llava_next.modeling_llava_next import LlavaNextForConditionalGeneration
|
166 |
+
from transformers.models.llava_next.processing_llava_next import LlavaNextProcessor
|
167 |
+
model_path = MODEL_PATH
|
168 |
+
print(f'Loading model from {model_path}')
|
169 |
+
|
170 |
+
print(f'model_path={model_path}')
|
171 |
+
if os.path.exists(f"{model_path}/pytorch_model_fsdp.bin") and not os.path.exists(f"{model_path}/pytorch_model.bin"):
|
172 |
+
os.symlink("pytorch_model_fsdp.bin", f"{model_path}/pytorch_model.bin")
|
173 |
+
|
174 |
+
self._processor = LlavaNextProcessor.from_pretrained(model_path)
|
175 |
+
self._model = LlavaNextForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="cuda").eval()
|
176 |
+
|
177 |
+
self._model.sample_old = self._model.sample
|
178 |
+
self._model._sample = types.MethodType(NewGenerationMixin.sample_stream, self._model)
|
179 |
+
|
180 |
+
self._tokenizer = self._processor.tokenizer
|
181 |
+
print(self._model)
|
182 |
+
print(f"{self.max_position_embeddings=}")
|
183 |
+
|
184 |
+
def get_multimodal_tokens(self, full_prompt, image_paths=None):
|
185 |
+
num_tokens = len(self.tokenizer.encode(full_prompt))
|
186 |
+
for image_path in image_paths:
|
187 |
+
num_tokens += IMAGE_LENGTH * MAX_PACHES
|
188 |
+
return num_tokens
|
189 |
+
|
190 |
+
def maybe_raise_safety(self, message, gen_index=-1):
|
191 |
+
if is_check_safety():
|
192 |
+
if gen_index < 0:
|
193 |
+
message_safety = safety_check_conversation_string(message)
|
194 |
+
if message_safety is not None:
|
195 |
+
raise gr.Error(message_safety)
|
196 |
+
else:
|
197 |
+
if STREAM_CHECK_MULTIPLE > 0 and gen_index % STREAM_CHECK_MULTIPLE == 0:
|
198 |
+
message_safety = safety_check_conversation_string(message)
|
199 |
+
if message_safety is not None:
|
200 |
+
raise gr.Error(message_safety)
|
201 |
+
|
202 |
+
def generate_yield_string(self, prompt, temperature, max_tokens, stop_strings: Optional[Tuple[str]] = None, **kwargs):
|
203 |
+
from transformers.generation.utils import GenerationConfig
|
204 |
+
from PIL import Image
|
205 |
+
image_paths = kwargs.get("image_paths", None)
|
206 |
+
image_paths = image_paths or []
|
207 |
+
|
208 |
+
images = [Image.open(x) for x in image_paths] if len(image_paths) > 0 else None
|
209 |
+
|
210 |
+
with torch.no_grad():
|
211 |
+
# inputs = self.processor(prompt, images, return_tensors='pt', concat_images=True)
|
212 |
+
inputs = self.processor(prompt, images, return_tensors='pt')
|
213 |
+
# inputs = inputs.to("cuda", torch.bfloat16)
|
214 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items() if v is not None}
|
215 |
+
num_tokens = self.get_multimodal_tokens(prompt, image_paths)
|
216 |
+
# non-streaming generation
|
217 |
+
# output = self._model.generate(
|
218 |
+
# **inputs,
|
219 |
+
# do_sample=True,
|
220 |
+
# temperature=temperature,
|
221 |
+
# max_new_tokens=max_tokens,
|
222 |
+
# pad_token_id=self.processor.tokenizer.pad_token_id,
|
223 |
+
# )
|
224 |
+
# # response = self.processor.tokenizer.decode(output[0][-inputs.input_ids.size(-1):], skip_special_tokens=True)
|
225 |
+
# full_output_text = self.processor.decode(output[0], skip_special_tokens=True)
|
226 |
+
# response = full_output_text.split("<|im_start|>assistant\n")[-1]
|
227 |
+
# num_tokens = self.get_multimodal_tokens(prompt + response, image_paths)
|
228 |
+
# print(prompt)
|
229 |
+
# print(response)
|
230 |
+
# print(num_tokens)
|
231 |
+
# yield response, num_tokens
|
232 |
+
|
233 |
+
# if i % 4 == 0 and i > 1:
|
234 |
+
# message_safety = safety_check(response)
|
235 |
+
# if message_safety is not None:
|
236 |
+
# history = undo_history(history)
|
237 |
+
# yield history, "", None
|
238 |
+
# raise gr.Error(message_safety)
|
239 |
+
self.maybe_raise_safety(prompt)
|
240 |
+
|
241 |
+
# # ! streaming
|
242 |
+
generator = self._model.generate(
|
243 |
+
**inputs,
|
244 |
+
do_sample=True,
|
245 |
+
temperature=temperature,
|
246 |
+
max_new_tokens=max_tokens,
|
247 |
+
pad_token_id=self.processor.tokenizer.pad_token_id,
|
248 |
+
)
|
249 |
+
|
250 |
+
out_tokens = []
|
251 |
+
response = None
|
252 |
+
print(f"{STREAM_YIELD_MULTIPLE=}")
|
253 |
+
for index, token in enumerate(generator):
|
254 |
+
out_tokens.append(token.item())
|
255 |
+
response = self.tokenizer.decode(out_tokens, skip_special_tokens=True)
|
256 |
+
|
257 |
+
self.maybe_raise_safety(response, gen_index=index)
|
258 |
+
|
259 |
+
if STREAM_YIELD_MULTIPLE > 0:
|
260 |
+
if index % STREAM_YIELD_MULTIPLE == 0 and index > 0:
|
261 |
+
yield response, num_tokens
|
262 |
+
else:
|
263 |
+
yield response, num_tokens
|
264 |
+
|
265 |
+
del generator
|
266 |
+
|
267 |
+
if response is not None:
|
268 |
+
self.maybe_raise_safety(prompt)
|
269 |
+
|
270 |
+
full_text = prompt + response
|
271 |
+
num_tokens = self.get_multimodal_tokens(full_text, image_paths)
|
272 |
+
yield response, num_tokens
|
273 |
+
|