DavidGF commited on
Commit
99e0c29
1 Parent(s): fc726c1

Upload 4 files

Browse files
kraken_model/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at
4
+ #
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ #
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ...utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_torch_available,
20
+ )
21
+
22
+
23
+ _import_structure = {
24
+ "configuration_kraken_lora": ["KrakenConfig"],
25
+ }
26
+
27
+ try:
28
+ if not is_torch_available():
29
+ raise OptionalDependencyNotAvailable()
30
+ except OptionalDependencyNotAvailable:
31
+ pass
32
+ else:
33
+ _import_structure["modeling_kraken_lora"] = [
34
+ "KrakenForCausalLM",
35
+ ]
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from .configuration_kraken_lora import KrakenConfig
40
+
41
+ try:
42
+ if not is_torch_available():
43
+ raise OptionalDependencyNotAvailable()
44
+ except OptionalDependencyNotAvailable:
45
+ pass
46
+ else:
47
+ from .modeling_kraken_lora import (
48
+ KrakenForCausalLM,
49
+ )
50
+
51
+
52
+ else:
53
+ import sys
54
+
55
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
kraken_model/configuration_kraken_lora.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class KrakenConfig(PretrainedConfig):
4
+ model_type = "kraken"
5
+
6
+ def __init__(self, config_dict=None, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.config_dict = config_dict or {}
kraken_model/modeling_kraken_lora.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
3
+ from configuration_kraken_lora import KrakenConfig
4
+ import tokenizer_template_switch
5
+ from peft import PeftModel, PeftConfig # Import necessary modules for LoRA
6
+
7
+ class KrakenForCausalLM(PreTrainedModel):
8
+ config_class = KrakenConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
13
+ self.model = self.load_base_model(config.config_dict['models']['base'], config.config_dict['quantization']['base']) # Load only expert1 as the base model
14
+ self.lora_adapters = config.config_dict['lora_adapters'] # Load LoRA adapter paths
15
+ self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True, device_map="auto")
16
+ self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True, device_map="auto")
17
+ self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
18
+ self.models_indices = config.config_dict['class_indices']
19
+
20
+ def load_base_model(self, model_name, quantization):
21
+ if quantization == "8bit":
22
+ return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
23
+ elif quantization == "4bit":
24
+ return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
25
+ elif quantization == "awq":
26
+ return self.load_awq_model(model_name)
27
+ else:
28
+ return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
29
+
30
+ def load_awq_model(self, name):
31
+ return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
32
+
33
+ def load_lora_adapter(self, base_model, adapter_path):
34
+ print("Loading adapter: "+adapter_path)
35
+ return PeftModel.from_pretrained(base_model, adapter_path)
36
+
37
+ def tokenize_inputs(self, text, adapter_key):
38
+ return self.tokenizers[adapter_key](text, return_tensors="pt")
39
+
40
+ def determine_adapter(self, text):
41
+ prediction = self.router(text)[0]["label"]
42
+ model_decision_index = self.models_indices[prediction]
43
+ adapter_keys = ['lora_expert1', 'lora_expert2', 'lora_expert3', 'lora_expert4', 'lora_expert5']
44
+ return adapter_keys[model_decision_index]
45
+
46
+ def expert_tokenizer(self, text):
47
+ adapter_key = self.determine_adapter(text)
48
+ return self.tokenizers[adapter_key]
49
+
50
+
51
+ def generate(self, input_ids, **generate_kwargs):
52
+ # Tokenize the input_ids
53
+ text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
54
+
55
+ msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
56
+ if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
57
+ # Delete the first element
58
+ msgs.pop(0)
59
+ # Check if the last element has the role 'assistant'
60
+ if msgs and msgs[-1]['role'] == 'assistant':
61
+ # Delete the last element
62
+ msgs.pop()
63
+
64
+
65
+ # Determine the appropriate LoRA adapter
66
+ adapter_key = self.determine_adapter(text)
67
+ print(f"Choosing LoRA adapter for {adapter_key} ..")
68
+ # Load and apply the LoRA adapter to the base model (expert1)
69
+ lora_adapter_path = self.lora_adapters[adapter_key]
70
+ model_with_lora = self.load_lora_adapter(self.model, lora_adapter_path)
71
+
72
+ # Use the tokenizer for the selected expert to tokenize the inputs
73
+ mod_txt = self.tokenizers[adapter_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
74
+ current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
75
+
76
+ # Tokenize accordingly to the best model
77
+
78
+ tok = self.tokenizers[adapter_key](mod_txt, return_tensors="pt")
79
+ tok_input_ids = tok.input_ids.to(current_device)
80
+ tok_attention_mask = tok.attention_mask.to(current_device)
81
+
82
+ # Generate text using the modified model
83
+ return model_with_lora.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
84
+
85
+
86
+
kraken_model/tokenizer_template_switch.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from transformers import AutoTokenizer
3
+
4
+ def extract_separators(template):
5
+ """
6
+ Extracts separators used in the tokenization template.
7
+ """
8
+ # Adjust the regex to correctly match the specific pattern between '{{' and '+ message["content"] +'
9
+ pattern = r"\{\{\s*([^{}]+?)\s*\+ message\['content'\]"
10
+ matches = re.findall(pattern, template)
11
+ # Clean up any extra spaces and return the matches
12
+ separators = [match.strip() for match in matches]
13
+
14
+ if any("message['role']" in element for element in separators):
15
+ roles = ["system", "user", "assistant"]
16
+ separators_ = []
17
+ for role in roles:
18
+ separators_.append(separators[0].replace(" + message['role'] + ", role).replace("'",""))
19
+ return separators_
20
+
21
+ return separators
22
+
23
+ def detect_eos_token(jinja_template, tokenizer):
24
+ if "<|im_end|>" in jinja_template:
25
+ return "<|im_end|>"
26
+ if "</s>" in jinja_template:
27
+ return "</s>"
28
+ if "eos_token" in jinja_template:
29
+ return tokenizer.eos_token
30
+ else:
31
+ return "<|endoftext|>"
32
+
33
+ def recover_messages(formatted_message, separators, eos_token):
34
+ """
35
+ Recovers the original messages from the formatted message string.
36
+ """
37
+ # Split the formatted message using the end-of-string token
38
+ split_messages = formatted_message.split(eos_token)
39
+
40
+ # Remove the last empty string if it exists due to a trailing separator
41
+ if split_messages and split_messages[-1].strip() == '':
42
+ split_messages.pop()
43
+
44
+ # Prepare the list to hold the recovered messages
45
+ recovered_messages = []
46
+
47
+ # Define roles after the first message, alternating between "user" and "assistant"
48
+ alternate_roles = ["user", "assistant"]
49
+
50
+ # Iterate over the split messages
51
+ for index, message_content in enumerate(split_messages):
52
+ # Determine the role, starting with "system" for the first message
53
+ # then alternating between "user" and "assistant" for subsequent messages
54
+ if index == 0:
55
+ role = "system"
56
+ else:
57
+ role = alternate_roles[(index - 1) % 2]
58
+
59
+ # Clean the message content by removing leading/trailing whitespace and separators
60
+ clean_content = message_content.strip()
61
+ for separator in separators:
62
+ clean_content = clean_content.replace(separator.strip("'"), '', 1).strip()
63
+
64
+ # Append the cleaned message with its role to the list
65
+ recovered_messages.append({"role": role, "content": clean_content})
66
+
67
+ return recovered_messages
68
+
69
+ def recover_chat_messages(tokenized_chat, tokenizer):
70
+ """
71
+ Given a tokenized_chat string and a tokenizer, returns the list of message dictionaries.
72
+ """
73
+ jinja_template = tokenizer.chat_template
74
+ separators = extract_separators(jinja_template)
75
+ eos_token = eos_token = detect_eos_token(jinja_template, tokenizer)
76
+ recovered_messages = recover_messages(tokenized_chat, separators, eos_token)
77
+ return recovered_messages
78
+
79
+ # Example usage
80
+ if __name__ == "__main__":
81
+ checkpoint = "Qwen/Qwen1.5-0.5B"
82
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
83
+
84
+ messages = [
85
+ {
86
+ "role": "system",
87
+ "content": "You are a friendly chatbot who always responds in the style of a pirate",
88
+ },
89
+ {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
90
+ ]
91
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=False)
92
+ print(tokenized_chat)
93
+
94
+ recovered_messages = recover_chat_messages(tokenized_chat, tokenizer)
95
+ print(recovered_messages)