Delete modeling_kraken.py
Browse files- modeling_kraken.py +0 -82
modeling_kraken.py
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
|
3 |
-
from configuration_kraken import KrakenConfig
|
4 |
-
import tokenizer_template_switch
|
5 |
-
|
6 |
-
class KrakenForCausalLM(PreTrainedModel):
|
7 |
-
config_class = KrakenConfig
|
8 |
-
|
9 |
-
def __init__(self, config):
|
10 |
-
super().__init__(config)
|
11 |
-
self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
|
12 |
-
self.models = self.load_expert_models(config.config_dict['models'], config.config_dict['quantization'])
|
13 |
-
self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
|
14 |
-
self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
|
15 |
-
self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
|
16 |
-
self.models_indices = config.config_dict['class_indices']
|
17 |
-
|
18 |
-
def load_expert_models(self, models_dict, quantization_dict):
|
19 |
-
models = {}
|
20 |
-
for key, name in models_dict.items():
|
21 |
-
quantization = quantization_dict.get(key)
|
22 |
-
if quantization == "8bit":
|
23 |
-
models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
|
24 |
-
elif quantization == "4bit":
|
25 |
-
models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
|
26 |
-
elif quantization == "awq":
|
27 |
-
models[key] = self.load_awq_model(name)
|
28 |
-
else:
|
29 |
-
models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
|
30 |
-
return models
|
31 |
-
|
32 |
-
def load_awq_model(self, name):
|
33 |
-
return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
|
34 |
-
|
35 |
-
def tokenize_inputs(self, text, model_key):
|
36 |
-
return self.tokenizers[model_key](text, return_tensors="pt")
|
37 |
-
|
38 |
-
def determine_model(self, text):
|
39 |
-
prediction = self.router(text)[0]["label"]
|
40 |
-
model_decision_index = self.models_indices[prediction]
|
41 |
-
model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
|
42 |
-
return model_keys[model_decision_index]
|
43 |
-
|
44 |
-
def expert_tokenizer(self, text):
|
45 |
-
model_key = self.determine_model(text)
|
46 |
-
return self.tokenizers[model_key]
|
47 |
-
|
48 |
-
|
49 |
-
def generate(self, input_ids, **generate_kwargs):
|
50 |
-
# Tokenize the input_ids
|
51 |
-
text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
|
52 |
-
|
53 |
-
msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
|
54 |
-
if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
|
55 |
-
# Delete the first element
|
56 |
-
msgs.pop(0)
|
57 |
-
# Check if the last element has the role 'assistant'
|
58 |
-
if msgs and msgs[-1]['role'] == 'assistant':
|
59 |
-
# Delete the last element
|
60 |
-
msgs.pop()
|
61 |
-
|
62 |
-
# Determine the model key using the existing routing logic
|
63 |
-
model_key = self.determine_model(text)
|
64 |
-
# Show the routing result
|
65 |
-
print(f"Choosing {model_key} ..")
|
66 |
-
# Retrieve the model from the dictionary
|
67 |
-
model = self.models[model_key]
|
68 |
-
|
69 |
-
mod_txt = self.tokenizers[model_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
70 |
-
current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
|
71 |
-
|
72 |
-
# Tokenize accordingly to the best model
|
73 |
-
|
74 |
-
tok = self.tokenizers[model_key](mod_txt, return_tensors="pt")
|
75 |
-
tok_input_ids = tok.input_ids.to(current_device)
|
76 |
-
tok_attention_mask = tok.attention_mask.to(current_device)
|
77 |
-
|
78 |
-
# Generate text using the retrieved model
|
79 |
-
return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|