jingyaogong commited on
Commit
1107462
·
verified ·
1 Parent(s): d78c45a

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +49 -155
model.py CHANGED
@@ -1,6 +1,8 @@
1
  import math
2
  import struct
3
  import inspect
 
 
4
  from .LMConfig import LMConfig
5
  from typing import Any, Optional, Tuple
6
  import numpy as np
@@ -66,93 +68,66 @@ class Attention(nn.Module):
66
  super().__init__()
67
  self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
68
  assert args.n_heads % self.n_kv_heads == 0
69
- model_parallel_size = 1
70
- self.n_local_heads = args.n_heads // model_parallel_size
71
- self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
72
  self.n_rep = self.n_local_heads // self.n_local_kv_heads
73
  self.head_dim = args.dim // args.n_heads
74
  self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
75
  self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
76
  self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
77
  self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
 
78
  self.attn_dropout = nn.Dropout(args.dropout)
79
  self.resid_dropout = nn.Dropout(args.dropout)
80
  self.dropout = args.dropout
81
-
82
- # use flash attention or a manual implementation?
83
  self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
84
 
85
- if not self.flash:
86
- # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
87
- mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
88
- mask = torch.triu(mask, diagonal=1)
89
- self.register_buffer("mask", mask)
90
-
91
- def forward(
92
- self,
93
- x: torch.Tensor,
94
- pos_cis: torch.Tensor,
95
- use_kv_cache: bool = False,
96
- past_kv: Tuple[torch.Tensor] = None
97
- ):
98
  bsz, seqlen, _ = x.shape
99
- # QKV
100
- # inference
101
- if use_kv_cache:
102
- # 只计算最后一个token的Q
103
- current_token = x[:, -1:, :]
104
-
105
- if not past_kv:
106
- xq = self.wq(x)
107
- xk, xv = self.wk(x), self.wv(x)
108
- else:
109
- past_key, past_value = past_kv
110
- xq = torch.cat((torch.zeros_like(x[:, :-1, :]), self.wq(current_token)), dim=1)
111
- xk = torch.cat((past_key, self.wk(current_token)), dim=1)
112
- xv = torch.cat((past_value, self.wv(current_token)), dim=1)
113
 
114
- past_kv = (xk, xv)
115
- else:
116
- xq = self.wq(x)
117
- xk, xv = self.wk(x), self.wv(x)
118
 
119
  xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
120
  xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
121
  xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
122
 
123
- # RoPE relative positional embeddings
124
  xq, xk = apply_rotary_emb(xq, xk, pos_cis)
125
 
126
- # grouped multiquery attention: expand out keys and values
 
 
 
 
 
 
127
  xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
128
  xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
129
 
130
- # make heads into a batch dimension
131
- xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
132
  xk = xk.transpose(1, 2)
133
  xv = xv.transpose(1, 2)
134
 
135
- # flash implementation
136
- if self.flash:
137
  output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
138
  dropout_p=self.dropout if self.training else 0.0,
139
  is_causal=True)
140
  else:
141
- # manual implementation
142
  scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
143
- assert hasattr(self, 'mask')
144
  scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
145
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
146
  scores = self.attn_dropout(scores)
147
  output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
148
 
149
- # restore time as batch dimension and concat heads
150
  output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
151
 
152
- # final projection into the residual stream
153
  output = self.wo(output)
154
  output = self.resid_dropout(output)
155
- return output, past_kv
156
 
157
 
158
  class FeedForward(nn.Module):
@@ -182,7 +157,6 @@ class MoEGate(nn.Module):
182
  self.alpha = config.aux_loss_alpha
183
  self.seq_aux = config.seq_aux
184
 
185
- # topk selection algorithm
186
  self.norm_topk_prob = config.norm_topk_prob
187
  self.gating_dim = config.dim
188
  self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
@@ -194,7 +168,7 @@ class MoEGate(nn.Module):
194
 
195
  def forward(self, hidden_states):
196
  bsz, seq_len, h = hidden_states.shape
197
- ### compute gating score
198
  hidden_states = hidden_states.view(-1, h)
199
  logits = F.linear(hidden_states, self.weight, None)
200
  if self.scoring_func == 'softmax':
@@ -202,19 +176,15 @@ class MoEGate(nn.Module):
202
  else:
203
  raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
204
 
205
- ### select top-k experts
206
  topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
207
 
208
- ### norm gate to sum 1
209
  if self.top_k > 1 and self.norm_topk_prob:
210
  denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
211
  topk_weight = topk_weight / denominator
212
 
213
- ### expert-level computation auxiliary loss
214
  if self.training and self.alpha > 0.0:
215
  scores_for_aux = scores
216
  aux_topk = self.top_k
217
- # always compute aux loss based on the naive greedy topk method
218
  topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
219
  if self.seq_aux:
220
  scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
@@ -331,11 +301,10 @@ class TransformerBlock(nn.Module):
331
  dropout=args.dropout,
332
  )
333
 
334
- def forward(self, x, pos_cis, use_kv_cache=False, past_kv: Tuple[torch.Tensor] = None):
335
- attn_res, past_kv = self.attention(self.attention_norm(x), pos_cis, use_kv_cache, past_kv)
336
- h = x + attn_res
337
  out = h + self.feed_forward(self.ffn_norm(h))
338
- return out, past_kv
339
 
340
 
341
  class Transformer(PreTrainedModel):
@@ -357,22 +326,16 @@ class Transformer(PreTrainedModel):
357
  self.layers.append(TransformerBlock(layer_id, params))
358
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
359
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
360
-
361
- # share the unembedding parameters with the embedding parameters
362
- self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
363
-
364
- # some useful precompute for the RoPE relative positional embeddings
365
  pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
366
  self.register_buffer("pos_cis", pos_cis, persistent=False)
367
 
368
- # init all weights
369
  self.apply(self._init_weights)
370
- # apply special scaled init to the residual projections, per GPT-2 paper
371
  for pn, p in self.named_parameters():
372
  if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
373
  torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
374
 
375
- # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
376
  self.last_loss = None
377
  self.OUT = CausalLMOutputWithPast()
378
 
@@ -384,78 +347,64 @@ class Transformer(PreTrainedModel):
384
  elif isinstance(module, nn.Embedding):
385
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
386
 
387
- def forward(self, tokens: Optional[torch.Tensor] = None,
388
- targets: Optional[torch.Tensor] = None,
389
- use_kv_cache=False, past_kvs=None, **keyargs):
390
- if past_kvs is None:
391
- past_kvs = [None for _ in range(self.n_layers)]
392
  if 'input_ids' in keyargs:
393
  tokens = keyargs['input_ids']
394
  if 'attention_mask' in keyargs:
395
  targets = keyargs['attention_mask']
 
 
396
 
397
  _bsz, seqlen = tokens.shape
398
  h = self.tok_embeddings(tokens)
399
  h = self.dropout(h)
400
- pos_cis = self.pos_cis[:seqlen]
401
  for idx, layer in enumerate(self.layers):
402
- h, past_kvs[idx] = layer(h, pos_cis, use_kv_cache, past_kvs[idx])
403
 
404
  h = self.norm(h)
405
 
406
  if targets is not None:
407
- # if we are given some desired targets also calculate the loss
408
  logits = self.output(h)
409
  self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
410
  else:
411
- # inference-time mini-optimization: only forward the output on the very last position
412
- logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
413
  self.last_loss = None
414
 
415
  self.OUT.__setitem__('logits', logits)
416
  self.OUT.__setitem__('last_loss', self.last_loss)
417
-
418
- if use_kv_cache:
419
- return self.OUT, past_kvs
420
  return self.OUT
421
 
422
-
423
  @torch.inference_mode()
424
- def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=None, stream=True, repetition_penalty=1.):
 
425
  index = idx.shape[1]
426
- use_kv_cache = True
427
- past_kvs = [None for _ in range(self.n_layers)]
428
  while idx.shape[1] < max_new_tokens - 1:
429
- # if the sequence context is growing too long we must crop it at block_size
430
- idx_cond = idx # if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
431
- # forward the model to get the logits for the index in the sequence
432
- inference_res = self(idx_cond, use_kv_cache=use_kv_cache, past_kvs=past_kvs)
433
- if use_kv_cache:
434
- logits, past_kvs = inference_res[0].logits, inference_res[1]
435
  else:
436
- logits = inference_res.logits
437
 
438
- logits = logits[:, -1, :] # crop to just the final time step
 
439
 
440
- # Apply repetition penalty
441
  for token in set(idx.tolist()[0]):
442
- logits[:, token] /= repetition_penalty
443
 
444
  if temperature == 0.0:
445
- # "sample" the single most likely index
446
- __, idx_next = torch.topk(logits, k=1, dim=-1)
447
  else:
448
- # pluck the logits at the final step and scale by desired temperature
449
  logits = logits / temperature
450
- # optionally crop the logits to only the top k options
451
  if top_k is not None:
452
- v, __ = torch.topk(logits, min(top_k, logits.size(-1)))
453
  logits[logits < v[:, [-1]]] = -float('Inf')
454
 
455
- # apply softmax to convert logits to (normalized) probabilities
456
  probs = F.softmax(logits, dim=-1)
457
  idx_next = torch.multinomial(probs, num_samples=1, generator=None)
458
- # append sampled index to the running sequence and continue
459
  if idx_next == eos:
460
  break
461
 
@@ -468,63 +417,8 @@ class Transformer(PreTrainedModel):
468
 
469
  @torch.inference_mode()
470
  def eval_answer(self, idx):
471
- # if the sequence context is growing too long we must crop it at block_size
472
  idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
473
- # forward the model to get the logits for the index in the sequence
474
- past_kvs = [None for _ in range(self.n_layers)]
475
- inference_res = self(idx_cond, use_kv_cache=False, past_kvs=past_kvs)
476
  logits = inference_res.logits
477
  logits = logits[:, -1, :]
478
  return logits
479
-
480
- def export(self, filepath='model.bin'):
481
- """export the model weights in fp32 into .bin file to be read from C"""
482
- f = open(filepath, 'wb')
483
-
484
- def serialize(t):
485
- d = t.detach().cpu().view(-1).numpy().astype(np.float32)
486
- b = struct.pack(f'{len(d)}f', *d)
487
- f.write(b)
488
-
489
- # first write out the header
490
- hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
491
- p = self.params
492
- n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
493
- header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
494
- n_kv_heads, p.vocab_size, p.max_seq_len)
495
- f.write(header)
496
-
497
- # next write out the embedding weights
498
- serialize(self.tok_embeddings.weight)
499
-
500
- # now all the layers
501
- # attention weights
502
- for layer in self.layers:
503
- serialize(layer.attention_norm.weight)
504
- for layer in self.layers:
505
- serialize(layer.attention.wq.weight)
506
- for layer in self.layers:
507
- serialize(layer.attention.wk.weight)
508
- for layer in self.layers:
509
- serialize(layer.attention.wv.weight)
510
- for layer in self.layers:
511
- serialize(layer.attention.wo.weight)
512
- # ffn weights
513
- for layer in self.layers:
514
- serialize(layer.ffn_norm.weight)
515
- for layer in self.layers:
516
- serialize(layer.feed_forward.w1.weight)
517
- for layer in self.layers:
518
- serialize(layer.feed_forward.w2.weight)
519
- for layer in self.layers:
520
- serialize(layer.feed_forward.w3.weight)
521
- # final rmsnorm
522
- serialize(self.norm.weight)
523
- # note: no need to write final classifier weights due to weight sharing
524
- # pos_cis
525
- serialize(self.freqs_cos[:p.max_seq_len])
526
- serialize(self.freqs_sin[:p.max_seq_len])
527
-
528
- # write to binary file
529
- f.close()
530
- print(f"wrote {filepath}")
 
1
  import math
2
  import struct
3
  import inspect
4
+ import time
5
+
6
  from .LMConfig import LMConfig
7
  from typing import Any, Optional, Tuple
8
  import numpy as np
 
68
  super().__init__()
69
  self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
70
  assert args.n_heads % self.n_kv_heads == 0
71
+ self.n_local_heads = args.n_heads
72
+ self.n_local_kv_heads = self.n_kv_heads
 
73
  self.n_rep = self.n_local_heads // self.n_local_kv_heads
74
  self.head_dim = args.dim // args.n_heads
75
  self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
76
  self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
77
  self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
78
  self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
79
+ self.k_cache, self.v_cache = None, None
80
  self.attn_dropout = nn.Dropout(args.dropout)
81
  self.resid_dropout = nn.Dropout(args.dropout)
82
  self.dropout = args.dropout
 
 
83
  self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
84
 
85
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
86
+ mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
87
+ mask = torch.triu(mask, diagonal=1)
88
+ self.register_buffer("mask", mask)
89
+
90
+ def forward(self, x: torch.Tensor, pos_cis: torch.Tensor, kv_cache=False):
 
 
 
 
 
 
 
91
  bsz, seqlen, _ = x.shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
 
 
 
94
 
95
  xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
96
  xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
97
  xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
98
 
 
99
  xq, xk = apply_rotary_emb(xq, xk, pos_cis)
100
 
101
+ # 更高效的kv_cache实现
102
+ if kv_cache and self.eval():
103
+ if seqlen == 1 and all(cache is not None for cache in (self.k_cache, self.v_cache)):
104
+ xk = torch.cat((self.k_cache, xk), dim=1)
105
+ xv = torch.cat((self.v_cache, xv), dim=1)
106
+ self.k_cache, self.v_cache = xk, xv
107
+
108
  xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
109
  xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
110
 
111
+ xq = xq.transpose(1, 2)
 
112
  xk = xk.transpose(1, 2)
113
  xv = xv.transpose(1, 2)
114
 
115
+ if self.flash and seqlen != 1:
 
116
  output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None,
117
  dropout_p=self.dropout if self.training else 0.0,
118
  is_causal=True)
119
  else:
 
120
  scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
 
121
  scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen)
122
  scores = F.softmax(scores.float(), dim=-1).type_as(xq)
123
  scores = self.attn_dropout(scores)
124
  output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
125
 
 
126
  output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
127
 
 
128
  output = self.wo(output)
129
  output = self.resid_dropout(output)
130
+ return output
131
 
132
 
133
  class FeedForward(nn.Module):
 
157
  self.alpha = config.aux_loss_alpha
158
  self.seq_aux = config.seq_aux
159
 
 
160
  self.norm_topk_prob = config.norm_topk_prob
161
  self.gating_dim = config.dim
162
  self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
 
168
 
169
  def forward(self, hidden_states):
170
  bsz, seq_len, h = hidden_states.shape
171
+
172
  hidden_states = hidden_states.view(-1, h)
173
  logits = F.linear(hidden_states, self.weight, None)
174
  if self.scoring_func == 'softmax':
 
176
  else:
177
  raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
178
 
 
179
  topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
180
 
 
181
  if self.top_k > 1 and self.norm_topk_prob:
182
  denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
183
  topk_weight = topk_weight / denominator
184
 
 
185
  if self.training and self.alpha > 0.0:
186
  scores_for_aux = scores
187
  aux_topk = self.top_k
 
188
  topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
189
  if self.seq_aux:
190
  scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
 
301
  dropout=args.dropout,
302
  )
303
 
304
+ def forward(self, x, pos_cis, kv_cache=False):
305
+ h = x + self.attention(self.attention_norm(x), pos_cis, kv_cache)
 
306
  out = h + self.feed_forward(self.ffn_norm(h))
307
+ return out
308
 
309
 
310
  class Transformer(PreTrainedModel):
 
326
  self.layers.append(TransformerBlock(layer_id, params))
327
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
328
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
329
+ self.tok_embeddings.weight = self.output.weight
 
 
 
 
330
  pos_cis = precompute_pos_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
331
  self.register_buffer("pos_cis", pos_cis, persistent=False)
332
 
 
333
  self.apply(self._init_weights)
334
+
335
  for pn, p in self.named_parameters():
336
  if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
337
  torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers))
338
 
 
339
  self.last_loss = None
340
  self.OUT = CausalLMOutputWithPast()
341
 
 
347
  elif isinstance(module, nn.Embedding):
348
  torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
349
 
350
+ def forward(self, tokens: Optional[torch.Tensor] = None, targets: Optional[torch.Tensor] = None,
351
+ kv_cache=False, **keyargs):
352
+ current_idx = 0
 
 
353
  if 'input_ids' in keyargs:
354
  tokens = keyargs['input_ids']
355
  if 'attention_mask' in keyargs:
356
  targets = keyargs['attention_mask']
357
+ if 'current_idx' in keyargs:
358
+ current_idx = int(keyargs['current_idx'])
359
 
360
  _bsz, seqlen = tokens.shape
361
  h = self.tok_embeddings(tokens)
362
  h = self.dropout(h)
363
+ pos_cis = self.pos_cis[current_idx:current_idx + seqlen]
364
  for idx, layer in enumerate(self.layers):
365
+ h = layer(h, pos_cis, kv_cache)
366
 
367
  h = self.norm(h)
368
 
369
  if targets is not None:
 
370
  logits = self.output(h)
371
  self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
372
  else:
373
+ logits = self.output(h[:, [-1], :])
 
374
  self.last_loss = None
375
 
376
  self.OUT.__setitem__('logits', logits)
377
  self.OUT.__setitem__('last_loss', self.last_loss)
 
 
 
378
  return self.OUT
379
 
 
380
  @torch.inference_mode()
381
+ def generate(self, idx, eos, max_new_tokens, temperature=0.7, top_k=8, stream=True, rp=1., kv_cache=True):
382
+ # rp: repetition_penalty
383
  index = idx.shape[1]
384
+ init_inference = True
 
385
  while idx.shape[1] < max_new_tokens - 1:
386
+ if init_inference or not kv_cache:
387
+ inference_res, init_inference = self(idx, kv_cache=kv_cache), False
 
 
 
 
388
  else:
389
+ inference_res = self(idx[:, -1:], kv_cache=kv_cache, current_idx=idx.shape[1] - 1)
390
 
391
+ logits = inference_res.logits
392
+ logits = logits[:, -1, :]
393
 
 
394
  for token in set(idx.tolist()[0]):
395
+ logits[:, token] /= rp
396
 
397
  if temperature == 0.0:
398
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
 
399
  else:
 
400
  logits = logits / temperature
 
401
  if top_k is not None:
402
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
403
  logits[logits < v[:, [-1]]] = -float('Inf')
404
 
 
405
  probs = F.softmax(logits, dim=-1)
406
  idx_next = torch.multinomial(probs, num_samples=1, generator=None)
407
+
408
  if idx_next == eos:
409
  break
410
 
 
417
 
418
  @torch.inference_mode()
419
  def eval_answer(self, idx):
 
420
  idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
421
+ inference_res = self(idx_cond)
 
 
422
  logits = inference_res.logits
423
  logits = logits[:, -1, :]
424
  return logits