oweller2 commited on
Commit
e1a243a
1 Parent(s): d831694

same as training code

Browse files
Files changed (2) hide show
  1. modeling_flexbert.py +16 -38
  2. padding.py +1 -1
modeling_flexbert.py CHANGED
@@ -1529,16 +1529,13 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1529
  self.unpad_embeddings = config.unpad_embeddings
1530
  self.pad_logits = config.pad_logits
1531
  self.compile_model = config.compile_model
1532
- self.vocab_size = config.vocab_size
1533
  # self.masked_prediction = config.masked_prediction
1534
 
1535
  # Initialize weights and apply final processing
1536
  self._init_weights(reset_params=False)
1537
 
1538
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
1539
- # Handle the XOR condition
1540
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
1541
-
1542
  if module is not None:
1543
  # Add basic initialization for common module types
1544
  if isinstance(module, (nn.Linear, nn.Embedding)):
@@ -1552,7 +1549,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1552
  assert isinstance(reset_params, bool)
1553
  self.bert._init_weights(reset_params=reset_params)
1554
  self.lm_head._init_weights(reset_params=reset_params)
1555
-
1556
  if not self.config.tie_word_embeddings:
1557
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1558
 
@@ -1640,27 +1637,22 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1640
  #
1641
  # Prediction scores are only computed for masked tokens and the (bs,
1642
  # seqlen) dimensions are flattened
 
1643
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1644
- if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1645
- batch_size, seq_len = input_ids.shape[:2]
1646
- if attention_mask is None:
1647
- # unpad expects a encoder-like mask where all non-padding are ones
1648
- attention_mask = torch.ones_like(input_ids)
1649
- attention_mask[input_ids == 50283] = 0 # zero out pad tokens
1650
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1651
  input_ids, attention_mask, position_ids, labels
1652
  )
1653
 
1654
-
1655
  hidden_states = self.bert(
1656
  input_ids,
1657
- attention_mask=None, # let FA handle it
1658
  position_ids=position_ids,
1659
  indices=indices,
1660
  cu_seqlens=cu_seqlens,
1661
  max_seqlen=max_seqlen,
1662
  )
1663
- # print(hidden_states.shape)
1664
 
1665
  if self.compile_model:
1666
  logits = self.compiled_lm_head(hidden_states)
@@ -1673,26 +1665,24 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1673
  shift_labels = torch.full_like(input_ids, -100)
1674
  shift_labels[:-1] = input_ids[1:]
1675
 
1676
- # Mask boundaries
1677
  for i in range(len(cu_seqlens) - 1):
1678
  boundary_pos = cu_seqlens[i+1] - 1
1679
  shift_labels[boundary_pos] = -100
1680
-
1681
- # Mask out PAD tokens
1682
- mask = (shift_labels == 50283)
1683
- shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
1684
-
1685
 
1686
- # print input_ids[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
1687
- # print shift_labels[(cu_seqlens[2]+1)-5:(cu_seqlens[2]+1)+5]
1688
- # print input_ids[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
1689
- # print shift_labels[(cu_seqlens[-2]+1)-5:(cu_seqlens[-2]+1)+5]
1690
- # breakpoint() # pkill -u oweller2 -f wandb
1691
 
1692
  else:
1693
  # Padded case: simple shift
1694
  shift_labels = input_ids[..., 1:].contiguous()
1695
  logits = logits[..., :-1, :].contiguous()
 
 
 
 
1696
 
1697
  # For both cases, we'll use the shifted input_ids as our labels
1698
  labels = shift_labels
@@ -1703,26 +1693,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
1703
  shift_labels.view(-1)
1704
  )
1705
 
1706
- if self.unpad_embeddings: # revert back to normal logits
1707
- logits = logits.view(batch_size, -1, self.vocab_size)
1708
-
1709
  if self.pad_logits:
1710
- # print(f"Padding logits: {logits.shape}")
1711
- new_logits = self.pad_inputs(logits, indices, batch_size, seq_len-1)[0]
1712
- # print(f"New logits: {new_logits.shape}")
1713
- # print(new_logits.shape)
1714
- # if new_logits.dim() == 2:
1715
- # new_logits = new_logits.unsqueeze(0)
1716
  return CausalLMOutput(
1717
  loss=loss,
1718
- logits=new_logits,
1719
  hidden_states=None,
1720
  attentions=None,
1721
  )
1722
  else:
1723
- # print(f"Non-padding logits: {logits.shape}")
1724
- # if logits.dim() == 2:
1725
- # logits = logits.unsqueeze(0)
1726
  return CausalLMOutput(
1727
  loss=loss,
1728
  logits=logits,
@@ -1947,4 +1925,4 @@ def init_mlm_model_from_pretrained(
1947
  pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
1948
  )
1949
  else:
1950
- tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
 
1529
  self.unpad_embeddings = config.unpad_embeddings
1530
  self.pad_logits = config.pad_logits
1531
  self.compile_model = config.compile_model
 
1532
  # self.masked_prediction = config.masked_prediction
1533
 
1534
  # Initialize weights and apply final processing
1535
  self._init_weights(reset_params=False)
1536
 
1537
  def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
 
1538
  assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
 
1539
  if module is not None:
1540
  # Add basic initialization for common module types
1541
  if isinstance(module, (nn.Linear, nn.Embedding)):
 
1549
  assert isinstance(reset_params, bool)
1550
  self.bert._init_weights(reset_params=reset_params)
1551
  self.lm_head._init_weights(reset_params=reset_params)
1552
+
1553
  if not self.config.tie_word_embeddings:
1554
  init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
1555
 
 
1637
  #
1638
  # Prediction scores are only computed for masked tokens and the (bs,
1639
  # seqlen) dimensions are flattened
1640
+
1641
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1642
+ if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
1643
+ batch_size, seq_len = input_ids.shape[:2]
 
 
 
 
1644
  input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
1645
  input_ids, attention_mask, position_ids, labels
1646
  )
1647
 
 
1648
  hidden_states = self.bert(
1649
  input_ids,
1650
+ attention_mask=None, # let FA do this
1651
  position_ids=position_ids,
1652
  indices=indices,
1653
  cu_seqlens=cu_seqlens,
1654
  max_seqlen=max_seqlen,
1655
  )
 
1656
 
1657
  if self.compile_model:
1658
  logits = self.compiled_lm_head(hidden_states)
 
1665
  shift_labels = torch.full_like(input_ids, -100)
1666
  shift_labels[:-1] = input_ids[1:]
1667
 
1668
+ # Mask boundaries, so eos doesn't predict bos
1669
  for i in range(len(cu_seqlens) - 1):
1670
  boundary_pos = cu_seqlens[i+1] - 1
1671
  shift_labels[boundary_pos] = -100
 
 
 
 
 
1672
 
1673
+ # NOTE: no padding or mask in there for now
1674
+ assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
1675
+ assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
1676
+ assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
 
1677
 
1678
  else:
1679
  # Padded case: simple shift
1680
  shift_labels = input_ids[..., 1:].contiguous()
1681
  logits = logits[..., :-1, :].contiguous()
1682
+ # mask out PAD tokens in the shift_labels
1683
+ mask = (shift_labels == 50283)
1684
+ shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
1685
+ assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
1686
 
1687
  # For both cases, we'll use the shifted input_ids as our labels
1688
  labels = shift_labels
 
1693
  shift_labels.view(-1)
1694
  )
1695
 
 
 
 
1696
  if self.pad_logits:
 
 
 
 
 
 
1697
  return CausalLMOutput(
1698
  loss=loss,
1699
+ logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
1700
  hidden_states=None,
1701
  attentions=None,
1702
  )
1703
  else:
 
 
 
1704
  return CausalLMOutput(
1705
  loss=loss,
1706
  logits=logits,
 
1925
  pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
1926
  )
1927
  else:
1928
+ tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
padding.py CHANGED
@@ -84,4 +84,4 @@ def pad_input(
84
  padded_labels[indices] = labels
85
  padded_labels = padded_labels.view(batch, seqlen)
86
 
87
- return padded_inputs, padded_labels
 
84
  padded_labels[indices] = labels
85
  padded_labels = padded_labels.view(batch, seqlen)
86
 
87
+ return padded_inputs, padded_labels