Spaces:
Running
on
Zero
Running
on
Zero
# Copyright: DAMO Academy, Alibaba Group | |
# By Xuan Phi Nguyen at DAMO Academy, Alibaba Group | |
# Description: | |
""" | |
VLLM-based demo script to launch Language chat model for Southeast Asian Languages | |
""" | |
import os | |
import numpy as np | |
import argparse | |
import torch | |
import gradio as gr | |
from typing import Any, Iterator | |
from typing import Iterator, List, Optional, Tuple | |
import filelock | |
import glob | |
import json | |
from gradio_client.documentation import document, set_documentation_group | |
from typing import List, Optional, Union, Dict, Tuple | |
from tqdm.auto import tqdm | |
from huggingface_hub import snapshot_download | |
# @@ environments ================ | |
DEBUG = bool(int(os.environ.get("DEBUG", "1"))) | |
BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "1"))) | |
# for lang block, wether to block in history too | |
LANG_BLOCK_HISTORY = bool(int(os.environ.get("LANG_BLOCK_HISTORY", "0"))) | |
TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1")) | |
DTYPE = os.environ.get("DTYPE", "bfloat16") | |
# ! (no debug) whether to download HF_MODEL_NAME and save to MODEL_PATH | |
DOWNLOAD_SNAPSHOT = bool(int(os.environ.get("DOWNLOAD_SNAPSHOT", "0"))) | |
LOG_RESPONSE = bool(int(os.environ.get("LOG_RESPONSE", "0"))) | |
# ! show model path in the demo page, only for internal | |
DISPLAY_MODEL_PATH = bool(int(os.environ.get("DISPLAY_MODEL_PATH", "1"))) | |
# ! uploaded model path, will be downloaded to MODEL_PATH | |
HF_MODEL_NAME = os.environ.get("HF_MODEL_NAME", "DAMO-NLP-SG/seal-13b-chat-a") | |
# ! if model is private, need HF_TOKEN to access the model | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
# ! path where the model is downloaded, either on ./ or persistent disc | |
MODEL_PATH = os.environ.get("MODEL_PATH", "./seal-13b-chat-a") | |
# ! !! Whether to delete the folder, ONLY SET THIS IF YOU WANT TO DELETE SAVED MODEL ON PERSISTENT DISC | |
DELETE_FOLDER = os.environ.get("DELETE_FOLDER", "") | |
IS_DELETE_FOLDER = DELETE_FOLDER is not None and os.path.exists(DELETE_FOLDER) | |
print(f'DELETE_FOLDER: {DELETE_FOLDER} | {DOWNLOAD_SNAPSHOT=}') | |
# ! list of keywords to disabled as security measures to comply with local regulation | |
KEYWORDS = os.environ.get("KEYWORDS", "").strip() | |
KEYWORDS = KEYWORDS.split(";") if len(KEYWORDS) > 0 else [] | |
KEYWORDS = [x.lower() for x in KEYWORDS] | |
# gradio config | |
PORT = int(os.environ.get("PORT", "7860")) | |
# how many iterations to yield response | |
STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1")) | |
# how many iterations to perform safety check on response | |
STREAM_CHECK_MULTIPLE = int(os.environ.get("STREAM_CHECK_MULTIPLE", "0")) | |
# self explanatory | |
MAX_TOKENS = int(os.environ.get("MAX_TOKENS", "2048")) | |
TEMPERATURE = float(os.environ.get("TEMPERATURE", "0.1")) | |
FREQUENCE_PENALTY = float(os.environ.get("FREQUENCE_PENALTY", "0.4")) | |
gpu_memory_utilization = float(os.environ.get("gpu_memory_utilization", "0.9")) | |
# whether to enable quantization, currently not in use | |
QUANTIZATION = str(os.environ.get("QUANTIZATION", "")) | |
""" | |
Internal instructions of how to configure the DEMO | |
1. Upload SFT model as a model to huggingface: hugginface/models/seal_13b_a | |
2. If the model weights is private, set HF_TOKEN=<your private hf token> in https://huggingface.co/spaces/????/?????/settings | |
3. space config env: `HF_MODEL_NAME=DAMO-NLP-SG/seal-13b-chat-a` or the underlining model | |
4. If enable persistent storage: set | |
HF_HOME=/data/.huggingface | |
MODEL_PATH=/data/.huggingface/seal-13b-chat-a | |
if not: | |
MODEL_PATH=./seal-13b-chat-a | |
""" | |
# ============================== | |
print(f'DEBUG mode: {DEBUG}') | |
print(f'Torch version: {torch.__version__}') | |
try: | |
print(f'Torch CUDA version: {torch.version.cuda}') | |
except Exception as e: | |
print(f'Failed to print cuda version: {e}') | |
try: | |
compute_capability = torch.cuda.get_device_capability() | |
print(f'Torch CUDA compute_capability: {compute_capability}') | |
except Exception as e: | |
print(f'Failed to print compute_capability version: {e}') | |
# @@ constants ================ | |
DTYPES = { | |
'float16': torch.float16, | |
'bfloat16': torch.bfloat16 | |
} | |
llm = None | |
demo = None | |
BOS_TOKEN = '<s>' | |
EOS_TOKEN = '</s>' | |
B_INST, E_INST = "[INST]", "[/INST]" | |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" | |
SYSTEM_PROMPT_1 = """You are a multilingual, helpful, respectful and honest assistant. Your name is SeaLLM and you are built by DAMO Academy, Alibaba Group. \ | |
Please always answer as helpfully as possible, while being safe. Your \ | |
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure \ | |
that your responses are socially unbiased and positive in nature. | |
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ | |
correct. If you don't know the answer to a question, please don't share false information. | |
As a multilingual assistant, you must respond and follow instructions in the native language of the user by default, unless told otherwise. \ | |
Your response should adapt to the norms and customs of the respective language and culture. | |
""" | |
# ============ CONSTANT ============ | |
# https://github.com/gradio-app/gradio/issues/884 | |
MODEL_NAME = "SeaLLM-13B" | |
MODEL_TITLE = "SeaLLM-13B - An Assistant for Southeast Asian Languages" | |
MODEL_TITLE = """ | |
<div class="container" style=" | |
align-items: center; | |
justify-content: center; | |
display: flex; | |
"> | |
<div class="image" > | |
<img src="file/seal_logo.png" style=" | |
max-width: 10em; | |
max-height: 5%; | |
height: 3em; | |
width: 3em; | |
float: left; | |
margin-left: auto; | |
"> | |
</div> | |
<div class="text" style=" | |
padding-left: 20px; | |
padding-top: 1%; | |
float: left; | |
"> | |
<h1>SeaLLMs - Large Language Models for Southeast Asia</h1> | |
</div> | |
</div> | |
""" | |
MODEL_DESC = """ | |
<div style='display:flex; gap: 0.25rem; '> | |
<a href=''><img src='https://img.shields.io/badge/Github-Code-success'></a> | |
<a href='https://huggingface.co/spaces/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> | |
<a href='https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> | |
<a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a> | |
</div> | |
<span style="font-size: larger"> | |
This is <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">SeaLLM-13B-Chat</a> - a chatbot assistant optimized for Southeast Asian Languages. It produces helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭. | |
Explore <a href="https://huggingface.co/SeaLLMs/SeaLLM-Chat-13b" target="_blank">our article</a> for more details. | |
</span> | |
<br> | |
<span > | |
NOTE: The chatbot may produce inaccurate and harmful information about people, places, or facts. | |
<u style="color: red">By using our service, you are required to agree to the following terms:</u><br> | |
<ul> | |
<li > | |
You must not use our service to generate any harmful, unethical or illegal content that violates locally applicable and international laws or regulations, | |
including but not limited to hate speech, violence, pornography and deception.</li> | |
<li > | |
The service collects user dialogue data for testing and performance improvement, and reserves the right to distribute it under | |
<a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution (CC-BY)</a> or similar license. So do not enter any personal information! | |
</li> | |
</ul> | |
</span> | |
""".strip() | |
cite_markdown = """ | |
## Citation | |
If you find our project useful, hope you can star our repo and cite our paper as follows: | |
``` | |
@article{damonlpsg2023seallm, | |
author = {Xuan-Phi Nguyen*, Wenxuan Zhang*, Xin Li*, Mahani Aljunied*, Qingyu Tan, Liying Cheng, Guanzheng Chen, Yue Deng, Sen Yang, Chaoqun Liu, Hang Zhang, Lidong Bing}, | |
title = {SeaLLMs - Large Language Models for Southeast Asia}, | |
year = 2023, | |
} | |
``` | |
""" | |
path_markdown = """ | |
#### Model path: | |
{model_path} | |
""" | |
def _detect_lang(text): | |
# Disable language that may have safety risk | |
from langdetect import detect as detect_lang | |
dlang = None | |
try: | |
dlang = detect_lang(text) | |
except Exception as e: | |
print(f'Error: {e}') | |
if "No features in text." in str(e): | |
return "en" | |
else: | |
return "zh" | |
return dlang | |
def custom_hf_model_weights_iterator( | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
use_np_cache: bool = False, | |
) -> Iterator[Tuple[str, torch.Tensor]]: | |
# ! if use vllm==0.1.4, use this to augment hf_model_weights_iterator loader | |
from vllm.model_executor.weight_utils import Disabledtqdm | |
# Prepare file lock directory to prevent multiple processes from | |
# downloading the same model weights at the same time. | |
lock_dir = cache_dir if cache_dir is not None else "/tmp" | |
lock_file_name = model_name_or_path.replace("/", "-") + ".lock" | |
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) | |
# Download model weights from huggingface. | |
is_local = os.path.isdir(model_name_or_path) | |
if not is_local: | |
with lock: | |
hf_folder = snapshot_download(model_name_or_path, | |
allow_patterns="*.bin", | |
cache_dir=cache_dir, | |
local_files_only=True, | |
tqdm_class=Disabledtqdm) | |
else: | |
hf_folder = model_name_or_path | |
hf_bin_files = [ | |
x for x in glob.glob(os.path.join(hf_folder, "*model*.bin")) | |
if not x.endswith("training_args.bin") | |
] | |
hf_safetensors_files = [ | |
x for x in glob.glob(os.path.join(hf_folder, "*model*.safetensors")) | |
if not x.endswith("training_args.bin") | |
] | |
if use_np_cache: | |
# Convert the model weights from torch tensors to numpy arrays for | |
# faster loading. | |
np_folder = os.path.join(hf_folder, "np") | |
os.makedirs(np_folder, exist_ok=True) | |
weight_names_file = os.path.join(np_folder, "weight_names.json") | |
with lock: | |
if not os.path.exists(weight_names_file): | |
weight_names = [] | |
for bin_file in hf_bin_files: | |
state = torch.load(bin_file, map_location="cpu") | |
for name, param in state.items(): | |
param_path = os.path.join(np_folder, name) | |
with open(param_path, "wb") as f: | |
np.save(f, param.cpu().detach().numpy()) | |
weight_names.append(name) | |
with open(weight_names_file, "w") as f: | |
json.dump(weight_names, f) | |
with open(weight_names_file, "r") as f: | |
weight_names = json.load(f) | |
for name in weight_names: | |
param_path = os.path.join(np_folder, name) | |
with open(param_path, "rb") as f: | |
param = np.load(f) | |
yield name, torch.from_numpy(param) | |
else: | |
if len(hf_bin_files) > 0: | |
print(F'Load bin files: {hf_bin_files}') | |
for bin_file in hf_bin_files: | |
state = torch.load(bin_file, map_location="cpu") | |
for name, param in state.items(): | |
yield name, param | |
del state | |
torch.cuda.empty_cache() | |
elif len(hf_safetensors_files) > 0: | |
print(F'Load safetensor files: {hf_safetensors_files}') | |
from safetensors.torch import load_file | |
for safe_file in hf_safetensors_files: | |
# state = torch.load(bin_file, map_location="cpu") | |
state = load_file(safe_file) | |
for name, param in state.items(): | |
yield name, param | |
del state | |
torch.cuda.empty_cache() | |
else: | |
raise ValueError(f'no files available either bin or safe') | |
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: | |
"""convert PySafeSlice object from safetensors to torch.Tensor | |
PySafeSlice object supports indexing, which is done before loading the | |
actual tensor and can reduce the amount of memory being read into the | |
memory. However, it does not support more advanced functionalities | |
like `.view()` or `.t()`. Therefore, if we need to modify the loaded | |
tensor with these more complicated operators, we need to convert to | |
tensor first. | |
""" | |
if not isinstance(x, torch.Tensor): | |
x = x[:] | |
return x | |
def load_padded_tensor_parallel_vocab( | |
param: torch.Tensor, | |
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` | |
tensor_model_parallel_rank: int, | |
) -> None: | |
shard_size = param.shape[0] | |
start_idx = tensor_model_parallel_rank * shard_size | |
end_idx = (tensor_model_parallel_rank + 1) * shard_size | |
loaded_weight = loaded_weight[start_idx:end_idx] | |
loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
param[:loaded_weight.shape[0]].copy_(loaded_weight) | |
def llama_load_weights( | |
self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
use_np_cache: bool = False, | |
load_format: str = "auto", | |
revision: Optional[str] = None | |
): | |
# if use vllm==0.1.4 | |
from vllm.model_executor.weight_utils import ( | |
load_tensor_parallel_weights | |
) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
tp_size = get_tensor_model_parallel_world_size() | |
tensor_model_parallel_rank = get_tensor_model_parallel_rank() | |
q_proj_shard_size = (self.config.hidden_size // tp_size) | |
kv_proj_shard_size = (self.config.hidden_size // | |
self.config.num_attention_heads * | |
getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) // tp_size) | |
attention_weight_specs = [ | |
# (weight_name, shard_size, offset) | |
("q_proj", q_proj_shard_size, 0), | |
("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
("v_proj", kv_proj_shard_size, | |
q_proj_shard_size + kv_proj_shard_size), | |
] | |
state_dict = self.state_dict() | |
need_to_load = len(state_dict) | |
loaded = 0 | |
iterator = custom_hf_model_weights_iterator(model_name_or_path, cache_dir, use_np_cache) | |
for name, loaded_weight in iterator: | |
if "rotary_emb.inv_freq" in name: | |
continue | |
if "embed_tokens" in name or "lm_head" in name: | |
param = state_dict[name] | |
# Consider padding in the vocab size. | |
padded_vocab_size = (param.shape[0] * tp_size) | |
# num_extra_rows = padded_vocab_size - self.config.vocab_size | |
num_extra_rows = padded_vocab_size - loaded_weight.size(0) | |
load_size = loaded_weight.size() | |
extra_rows = torch.empty(num_extra_rows, | |
loaded_weight.shape[1]) | |
extra_rows = extra_rows.to(loaded_weight) | |
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0) | |
if num_extra_rows > 0: | |
print(f'Add empty to {num_extra_rows} extra row for {name}') | |
print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}') | |
is_attention_weight = False | |
for weight_name, shard_size, offset in attention_weight_specs: | |
if weight_name not in name or "qkv_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "qkv_proj")] | |
loaded_weight = loaded_weight[ | |
shard_size * tensor_model_parallel_rank:shard_size * | |
(tensor_model_parallel_rank + 1)] | |
param_slice = param.data[offset:offset + shard_size] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 3 | |
is_attention_weight = True | |
break | |
if is_attention_weight: | |
continue | |
# ! qkv_proj is sharded differently if concatenated into qkv | |
# qkv: qqqq kkkk vvvv | |
# lweight: qq0qq1 kk0kk1 vv0vv1 | |
# q_shard_size: hidden_size // tp_size = qq | |
# qkv_s0: qq0_kk0_vv0 | |
# qkv_s1: qq1_kk1_vv1 | |
if "qkv_proj" in name: | |
param = state_dict[name] | |
# loaded_weight | |
qsize = self.config.hidden_size | |
kvsize = self.config.hidden_size // self.config.num_attention_heads * getattr(self.config, "num_key_value_heads", self.config.num_attention_heads) | |
q_offsets = ( | |
q_proj_shard_size * tensor_model_parallel_rank, | |
q_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
k_offsets = ( | |
qsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
qsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
v_offsets = ( | |
qsize + kvsize + kv_proj_shard_size * tensor_model_parallel_rank, | |
qsize + kvsize + kv_proj_shard_size * (tensor_model_parallel_rank + 1) | |
) | |
_loaded_weight = torch.cat( | |
[ | |
loaded_weight[q_offsets[0]:q_offsets[1]], | |
loaded_weight[k_offsets[0]:k_offsets[1]], | |
loaded_weight[v_offsets[0]:v_offsets[1]], | |
], 0 | |
) | |
assert param.shape == _loaded_weight.shape, f'{param.shape=} != {_loaded_weight.shape=}' | |
param.data.copy_(_loaded_weight) | |
loaded += 1.0 | |
is_attention_weight = True | |
if is_attention_weight: | |
continue | |
is_gate_up_weight = False | |
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
if weight_name not in name or "gate_up_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
shard_size = param.shape[0] // 2 | |
loaded_weight = loaded_weight[ | |
shard_size * tensor_model_parallel_rank:shard_size * | |
(tensor_model_parallel_rank + 1)] | |
param_slice = param.data[shard_size * stride_id:shard_size * | |
(stride_id + 1)] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 2 | |
is_gate_up_weight = True | |
break | |
if is_gate_up_weight: | |
continue | |
if "gate_up_proj" in name: | |
param = state_dict[name] | |
shard_size = param.shape[0] // 2 | |
intermediate_size = self.config.intermediate_size | |
g_offsets = ( | |
shard_size * tensor_model_parallel_rank, | |
shard_size * (tensor_model_parallel_rank + 1) | |
) | |
u_offsets = ( | |
intermediate_size + shard_size * tensor_model_parallel_rank, | |
intermediate_size + shard_size * (tensor_model_parallel_rank + 1) | |
) | |
_loaded_weight = torch.cat( | |
[ | |
loaded_weight[g_offsets[0]:g_offsets[1]], | |
loaded_weight[u_offsets[0]:u_offsets[1]], | |
], 0 | |
) | |
assert param.shape == _loaded_weight.shape | |
param.data.copy_(_loaded_weight) | |
loaded += 1.0 | |
is_gate_up_weight = True | |
if is_gate_up_weight: | |
continue | |
param = state_dict[name] | |
load_tensor_parallel_weights(param, loaded_weight, name, | |
self._column_parallel_weights, | |
self._row_parallel_weights, | |
tensor_model_parallel_rank) | |
loaded += 1 | |
if np.abs(loaded - need_to_load) < 0.01: | |
print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
else: | |
print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
def new_llama_load_weights( | |
self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
load_format: str = "auto", | |
revision: Optional[str] = None | |
): | |
# If use newest vllm, not been thoroughly tested yet. | |
from vllm.model_executor.weight_utils import ( | |
load_tensor_parallel_weights, hf_model_weights_iterator | |
) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
if self.quant_config is None: | |
weight_suffixes = ["weight"] | |
else: | |
weight_suffixes = self.quant_config.get_tp_tensor_names() | |
column_parallel_weights: List[str] = [] | |
for layer in self._column_parallel_layers: | |
for suffix in weight_suffixes: | |
column_parallel_weights.append(f"{layer}.{suffix}") | |
row_parallel_weights: List[str] = [] | |
for layer in self._row_parallel_layers: | |
for suffix in weight_suffixes: | |
row_parallel_weights.append(f"{layer}.{suffix}") | |
tp_size = get_tensor_model_parallel_world_size() | |
tp_rank = get_tensor_model_parallel_rank() | |
assert tp_size == 1, f'tensorparallel >=2 not allowed. {tp_size}' | |
q_proj_shard_size = (self.config.hidden_size // tp_size) | |
num_kv_heads_replicas = max(1, | |
tp_size // self.config.num_key_value_heads) | |
num_kv_heads_per_gpu = max(1, | |
self.config.num_key_value_heads // tp_size) | |
kv_proj_shard_size = (self.config.hidden_size // | |
self.config.num_attention_heads * | |
num_kv_heads_per_gpu) | |
attention_weight_specs = [ | |
# (weight_name, shard_size, offset) | |
("q_proj", q_proj_shard_size, 0), | |
("k_proj", kv_proj_shard_size, q_proj_shard_size), | |
("v_proj", kv_proj_shard_size, | |
q_proj_shard_size + kv_proj_shard_size), | |
] | |
state_dict = self.state_dict() | |
need_to_load = len(state_dict) | |
loaded = 0 | |
for name, loaded_weight in hf_model_weights_iterator( | |
model_name_or_path, cache_dir, load_format, revision): | |
if "rotary_emb.inv_freq" in name: | |
continue | |
is_packed = False | |
is_transposed = False | |
if self.quant_config is not None: | |
is_packed = self.quant_config.is_packed(name) | |
is_transposed = self.quant_config.is_transposed(name) | |
if is_transposed: | |
loaded_weight = convert_pyslice_to_tensor(loaded_weight) | |
loaded_weight = loaded_weight.T | |
is_attention_weight = False | |
for weight_name, shard_size, offset in attention_weight_specs: | |
if weight_name not in name or "qkv_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "qkv_proj")] | |
if is_transposed: | |
param = param.T | |
if is_packed: | |
shard_size //= self.quant_config.pack_factor | |
offset //= self.quant_config.pack_factor | |
if weight_name in ["k_proj", "v_proj"]: | |
shard_id = tp_rank // num_kv_heads_replicas | |
else: | |
shard_id = tp_rank | |
loaded_weight = loaded_weight[shard_size * | |
shard_id:shard_size * | |
(shard_id + 1)] | |
param_slice = param.data[offset:offset + shard_size] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 3 | |
is_attention_weight = True | |
break | |
if is_attention_weight: | |
continue | |
# TODO: need to figure out to do sharding with qkv_proj fused | |
is_gate_up_weight = False | |
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): | |
if weight_name not in name or "gate_up_proj" in name: | |
continue | |
param = state_dict[name.replace(weight_name, "gate_up_proj")] | |
if is_transposed: | |
param = param.T | |
shard_size = param.shape[0] // 2 | |
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * | |
(tp_rank + 1)] | |
param_slice = param.data[shard_size * stride_id:shard_size * | |
(stride_id + 1)] | |
assert param_slice.shape == loaded_weight.shape | |
param_slice.copy_(loaded_weight) | |
loaded += 1.0 / 2 | |
is_gate_up_weight = True | |
break | |
if is_gate_up_weight: | |
continue | |
# TODO: need to figure out to do sharding with gate_up_proj fused | |
param = state_dict[name] | |
if is_transposed: | |
param = param.T | |
if "embed_tokens" in name or "lm_head" in name: | |
load_padded_tensor_parallel_vocab(param, loaded_weight, | |
tp_rank) | |
loaded += 1 | |
continue | |
load_tensor_parallel_weights(param, loaded_weight, name, | |
column_parallel_weights, | |
row_parallel_weights, tp_rank) | |
loaded += 1 | |
if np.abs(loaded - need_to_load) < 0.01: | |
print(f'WARNING: only {loaded} params loaded out of {need_to_load}') | |
else: | |
print(f'Loaded all {loaded} params loaded out of {need_to_load}') | |
# Reassign LlamaForCausalLM.load_weights with llama_load_weights | |
if not DEBUG: | |
try: | |
import vllm | |
from vllm.model_executor.model_loader import _MODEL_REGISTRY | |
from vllm.model_executor.models import LlamaForCausalLM | |
_MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM | |
if vllm.__version__ == "0.1.4": | |
LlamaForCausalLM.load_weights = llama_load_weights | |
else: | |
LlamaForCausalLM.load_weights = new_llama_load_weights | |
if DTYPE == "bfloat16": | |
try: | |
compute_capability = torch.cuda.get_device_capability() | |
if compute_capability[0] < 8: | |
gpu_name = torch.cuda.get_device_name() | |
print( | |
"Bfloat16 is only supported on GPUs with compute capability " | |
f"of at least 8.0. Your {gpu_name} GPU has compute capability " | |
f"{compute_capability[0]}.{compute_capability[1]}. --> Move to FLOAT16") | |
DTYPE = "float16" | |
except Exception as e: | |
print(f'Unable to obtain compute_capability: {e}') | |
except Exception as e: | |
print(f'Failing import and reconfigure VLLM: {str(e)}') | |
# ! ================================================================== | |
set_documentation_group("component") | |
RES_PRINTED = False | |
def llama_chat_sys_input_seq_constructor(text, sys_prompt=SYSTEM_PROMPT_1, bos_token=BOS_TOKEN, eos_token=EOS_TOKEN): | |
return f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {text} {E_INST}" | |
def llama_chat_multiturn_sys_input_seq_constructor( | |
message: str, | |
history: List[Tuple[str, str]], | |
sys_prompt=SYSTEM_PROMPT_1, | |
bos_token=BOS_TOKEN, | |
eos_token=EOS_TOKEN, | |
): | |
""" | |
``` | |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] | |
``` | |
""" | |
text = '' | |
for i, (prompt, res) in enumerate(history): | |
if i == 0: | |
text += f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {prompt} {E_INST}" | |
else: | |
text += f"{bos_token}{B_INST} {prompt} {E_INST}" | |
if res is not None: | |
text += f" {res} {eos_token} " | |
if len(history) == 0 or text.strip() == '': | |
text = f"{bos_token}{B_INST} {B_SYS} {sys_prompt} {E_SYS} {message} {E_INST}" | |
else: | |
text += f"{bos_token}{B_INST} {message} {E_INST}" | |
return text | |
class ChatBot(gr.Chatbot): | |
def _postprocess_chat_messages( | |
self, chat_message | |
): | |
x = super()._postprocess_chat_messages(chat_message) | |
if isinstance(x, str): | |
x = x.strip().replace("\n", "<br>") | |
return x | |
from gradio.components import Button | |
from gradio.events import Dependency, EventListenerMethod | |
# replace events so that submit button is disabled during generation, if stop_btn not found | |
# this prevent weird behavior | |
def _setup_stop_events( | |
self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency | |
) -> None: | |
event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers] | |
if self.stop_btn and self.is_generator: | |
if self.submit_btn: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: ( | |
Button.update(visible=False), | |
Button.update(visible=True), | |
), | |
None, | |
[self.submit_btn, self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: (Button.update(visible=True), Button.update(visible=False)), | |
None, | |
[self.submit_btn, self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
else: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: Button.update(visible=True), | |
None, | |
[self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: Button.update(visible=False), | |
None, | |
[self.stop_btn], | |
api_name=False, | |
queue=False, | |
) | |
self.stop_btn.click( | |
None, | |
None, | |
None, | |
cancels=event_to_cancel, | |
api_name=False, | |
) | |
else: | |
if self.submit_btn: | |
for event_trigger in event_triggers: | |
event_trigger( | |
lambda: Button.update(interactive=False), | |
None, | |
[self.submit_btn], | |
api_name=False, | |
queue=False, | |
) | |
event_to_cancel.then( | |
lambda: Button.update(interactive=True), | |
None, | |
[self.submit_btn], | |
api_name=False, | |
queue=False, | |
) | |
# upon clear, cancel the submit event as well | |
if self.clear_btn: | |
self.clear_btn.click( | |
lambda: ([], [], None, Button.update(interactive=True)), | |
None, | |
[self.chatbot, self.chatbot_state, self.saved_input, self.submit_btn], | |
queue=False, | |
api_name=False, | |
cancels=event_to_cancel, | |
) | |
# TODO: reconfigure clear button as stop and clear button | |
def _setup_events(self) -> None: | |
has_on = False | |
try: | |
from gradio.events import Dependency, EventListenerMethod, on | |
has_on = True | |
except ImportError as ie: | |
has_on = False | |
submit_fn = self._stream_fn if self.is_generator else self._submit_fn | |
if has_on: | |
# new version | |
submit_triggers = ( | |
[self.textbox.submit, self.submit_btn.click] | |
if self.submit_btn | |
else [self.textbox.submit] | |
) | |
submit_event = ( | |
on( | |
submit_triggers, | |
self._clear_and_save_textbox, | |
[self.textbox], | |
[self.textbox, self.saved_input], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
self._display_input, | |
[self.saved_input, self.chatbot_state], | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
submit_fn, | |
[self.saved_input, self.chatbot_state] + self.additional_inputs, | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
) | |
) | |
self._setup_stop_events(submit_triggers, submit_event) | |
else: | |
raise ValueError(f'Better install new gradio version than 3.44.0') | |
if self.retry_btn: | |
retry_event = ( | |
self.retry_btn.click( | |
self._delete_prev_fn, | |
[self.chatbot_state], | |
[self.chatbot, self.saved_input, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
self._display_input, | |
[self.saved_input, self.chatbot_state], | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
) | |
.then( | |
submit_fn, | |
[self.saved_input, self.chatbot_state] + self.additional_inputs, | |
[self.chatbot, self.chatbot_state], | |
api_name=False, | |
) | |
) | |
self._setup_stop_events([self.retry_btn.click], retry_event) | |
if self.undo_btn: | |
self.undo_btn.click( | |
self._delete_prev_fn, | |
[self.chatbot_state], | |
[self.chatbot, self.saved_input, self.chatbot_state], | |
api_name=False, | |
queue=False, | |
).then( | |
lambda x: x, | |
[self.saved_input], | |
[self.textbox], | |
api_name=False, | |
queue=False, | |
) | |
# Reconfigure clear_btn to stop and clear text box | |
# if self.clear_btn: | |
# self.clear_btn.click( | |
# lambda: ([], [], None), | |
# None, | |
# [self.chatbot, self.chatbot_state, self.saved_input], | |
# queue=False, | |
# api_name=False, | |
# cancels=submit_event, | |
# ) | |
# replace | |
gr.ChatInterface._setup_stop_events = _setup_stop_events | |
gr.ChatInterface._setup_events = _setup_events | |
def vllm_abort(self: Any): | |
from vllm.sequence import SequenceStatus | |
scheduler = self.llm_engine.scheduler | |
for state_queue in [scheduler.waiting, scheduler.running, scheduler.swapped]: | |
for seq_group in state_queue: | |
# if seq_group.request_id == request_id: | |
# Remove the sequence group from the state queue. | |
state_queue.remove(seq_group) | |
for seq in seq_group.seqs: | |
if seq.is_finished(): | |
continue | |
scheduler.free_seq(seq, SequenceStatus.FINISHED_ABORTED) | |
def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]: | |
from vllm.outputs import RequestOutput | |
# Initialize tqdm. | |
if use_tqdm: | |
num_requests = self.llm_engine.get_num_unfinished_requests() | |
pbar = tqdm(total=num_requests, desc="Processed prompts") | |
# Run the engine. | |
outputs: Dict[str, RequestOutput] = {} | |
while self.llm_engine.has_unfinished_requests(): | |
step_outputs = self.llm_engine.step() | |
for output in step_outputs: | |
outputs[output.request_id] = output | |
if len(outputs) > 0: | |
yield outputs | |
def vllm_generate_stream( | |
self: Any, | |
prompts: Optional[Union[str, List[str]]] = None, | |
sampling_params: Optional[Any] = None, | |
prompt_token_ids: Optional[List[List[int]]] = None, | |
use_tqdm: bool = False, | |
) -> Dict[str, Any]: | |
"""Generates the completions for the input prompts. | |
NOTE: This class automatically batches the given prompts, considering | |
the memory constraint. For the best performance, put all of your prompts | |
into a single list and pass it to this method. | |
Args: | |
prompts: A list of prompts to generate completions for. | |
sampling_params: The sampling parameters for text generation. If | |
None, we use the default sampling parameters. | |
prompt_token_ids: A list of token IDs for the prompts. If None, we | |
use the tokenizer to convert the prompts to token IDs. | |
use_tqdm: Whether to use tqdm to display the progress bar. | |
Returns: | |
A list of `RequestOutput` objects containing the generated | |
completions in the same order as the input prompts. | |
""" | |
from vllm import LLM, SamplingParams | |
if prompts is None and prompt_token_ids is None: | |
raise ValueError("Either prompts or prompt_token_ids must be " | |
"provided.") | |
if isinstance(prompts, str): | |
# Convert a single prompt to a list. | |
prompts = [prompts] | |
if prompts is not None and prompt_token_ids is not None: | |
if len(prompts) != len(prompt_token_ids): | |
raise ValueError("The lengths of prompts and prompt_token_ids " | |
"must be the same.") | |
if sampling_params is None: | |
# Use default sampling params. | |
sampling_params = SamplingParams() | |
# Add requests to the engine. | |
if prompts is not None: | |
num_requests = len(prompts) | |
else: | |
num_requests = len(prompt_token_ids) | |
for i in range(num_requests): | |
prompt = prompts[i] if prompts is not None else None | |
if prompt_token_ids is None: | |
token_ids = None | |
else: | |
token_ids = prompt_token_ids[i] | |
self._add_request(prompt, sampling_params, token_ids) | |
# return self._run_engine(use_tqdm) | |
yield from _vllm_run_engine(self, use_tqdm) | |
# ! avoid saying | |
LANG_BLOCK_MESSAGE = """Sorry, the language you have asked is currently not supported. If you have questions in other supported languages, I'll be glad to help. \ | |
Please also consider clearing the chat box for a better experience.""" | |
KEYWORD_BLOCK_MESSAGE = "Sorry, I cannot fulfill your request. If you have any unrelated question, I'll be glad to help." | |
def block_zh( | |
message: str, | |
history: List[Tuple[str, str]] = None, | |
) -> str: | |
# relieve history base block | |
if LANG_BLOCK_HISTORY and history is not None and any((LANG_BLOCK_MESSAGE in x[1].strip()) for x in history): | |
return True | |
elif 'zh' in _detect_lang(message): | |
print(f'Detect zh: {message}') | |
return True | |
else: | |
return False | |
def log_responses(history, message, response): | |
pass | |
def safety_check(text, history=None, ) -> Optional[str]: | |
""" | |
Despite our effort in safety tuning and red teaming, our models may still generate harmful or illegal content. | |
This provides an additional security measure to enhance safety and compliance with local regulations. | |
""" | |
if len(KEYWORDS) > 0 and any(x in text.lower() for x in KEYWORDS): | |
return KEYWORD_BLOCK_MESSAGE | |
if BLOCK_ZH: | |
if history is not None: | |
if block_zh(text, history): | |
return LANG_BLOCK_MESSAGE | |
else: | |
if "zh" in _detect_lang(text): | |
return LANG_BLOCK_MESSAGE | |
return None | |
def chat_response_stream_multiturn( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float, | |
max_tokens: int, | |
frequency_penalty: float, | |
system_prompt: Optional[str] = SYSTEM_PROMPT_1 | |
) -> str: | |
from vllm import LLM, SamplingParams | |
"""Build multi turn | |
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] Answer <eos> | |
<bos>[INST] Prompt [/INST] | |
message is incoming prompt | |
history don't have the current messauge | |
""" | |
global llm, RES_PRINTED | |
assert llm is not None | |
assert system_prompt.strip() != '', f'system prompt is empty' | |
tokenizer = llm.get_tokenizer() | |
# force removing all | |
vllm_abort(llm) | |
temperature = float(temperature) | |
frequency_penalty = float(frequency_penalty) | |
max_tokens = int(max_tokens) | |
message = message.strip() | |
if len(message) == 0: | |
raise gr.Error("The message cannot be empty!") | |
message_safety = safety_check(message, history=history) | |
if message_safety is not None: | |
yield message_safety | |
return | |
# history will be appended with message later on | |
full_prompt = llama_chat_multiturn_sys_input_seq_constructor( | |
message, history, sys_prompt=system_prompt | |
) | |
if len(tokenizer.encode(full_prompt, add_special_tokens=False)) >= 1000: | |
raise gr.Error(f"Conversation or prompt is too long, please clear the chatbox or try shorter input.") | |
sampling_params = SamplingParams( | |
temperature=temperature, | |
max_tokens=max_tokens, | |
frequency_penalty=frequency_penalty, | |
stop=['<s>', '</s>', '<<SYS>>', '<</SYS>>', '[INST]', '[/INST]'] | |
) | |
cur_out = None | |
for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)): | |
if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0: | |
cur_out = cur_out.replace("\\n", "\n") | |
# optionally check safety, and respond | |
if STREAM_CHECK_MULTIPLE > 0 and j % STREAM_CHECK_MULTIPLE == 0: | |
message_safety = safety_check(cur_out, history=None) | |
if message_safety is not None: | |
yield message_safety | |
return | |
yield cur_out | |
assert len(gen) == 1, f'{gen}' | |
item = next(iter(gen.values())) | |
cur_out = item.outputs[0].text | |
print(f'@@@@@@@@@@\n{full_prompt}<<<{cur_out}>>>\n##########\n') | |
if cur_out is not None and "\\n" in cur_out: | |
print(f'double slash-n in cur_out:\n{cur_out}') | |
cur_out = cur_out.replace("\\n", "\n") | |
if cur_out is not None: | |
yield cur_out | |
message_safety = safety_check(cur_out, history=None) | |
if message_safety is not None: | |
yield message_safety | |
return | |
if LOG_RESPONSE: | |
log_responses(history, message, cur_out) | |
def debug_chat_response_echo( | |
message: str, | |
history: List[Tuple[str, str]], | |
temperature: float = 0.0, | |
max_tokens: int = 4096, | |
frequency_penalty: float = 0.4, | |
system_prompt: str = SYSTEM_PROMPT_1, | |
) -> str: | |
import time | |
time.sleep(0.5) | |
yield f"repeat: {message}" | |
def check_model_path(model_path) -> str: | |
assert os.path.exists(model_path), f'{model_path} not found' | |
ckpt_info = "None" | |
if os.path.isdir(model_path): | |
if os.path.exists(f'{model_path}/info.txt'): | |
with open(f'{model_path}/info.txt', 'r') as f: | |
ckpt_info = f.read() | |
print(f'Checkpoint info:\n{ckpt_info}\n-----') | |
else: | |
print(f'info.txt not found in {model_path}') | |
print(f'model path dir: {list(os.listdir(model_path))}') | |
return ckpt_info | |
def maybe_delete_folder(): | |
if IS_DELETE_FOLDER and DOWNLOAD_SNAPSHOT: | |
print(f'DELETE ALL FILES IN {DELETE_FOLDER}') | |
for filename in os.listdir(DELETE_FOLDER): | |
file_path = os.path.join(DELETE_FOLDER, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.unlink(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print('Failed to delete %s. Reason: %s' % (file_path, e)) | |
def launch(): | |
global demo, llm, DEBUG | |
model_desc = MODEL_DESC | |
model_path = MODEL_PATH | |
model_title = MODEL_TITLE | |
hf_model_name = HF_MODEL_NAME | |
tensor_parallel = TENSOR_PARALLEL | |
assert tensor_parallel > 0 , f'{tensor_parallel} invalid' | |
dtype = DTYPE | |
sys_prompt = SYSTEM_PROMPT_1 | |
max_tokens = MAX_TOKENS | |
temperature = TEMPERATURE | |
frequence_penalty = FREQUENCE_PENALTY | |
ckpt_info = "None" | |
print( | |
f'Launch config: {tensor_parallel=} / {dtype=} / {max_tokens} | {BLOCK_ZH=} ' | |
f'\n| model_title=`{model_title}` ' | |
f'\n| STREAM_YIELD_MULTIPLE={STREAM_YIELD_MULTIPLE} ' | |
f'\n| STREAM_CHECK_MULTIPLE={STREAM_CHECK_MULTIPLE} ' | |
f'\n| DISPLAY_MODEL_PATH={DISPLAY_MODEL_PATH} ' | |
f'\n| LANG_BLOCK_HISTORY={LANG_BLOCK_HISTORY} ' | |
f'\n| frequence_penalty={frequence_penalty} ' | |
f'\n| temperature={temperature} ' | |
f'\n| hf_model_name={hf_model_name} ' | |
f'\n| model_path={model_path} ' | |
f'\n| DOWNLOAD_SNAPSHOT={DOWNLOAD_SNAPSHOT} ' | |
f'\n| gpu_memory_utilization={gpu_memory_utilization} ' | |
f'\n| KEYWORDS={KEYWORDS} ' | |
f'\n| Sys={SYSTEM_PROMPT_1}' | |
f'\n| Desc={model_desc}' | |
) | |
if DEBUG: | |
model_desc += "\n<br>!!!!! This is in debug mode, responses will copy original" | |
response_fn = debug_chat_response_echo | |
print(f'Creating in DEBUG MODE') | |
else: | |
# ! load the model | |
if DOWNLOAD_SNAPSHOT: | |
print(f'Downloading from HF_MODEL_NAME={hf_model_name} -> {model_path}') | |
if HF_TOKEN is not None: | |
print(f'Load with HF_TOKEN: {HF_TOKEN}') | |
snapshot_download(hf_model_name, local_dir=model_path, use_auth_token=True, token=HF_TOKEN) | |
else: | |
snapshot_download(hf_model_name, local_dir=model_path) | |
import vllm | |
from vllm import LLM | |
print(F'VLLM: {vllm.__version__}') | |
ckpt_info = check_model_path(model_path) | |
print(f'Load path: {model_path} | {ckpt_info}') | |
if QUANTIZATION == 'awq': | |
print(F'Load model in int4 quantization') | |
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization, quantization="awq") | |
else: | |
llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel, gpu_memory_utilization=gpu_memory_utilization) | |
try: | |
print(llm.llm_engine.workers[0].model) | |
except Exception as e: | |
print(f'Cannot print model worker: {e}') | |
try: | |
llm.llm_engine.scheduler_config.max_model_len = 4096 | |
llm.llm_engine.scheduler_config.max_num_batched_tokens = 4096 | |
llm.llm_engine.tokenizer.add_special_tokens = False | |
except Exception as e: | |
print(f'Cannot set parameters: {e}') | |
print(f'Use system prompt:\n{sys_prompt}') | |
response_fn = chat_response_stream_multiturn | |
print(F'respond: {response_fn}') | |
demo = gr.ChatInterface( | |
response_fn, | |
chatbot=ChatBot( | |
label=MODEL_NAME, | |
bubble_full_width=False, | |
latex_delimiters=[ | |
{ "left": "$", "right": "$", "display": False}, | |
{ "left": "$$", "right": "$$", "display": True}, | |
] | |
), | |
textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200), | |
submit_btn=gr.Button(value='Submit', variant="primary", scale=0), | |
# ! consider preventing the stop button | |
stop_btn=None, | |
title=f"{model_title}", | |
description=f"{model_desc}", | |
additional_inputs=[ | |
gr.Number(value=temperature, label='Temperature (higher -> more random)'), | |
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'), | |
gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens)'), | |
# ! Remove the system prompt textbox to avoid jailbreaking | |
# gr.Textbox(value=sys_prompt, label='System prompt', lines=8) | |
], | |
) | |
demo.title = MODEL_NAME | |
with demo: | |
# gr.Markdown(warning_markdown) | |
gr.Markdown(cite_markdown) | |
if DISPLAY_MODEL_PATH: | |
gr.Markdown(path_markdown.format(model_path=model_path)) | |
demo.queue() | |
demo.launch(server_port=PORT) | |
def main(): | |
launch() | |
if __name__ == "__main__": | |
main() | |