Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,485 Bytes
43a7079 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import os
from .configs.model2path import MODEL2PATH
class MInferenceConfig:
ATTENTION_TYPES = [
"minference",
"minference_with_dense",
"static",
"dilated1",
"dilated2",
"streaming",
"inf_llm",
"vllm",
]
def __init__(
self,
attn_type: str = "minference",
model_name: str = None,
config_path: str = None,
starting_layer: int = -1,
kv_cache_cpu: bool = False,
use_snapkv: bool = False,
is_search: bool = False,
attn_kwargs: dict = {},
**kwargs,
):
super(MInferenceConfig, self).__init__()
assert (
attn_type in self.ATTENTION_TYPES
), f"The attention_type {attn_type} you specified is not supported."
self.attn_type = attn_type
self.config_path = self.update_config_path(config_path, model_name)
self.model_name = model_name
self.is_search = is_search
self.starting_layer = starting_layer
self.kv_cache_cpu = kv_cache_cpu
self.use_snapkv = use_snapkv
self.attn_kwargs = attn_kwargs
def update_config_path(self, config_path: str, model_name: str):
if config_path is not None:
return config_path
assert (
model_name in MODEL2PATH
), f"The model {model_name} you specified is not supported. You are welcome to add it and open a PR :)"
return MODEL2PATH[model_name]
|