DavidGF commited on
Commit
fc726c1
1 Parent(s): e1c643a

Delete modeling_kraken_lora.py

Browse files
Files changed (1) hide show
  1. modeling_kraken_lora.py +0 -86
modeling_kraken_lora.py DELETED
@@ -1,86 +0,0 @@
1
- import torch
2
- from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
3
- from configuration_kraken_lora import KrakenConfig
4
- import tokenizer_template_switch
5
- from peft import PeftModel, PeftConfig # Import necessary modules for LoRA
6
-
7
- class KrakenForCausalLM(PreTrainedModel):
8
- config_class = KrakenConfig
9
-
10
- def __init__(self, config):
11
- super().__init__(config)
12
- self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
13
- self.model = self.load_base_model(config.config_dict['models']['base'], config.config_dict['quantization']['base']) # Load only expert1 as the base model
14
- self.lora_adapters = config.config_dict['lora_adapters'] # Load LoRA adapter paths
15
- self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True, device_map="auto")
16
- self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True, device_map="auto")
17
- self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
18
- self.models_indices = config.config_dict['class_indices']
19
-
20
- def load_base_model(self, model_name, quantization):
21
- if quantization == "8bit":
22
- return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
23
- elif quantization == "4bit":
24
- return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
25
- elif quantization == "awq":
26
- return self.load_awq_model(model_name)
27
- else:
28
- return AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
29
-
30
- def load_awq_model(self, name):
31
- return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
32
-
33
- def load_lora_adapter(self, base_model, adapter_path):
34
- print("Loading adapter: "+adapter_path)
35
- return PeftModel.from_pretrained(base_model, adapter_path)
36
-
37
- def tokenize_inputs(self, text, adapter_key):
38
- return self.tokenizers[adapter_key](text, return_tensors="pt")
39
-
40
- def determine_adapter(self, text):
41
- prediction = self.router(text)[0]["label"]
42
- model_decision_index = self.models_indices[prediction]
43
- adapter_keys = ['lora_expert1', 'lora_expert2', 'lora_expert3', 'lora_expert4', 'lora_expert5']
44
- return adapter_keys[model_decision_index]
45
-
46
- def expert_tokenizer(self, text):
47
- adapter_key = self.determine_adapter(text)
48
- return self.tokenizers[adapter_key]
49
-
50
-
51
- def generate(self, input_ids, **generate_kwargs):
52
- # Tokenize the input_ids
53
- text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
54
-
55
- msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
56
- if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
57
- # Delete the first element
58
- msgs.pop(0)
59
- # Check if the last element has the role 'assistant'
60
- if msgs and msgs[-1]['role'] == 'assistant':
61
- # Delete the last element
62
- msgs.pop()
63
-
64
-
65
- # Determine the appropriate LoRA adapter
66
- adapter_key = self.determine_adapter(text)
67
- print(f"Choosing LoRA adapter for {adapter_key} ..")
68
- # Load and apply the LoRA adapter to the base model (expert1)
69
- lora_adapter_path = self.lora_adapters[adapter_key]
70
- model_with_lora = self.load_lora_adapter(self.model, lora_adapter_path)
71
-
72
- # Use the tokenizer for the selected expert to tokenize the inputs
73
- mod_txt = self.tokenizers[adapter_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
74
- current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
75
-
76
- # Tokenize accordingly to the best model
77
-
78
- tok = self.tokenizers[adapter_key](mod_txt, return_tensors="pt")
79
- tok_input_ids = tok.input_ids.to(current_device)
80
- tok_attention_mask = tok.attention_mask.to(current_device)
81
-
82
- # Generate text using the modified model
83
- return model_with_lora.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
84
-
85
-
86
-