DavidGF commited on
Commit
1491648
1 Parent(s): 70e003e

Delete modeling_kraken.py

Browse files
Files changed (1) hide show
  1. modeling_kraken.py +0 -82
modeling_kraken.py DELETED
@@ -1,82 +0,0 @@
1
- import torch
2
- from transformers import PreTrainedModel, AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TextClassificationPipeline
3
- from configuration_kraken import KrakenConfig
4
- import tokenizer_template_switch
5
-
6
- class KrakenForCausalLM(PreTrainedModel):
7
- config_class = KrakenConfig
8
-
9
- def __init__(self, config):
10
- super().__init__(config)
11
- self.tokenizers = {key: AutoTokenizer.from_pretrained(name, device_map="auto") for key, name in config.config_dict['tokenizers'].items()}
12
- self.models = self.load_expert_models(config.config_dict['models'], config.config_dict['quantization'])
13
- self.router_model = AutoModelForSequenceClassification.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
14
- self.tokenizer = AutoTokenizer.from_pretrained(config.config_dict['router'], trust_remote_code=True,device_map="auto")
15
- self.router = TextClassificationPipeline(model=self.router_model, tokenizer=self.tokenizer)
16
- self.models_indices = config.config_dict['class_indices']
17
-
18
- def load_expert_models(self, models_dict, quantization_dict):
19
- models = {}
20
- for key, name in models_dict.items():
21
- quantization = quantization_dict.get(key)
22
- if quantization == "8bit":
23
- models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_8bit=True, torch_dtype="auto")
24
- elif quantization == "4bit":
25
- models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", load_in_4bit=True, torch_dtype="auto")
26
- elif quantization == "awq":
27
- models[key] = self.load_awq_model(name)
28
- else:
29
- models[key] = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto", torch_dtype="auto")
30
- return models
31
-
32
- def load_awq_model(self, name):
33
- return AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, device_map="auto")
34
-
35
- def tokenize_inputs(self, text, model_key):
36
- return self.tokenizers[model_key](text, return_tensors="pt")
37
-
38
- def determine_model(self, text):
39
- prediction = self.router(text)[0]["label"]
40
- model_decision_index = self.models_indices[prediction]
41
- model_keys = ['expert1', 'expert2', 'expert3', 'expert4','expert5']
42
- return model_keys[model_decision_index]
43
-
44
- def expert_tokenizer(self, text):
45
- model_key = self.determine_model(text)
46
- return self.tokenizers[model_key]
47
-
48
-
49
- def generate(self, input_ids, **generate_kwargs):
50
- # Tokenize the input_ids
51
- text = self.tokenizer.batch_decode(input_ids, skip_special_tokens=False)[0]
52
-
53
- msgs = tokenizer_template_switch.recover_chat_messages(text, self.tokenizer)
54
- if msgs and msgs[0]['role'] == 'system' and msgs[0]['content']=='<|im_start|>system':
55
- # Delete the first element
56
- msgs.pop(0)
57
- # Check if the last element has the role 'assistant'
58
- if msgs and msgs[-1]['role'] == 'assistant':
59
- # Delete the last element
60
- msgs.pop()
61
-
62
- # Determine the model key using the existing routing logic
63
- model_key = self.determine_model(text)
64
- # Show the routing result
65
- print(f"Choosing {model_key} ..")
66
- # Retrieve the model from the dictionary
67
- model = self.models[model_key]
68
-
69
- mod_txt = self.tokenizers[model_key].apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
70
- current_device = input_ids.device if isinstance(input_ids, torch.Tensor) else 'cpu'
71
-
72
- # Tokenize accordingly to the best model
73
-
74
- tok = self.tokenizers[model_key](mod_txt, return_tensors="pt")
75
- tok_input_ids = tok.input_ids.to(current_device)
76
- tok_attention_mask = tok.attention_mask.to(current_device)
77
-
78
- # Generate text using the retrieved model
79
- return model.generate(tok_input_ids, attention_mask=tok_attention_mask, **generate_kwargs)
80
-
81
-
82
-