Florian commited on
Commit
5b2e6a5
·
1 Parent(s): 48f630f

first commit

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app.py +94 -0
  3. requirements.txt +4 -0
  4. src/BranchyModel.py +469 -0
  5. src/utils.py +57 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ model/*
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save this as app.py and run with `streamlit run app.py`
2
+ import streamlit as st
3
+ import torch
4
+ import pandas as pd
5
+
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+
8
+ from src.utils import generate_next_token, breaking_ties
9
+ from src.BranchyModel import BranchyModel
10
+
11
+ st.title("Multi-Head LLM Demo")
12
+
13
+ def add_and_run(token, head):
14
+ # Update pd with Head and mean of previous heads and actual head
15
+ head_list = st.session_state["computation_pd"]["Head"].to_list() + [head]
16
+ mean = sum(head_list) / len(head_list)
17
+ st.session_state["computation_pd"] = pd.concat([st.session_state["computation_pd"], pd.DataFrame({"Head": [head], "Mean": [mean], "Base model consumption": [st.session_state['head_number']]})], ignore_index=True)
18
+
19
+ st.session_state['current_sentence'] += token
20
+ _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
21
+
22
+ def reset():
23
+ st.session_state['computation_pd'] = pd.DataFrame(columns=["Head", "Mean", "Base model consumption"])
24
+ st.session_state['current_sentence'] = "The climate in"
25
+ _, st.session_state['logits'], _, st.session_state['head_tokens'] = generate_next_token(st.session_state.model, st.session_state.tokenizer, st.session_state['current_sentence'])
26
+
27
+ @st.cache_resource
28
+ def load_model(penalty_alpha):
29
+ penalty_map = {0.1:"model_20240118-144039.bin",
30
+ 0.5:"model_20240118-192548.bin",
31
+ 2:"model_20240118-211943.bin",
32
+ 5:"model_20240118-231333.bin",
33
+ 10:"model_20240119-010725.bin",
34
+ 20:"model_20240119-030115.bin",
35
+ 0:"model_20240119-135506.bin",
36
+ 1:"model_20240119-154900.bin",
37
+ -20: "model_20240208-072350.bin",
38
+ -10: "model_20240208-052958.bin",
39
+ -5: "model_20240208-033606.bin",
40
+ -2: "model_20240208-014211.bin",
41
+ -1: "model_20240207-234817.bin",
42
+ -0.5: "model_20240207-215423.bin",
43
+ -0.1: "model_20240207-200020.bin"}
44
+
45
+ model_str = "susnato/phi-1_5_dev"
46
+ model = AutoModelForCausalLM.from_pretrained(model_str).to("cuda:1")
47
+ tokenizer = AutoTokenizer.from_pretrained(model_str)
48
+
49
+ branch_locations = list(range(0, 23, 5))
50
+ model = BranchyModel(branch_locations= branch_locations, model= model).to("cuda:1")
51
+
52
+ # Load the specific model based on penalty_alpha
53
+ model_path = penalty_map.get(penalty_alpha)
54
+ if model_path:
55
+ model.load_state_dict(torch.load(model_path, map_location="cuda:1"))
56
+ else:
57
+ print("Invalid penalty_alpha. Using default model weights.")
58
+
59
+ return model, tokenizer
60
+
61
+
62
+ if "model" not in st.session_state or "tokenizer" not in st.session_state:
63
+ print("Loading model...")
64
+ st.session_state.model, st.session_state.tokenizer = load_model(penalty_alpha=-2) # Example penalty_alpha
65
+ st.session_state["head_number"] = len(st.session_state.model.branch_locations) + 1
66
+ print(f"Head number: {st.session_state['head_number']}")
67
+ # Session state to store the current sentence
68
+ if 'current_sentence' not in st.session_state:
69
+ reset()
70
+
71
+ # Create a container to hold the buttons
72
+ cols = st.columns(len(st.session_state.head_tokens)) # Create a column for each token
73
+
74
+ # Iterate through each head token and create a button in a separate column
75
+ for i, (col, token) in enumerate(zip(cols, st.session_state.head_tokens)):
76
+ col.button(f"{st.session_state['head_tokens'][i]}",
77
+ key=f"head_{i}",
78
+ use_container_width=True,
79
+ on_click=add_and_run,
80
+ args=(st.session_state['head_tokens'][i], i))
81
+
82
+
83
+ # Display the current sentence
84
+ st.markdown(f"{st.session_state['current_sentence']}")
85
+
86
+ # Reset button to start over
87
+ st.button('Reset', on_click=reset)
88
+
89
+ if 'computation_pd' in st.session_state:
90
+ st.line_chart(st.session_state['computation_pd'])
91
+ # get last element from a pd
92
+ saved_budget = 100 - ((st.session_state["computation_pd"]["Mean"].iloc[-1] * 100) / st.session_state["computation_pd"]["Base model consumption"].iloc[-1])
93
+ st.markdown(f"You saved **{saved_budget:.2f}%** of the base model consumption.")
94
+ #st.write(st.session_state['computation_pd'])
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ streamlit==1.31.0
2
+ torch==2.0.1
3
+ pandas==2.0.3
4
+ transformers==4.36.0
src/BranchyModel.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from transformers import PreTrainedModel
8
+ from transformers.cache_utils import Cache, DynamicCache
9
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
10
+ from transformers.utils import ModelOutput
11
+
12
+
13
+ @dataclass
14
+ class CausalBranchyLLMOutputWithPast(ModelOutput):
15
+ loss: Optional[torch.Tensor] = None
16
+ lm_loss: Optional[torch.Tensor] = None
17
+ head_loss: Optional[torch.Tensor] = None
18
+ logits: torch.Tensor = None
19
+ head_outputs: Optional[torch.Tensor] = None
20
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
21
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
23
+
24
+ class Branch(nn.Module):
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
28
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
29
+
30
+ def forward(self, x):
31
+ x = self.layernorm(x)
32
+ x = self.lm_head(x)
33
+ return x
34
+
35
+ class BranchyModel(PreTrainedModel):
36
+ """
37
+ This class is a wrapper for transformer models with added functionality for branchy networks.
38
+ It uses BranchyConfig to initialize a model and later will be extended to add branches.
39
+
40
+ Args:
41
+ branch_locations (List[int]): The locations of the branches in the model.
42
+ starts indexing from 0. Branch 0 is after layer 0.
43
+ model (PreTrainedModel): The underlying transformer model to wrap.
44
+
45
+ Returns:
46
+ A model instance with the given configuration.
47
+ """
48
+
49
+ def __init__(self, branch_locations, model, loss_type="kl_div", penality_weight=None):
50
+ super().__init__(model.config)
51
+ # Initialize the base transformer model
52
+ self.model = model
53
+ self.branch_locations = branch_locations
54
+ self.loss_type = loss_type
55
+ self.penality_weight = penality_weight
56
+ if self.loss_type == "penalized_cross_entropy":
57
+ assert self.penality_weight is not None, "penality_weight must be provided for penalized_cross_entropy loss"
58
+ # Get details on layering inside the model
59
+ if hasattr(self.model.config, "n_layer") or hasattr(
60
+ self.model.config, "num_hidden_layers"
61
+ ): # If there is no n_layer in the config, there might be ways to get it from the model itself
62
+ self.num_layers = (
63
+ self.model.config.n_layer
64
+ if hasattr(self.model.config, "n_layer")
65
+ else self.model.config.num_hidden_layers
66
+ )
67
+ else:
68
+ raise ValueError("cannot find n_layer in config")
69
+ # if no branch locations are specified, branch at every layer
70
+ if self.branch_locations is None:
71
+ self.branch_locations = list(range(self.num_layers - 1))
72
+
73
+ assert self.num_layers > 0, "The number of layers must be greater than 0"
74
+ assert (
75
+ len(self.branch_locations) < self.num_layers
76
+ ), "The number of branches must be less than the number of layers"
77
+ assert all(
78
+ [0 <= i < self.num_layers for i in self.branch_locations]
79
+ ), "The branch locations must be between 0 and num_layers"
80
+
81
+
82
+ # Make sure the base model is frozen
83
+ for param in self.model.parameters():
84
+ param.requires_grad = False
85
+
86
+ # Instantiate heads. Default: heads are copies of the lm_head
87
+ self.model.heads = torch.nn.ModuleList(
88
+ [
89
+ Branch(self.model.config) for _ in range(len(self.branch_locations))
90
+ ]
91
+ )
92
+
93
+ # initialize heads
94
+ for head in self.model.heads:
95
+ head.apply(self.model._init_weights)
96
+ # Make them trainable
97
+ for param in head.parameters():
98
+ param.requires_grad = True
99
+
100
+ self.post_init()
101
+
102
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
103
+ def prepare_inputs_for_generation(
104
+ self,
105
+ input_ids,
106
+ past_key_values=None,
107
+ attention_mask=None,
108
+ inputs_embeds=None,
109
+ **kwargs,
110
+ ):
111
+ if past_key_values is not None:
112
+ if isinstance(past_key_values, Cache):
113
+ cache_length = past_key_values.get_seq_length()
114
+ past_length = past_key_values.seen_tokens
115
+ max_cache_length = past_key_values.get_max_length()
116
+ else:
117
+ cache_length = past_length = past_key_values[0][0].shape[2]
118
+ max_cache_length = None
119
+
120
+ # Keep only the unprocessed tokens:
121
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
122
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
123
+ # input)
124
+ if (
125
+ attention_mask is not None
126
+ and attention_mask.shape[1] > input_ids.shape[1]
127
+ ):
128
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
129
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
130
+ # input_ids based on the past_length.
131
+ elif past_length < input_ids.shape[1]:
132
+ input_ids = input_ids[:, past_length:]
133
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
134
+
135
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
136
+ if (
137
+ max_cache_length is not None
138
+ and attention_mask is not None
139
+ and cache_length + input_ids.shape[1] > max_cache_length
140
+ ):
141
+ attention_mask = attention_mask[:, -max_cache_length:]
142
+
143
+ position_ids = kwargs.get("position_ids", None)
144
+ if attention_mask is not None and position_ids is None:
145
+ # create position_ids on the fly for batch generation
146
+ position_ids = attention_mask.long().cumsum(-1) - 1
147
+ position_ids.masked_fill_(attention_mask == 0, 1)
148
+ if past_key_values:
149
+ position_ids = position_ids[:, -input_ids.shape[1] :]
150
+
151
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
152
+ if inputs_embeds is not None and past_key_values is None:
153
+ model_inputs = {"inputs_embeds": inputs_embeds}
154
+ else:
155
+ model_inputs = {"input_ids": input_ids}
156
+
157
+ model_inputs.update(
158
+ {
159
+ "position_ids": position_ids,
160
+ "past_key_values": past_key_values,
161
+ "use_cache": kwargs.get("use_cache"),
162
+ "attention_mask": attention_mask,
163
+ "fixed_output_head": kwargs.get("fixed_output_head", None),
164
+ }
165
+ )
166
+ return model_inputs
167
+
168
+ def compute_self_supervision_loss(
169
+ self,
170
+ aux_logits: torch.Tensor,
171
+ lm_logits: torch.Tensor,
172
+ return_per_head: bool = False,
173
+ ) -> Dict[str, torch.Tensor]:
174
+ last_aux_logits = aux_logits[..., -1, :]
175
+ last_lm_logits = lm_logits[..., -1, :]
176
+
177
+ repeated_last_lm_logits = last_lm_logits.repeat(
178
+ last_aux_logits.shape[0], 1, 1, 1
179
+ )
180
+ losses = []
181
+ # Can be useful to have detailed loss per head for comparison of performance
182
+ if return_per_head:
183
+ for head_logit in last_aux_logits:
184
+ if self.loss_type == "kl_div":
185
+ losses.append(
186
+ nn.KLDivLoss(reduction="batchmean")(
187
+ F.log_softmax(head_logit, dim=-1),
188
+ F.softmax(last_lm_logits, dim=-1),
189
+ )
190
+ )
191
+ elif self.loss_type == "cross_entropy":
192
+ losses.append(
193
+ nn.CrossEntropyLoss(reduction="mean")(
194
+ head_logit, torch.argmax(last_lm_logits, dim=-1)
195
+ )
196
+ )
197
+ elif self.loss_type == "penalized_cross_entropy":
198
+ ce_loss = nn.CrossEntropyLoss(reduction="mean")(
199
+ head_logit, torch.argmax(last_lm_logits, dim=-1)
200
+ )
201
+ probas = F.softmax(head_logit, dim=-1)
202
+ entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1))
203
+ #losses.append(ce_loss - self.penality_weight * (1.0 / (1.0 + entropy)))
204
+ losses.append(ce_loss - self.penality_weight * entropy)
205
+ else:
206
+ raise ValueError(
207
+ "The loss type must be either kl_div or cross_entropy"
208
+ )
209
+ loss = torch.stack(losses, dim=0).mean(dim=-1)
210
+ else:
211
+ # Compute the KL divergence between the last auxiliary head and the last LM head
212
+ if self.loss_type == "kl_div":
213
+ loss = nn.KLDivLoss(reduction="batchmean")(
214
+ F.log_softmax(last_aux_logits.view(-1, self.config.vocab_size), dim=-1),
215
+ F.softmax(
216
+ repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
217
+ ),
218
+ )
219
+ elif self.loss_type == "cross_entropy":
220
+ loss = nn.CrossEntropyLoss(reduction="mean")(
221
+ last_aux_logits.view(-1, self.config.vocab_size),
222
+ torch.argmax(
223
+ repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
224
+ ),
225
+ )
226
+ elif self.loss_type == "penalized_cross_entropy":
227
+ ce_loss = nn.CrossEntropyLoss(reduction="mean")(
228
+ last_aux_logits.view(-1, self.config.vocab_size),
229
+ torch.argmax(
230
+ repeated_last_lm_logits.view(-1, self.config.vocab_size), dim=-1
231
+ ),
232
+ )
233
+ probas = F.softmax(
234
+ last_aux_logits.view(-1, self.config.vocab_size), dim=-1
235
+ )
236
+ entropy = torch.mean(-torch.sum(probas * torch.log(probas + 1e-8), dim=-1))
237
+ loss = ce_loss + self.penality_weight * entropy
238
+ else:
239
+ raise ValueError(
240
+ "The loss type must be either kl_div or cross_entropy"
241
+ )
242
+ if return_per_head:
243
+ return {"loss": loss, "aux_loss": torch.stack(losses)}
244
+ else:
245
+ return {"loss": loss, "aux_loss": None}
246
+
247
+ def forward(
248
+ self,
249
+ input_ids: torch.LongTensor = None,
250
+ attention_mask: Optional[torch.Tensor] = None,
251
+ position_ids: Optional[torch.LongTensor] = None,
252
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
253
+ inputs_embeds: Optional[torch.FloatTensor] = None,
254
+ labels: Optional[torch.LongTensor] = None,
255
+ use_cache: Optional[bool] = None,
256
+ output_attentions: Optional[bool] = None,
257
+ output_hidden_states: Optional[bool] = None,
258
+ return_dict: Optional[bool] = None,
259
+ self_supervision: Optional[bool] = None,
260
+ fixed_output_head: Optional[int] = None,
261
+ ):
262
+ output_attentions = (
263
+ output_attentions
264
+ if output_attentions is not None
265
+ else self.config.output_attentions
266
+ )
267
+ return_dict = (
268
+ return_dict if return_dict is not None else self.config.use_return_dict
269
+ )
270
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
271
+
272
+ if self_supervision:
273
+ output_hidden_states = True
274
+ return self.forward_for_training(
275
+ input_ids=input_ids,
276
+ attention_mask=attention_mask,
277
+ position_ids=position_ids,
278
+ past_key_values=past_key_values,
279
+ inputs_embeds=inputs_embeds,
280
+ labels=labels,
281
+ use_cache=use_cache,
282
+ output_attentions=output_attentions,
283
+ output_hidden_states=output_hidden_states,
284
+ return_dict=return_dict,
285
+ )
286
+ else:
287
+ return self.forward_for_inference(
288
+ input_ids=input_ids,
289
+ attention_mask=attention_mask,
290
+ position_ids=position_ids,
291
+ past_key_values=past_key_values,
292
+ inputs_embeds=inputs_embeds,
293
+ use_cache=use_cache,
294
+ return_dict=return_dict,
295
+ fixed_output_head=fixed_output_head,
296
+ )
297
+
298
+ def forward_for_inference(
299
+ self,
300
+ input_ids: torch.LongTensor = None,
301
+ attention_mask: Optional[torch.Tensor] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
304
+ inputs_embeds: Optional[torch.FloatTensor] = None,
305
+ use_cache: Optional[bool] = None,
306
+ return_dict: Optional[bool] = None,
307
+ fixed_output_head: Optional[int] = None,
308
+ ):
309
+ if fixed_output_head not in self.branch_locations and fixed_output_head is not None and fixed_output_head != -1:
310
+ raise ValueError(
311
+ "The fixed output head must be one of the branch locations"
312
+ )
313
+ # retrieve input_ids and inputs_embeds
314
+ if input_ids is not None and inputs_embeds is not None:
315
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
316
+ elif input_ids is not None:
317
+ batch_size, seq_length = input_ids.shape
318
+ elif inputs_embeds is not None:
319
+ batch_size, seq_length, _ = inputs_embeds.shape
320
+ else:
321
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
322
+
323
+ past_key_values_length = 0
324
+
325
+ if use_cache:
326
+ use_legacy_cache = not isinstance(past_key_values, Cache)
327
+ if use_legacy_cache:
328
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
329
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
330
+
331
+ if position_ids is None:
332
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
333
+ position_ids = torch.arange(
334
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
335
+ )
336
+ position_ids = position_ids.unsqueeze(0)
337
+
338
+ if inputs_embeds is None:
339
+ inputs_embeds = self.model.model.embed_tokens(input_ids)
340
+
341
+ inputs_embeds = self.model.model.embed_dropout(inputs_embeds)
342
+
343
+ # Attention mask.
344
+ if self.model.model._use_flash_attention_2:
345
+ # 2d mask is passed through the layers
346
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
347
+ else:
348
+ # 4d mask is passed through the layers
349
+ attention_mask = _prepare_4d_causal_attention_mask(
350
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
351
+ )
352
+ all_head_logits = []
353
+ hidden_states = inputs_embeds
354
+ is_early_exited = False
355
+ for layer_idx, decoder_layer in enumerate(self.model.model.layers):
356
+ layer_outputs = decoder_layer(
357
+ hidden_states,
358
+ attention_mask=attention_mask,
359
+ position_ids=position_ids,
360
+ past_key_value=past_key_values,
361
+ use_cache=use_cache,
362
+ )
363
+
364
+ hidden_states = layer_outputs[0]
365
+
366
+ if use_cache:
367
+ next_decoder_cache = layer_outputs[1]
368
+
369
+ if fixed_output_head is not None and layer_idx == fixed_output_head:
370
+ # find postion of layer idx in branch_locations
371
+ branch_idx = self.branch_locations.index(layer_idx)
372
+ logits = self.model.heads[branch_idx](hidden_states)
373
+ is_early_exited = True
374
+ break
375
+ elif fixed_output_head == -1 and layer_idx in self.branch_locations:
376
+ # -1 means output all heads
377
+ branch_idx = self.branch_locations.index(layer_idx)
378
+ logits = self.model.heads[branch_idx](hidden_states)
379
+ all_head_logits.append(logits)
380
+
381
+ if not is_early_exited:
382
+ hidden_states = self.model.model.final_layernorm(hidden_states)
383
+ logits = self.model.lm_head(hidden_states)
384
+ if fixed_output_head == -1:
385
+ all_head_logits.append(logits)
386
+ all_head_logits = torch.stack(all_head_logits, dim=0)
387
+ next_cache = None
388
+ if use_cache:
389
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
390
+ if not return_dict:
391
+ return tuple(v for v in [logits, next_cache] if v is not None)
392
+
393
+ return CausalBranchyLLMOutputWithPast(
394
+ logits=logits,
395
+ head_outputs=all_head_logits,
396
+ past_key_values=next_cache,
397
+ )
398
+
399
+ def forward_for_training(
400
+ self,
401
+ input_ids: torch.LongTensor = None,
402
+ attention_mask: Optional[torch.Tensor] = None,
403
+ position_ids: Optional[torch.LongTensor] = None,
404
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
405
+ inputs_embeds: Optional[torch.FloatTensor] = None,
406
+ labels: Optional[torch.LongTensor] = None,
407
+ use_cache: Optional[bool] = None,
408
+ output_attentions: Optional[bool] = None,
409
+ output_hidden_states: Optional[bool] = None,
410
+ return_dict: Optional[bool] = None,
411
+ ):
412
+
413
+ if not output_hidden_states:
414
+ raise ValueError("output_hidden_states must be True for BranchyLLM")
415
+ if labels is not None:
416
+ raise NotImplementedError("BranchyLLM only supports self-supervision")
417
+ outputs = self.model(
418
+ input_ids=input_ids,
419
+ attention_mask=attention_mask,
420
+ position_ids=position_ids,
421
+ past_key_values=past_key_values,
422
+ inputs_embeds=inputs_embeds,
423
+ use_cache=use_cache,
424
+ output_attentions=output_attentions,
425
+ output_hidden_states=output_hidden_states,
426
+ return_dict=return_dict,
427
+ )
428
+ if not hasattr(outputs, "hidden_states") or outputs.hidden_states is None:
429
+ raise ValueError("The model must return hidden states")
430
+ hidden_states = outputs.hidden_states
431
+
432
+
433
+ heads_logits = []
434
+ for i, branch in enumerate(self.branch_locations):
435
+ heads_logits.append(
436
+ self.model.heads[i](
437
+ hidden_states[branch]
438
+ )
439
+ )
440
+ lm_logits = self.model.lm_head(hidden_states[-1])
441
+
442
+ heads_logits = torch.stack(heads_logits, dim=0).float()
443
+ lm_logits = lm_logits.float()
444
+ logits = torch.cat([heads_logits, lm_logits.unsqueeze(0)], dim=0)
445
+
446
+ loss = None
447
+ lm_loss = None
448
+ aux_loss = None
449
+
450
+ losses = self.compute_self_supervision_loss(
451
+ heads_logits, lm_logits, return_per_head=True
452
+ )
453
+ loss = losses["loss"]
454
+ if losses["aux_loss"] is not None:
455
+ aux_loss = losses["aux_loss"]
456
+
457
+ if not return_dict:
458
+ output = (logits,) + outputs[1:]
459
+ return ((loss, aux_loss, lm_loss) + output) if loss is not None else output
460
+
461
+ return CausalBranchyLLMOutputWithPast(
462
+ loss=loss,
463
+ lm_loss=lm_loss,
464
+ head_loss=aux_loss,
465
+ logits=logits,
466
+ past_key_values=outputs.past_key_values,
467
+ hidden_states=outputs.hidden_states,
468
+ attentions=outputs.attentions,
469
+ )
src/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def generate_next_token(model, tokenizer, input, method='greedy'):
4
+ """
5
+ Generate the next token of a sequence using the given model and tokenizer.
6
+ Specific for multi branched models.
7
+ Only output token from last head.
8
+
9
+ Args:
10
+ model (torch.nn.Module): The model to use for generation.
11
+ tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for generation.
12
+ input (str): The input text to generate from.
13
+
14
+ Returns:
15
+ token (str): The next token in the sequence.
16
+ logits (torch.Tensor): The logits of the next token. of shape[Head, vocab_size]
17
+ new_sequence (str): The new sequence after adding the next token.
18
+ """
19
+ device = model.device
20
+ input_ids = tokenizer.encode(input, return_tensors="pt").to(device)
21
+ model.eval()
22
+ logits = model(input_ids, fixed_output_head=-1).head_outputs[..., -1, :].squeeze(1) # squeeze batch dimension as it is 1 new shape is (head_count, vocab_size)
23
+ if logits == []:
24
+ raise ValueError("Model does not have head_outputs")
25
+ if method == 'greedy':
26
+ head_tokens = torch.argmax(logits, dim=-1)
27
+ elif method == 'sample':
28
+ head_tokens = torch.multinomial(torch.nn.functional.softmax(logits, dim=-1), num_samples=1)
29
+ elif method == 'top_k':
30
+ k = 5
31
+ top_k = torch.topk(logits, k, dim=-1)
32
+ top_k_logits, top_k_indices = top_k.values, top_k.indices
33
+ top_k_probs = torch.nn.functional.softmax(top_k_logits, dim=-1)
34
+ head_tokens = top_k_indices[torch.arange(top_k_probs.shape[0]), torch.multinomial(top_k_probs, num_samples=1).squeeze()]
35
+ elif method == 'top_p':
36
+ # logits is of shape [batch, vocab_size]
37
+ p = 0.9
38
+ probs = torch.nn.functional.softmax(logits, dim=-1)
39
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
40
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
41
+ sorted_indices_to_remove = cumulative_probs > p
42
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
+ sorted_indices_to_remove[..., 0] = 0
44
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
45
+ tmp_logits = logits.clone()
46
+ for i in range(logits.shape[0]):
47
+ tmp_logits[i, indices_to_remove[i]] = float('-inf')
48
+ head_tokens = torch.multinomial(torch.nn.functional.softmax(tmp_logits, dim=-1), num_samples=1).squeeze()
49
+ else:
50
+ raise ValueError(f"Unknown method: {method}")
51
+ head_tokens = tokenizer.batch_decode(head_tokens) # Treat head dim as batch dim
52
+ new_sequence = input + head_tokens[-1]
53
+ return head_tokens[-1], logits, new_sequence, head_tokens
54
+
55
+
56
+ def breaking_ties(tensor):
57
+ return torch.sub(torch.topk(tensor, 2, dim=-1).values[..., 0], torch.topk(tensor, 2, dim=-1).values[..., 1]).squeeze()