oweller2
commited on
Commit
•
e1a243a
1
Parent(s):
d831694
same as training code
Browse files- modeling_flexbert.py +16 -38
- padding.py +1 -1
modeling_flexbert.py
CHANGED
@@ -1529,16 +1529,13 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1529 |
self.unpad_embeddings = config.unpad_embeddings
|
1530 |
self.pad_logits = config.pad_logits
|
1531 |
self.compile_model = config.compile_model
|
1532 |
-
self.vocab_size = config.vocab_size
|
1533 |
# self.masked_prediction = config.masked_prediction
|
1534 |
|
1535 |
# Initialize weights and apply final processing
|
1536 |
self._init_weights(reset_params=False)
|
1537 |
|
1538 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
1539 |
-
# Handle the XOR condition
|
1540 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
1541 |
-
|
1542 |
if module is not None:
|
1543 |
# Add basic initialization for common module types
|
1544 |
if isinstance(module, (nn.Linear, nn.Embedding)):
|
@@ -1552,7 +1549,7 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1552 |
assert isinstance(reset_params, bool)
|
1553 |
self.bert._init_weights(reset_params=reset_params)
|
1554 |
self.lm_head._init_weights(reset_params=reset_params)
|
1555 |
-
|
1556 |
if not self.config.tie_word_embeddings:
|
1557 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1558 |
|
@@ -1640,27 +1637,22 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1640 |
#
|
1641 |
# Prediction scores are only computed for masked tokens and the (bs,
|
1642 |
# seqlen) dimensions are flattened
|
|
|
1643 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1644 |
-
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1645 |
-
batch_size, seq_len = input_ids.shape[:2]
|
1646 |
-
if attention_mask is None:
|
1647 |
-
# unpad expects a encoder-like mask where all non-padding are ones
|
1648 |
-
attention_mask = torch.ones_like(input_ids)
|
1649 |
-
attention_mask[input_ids == 50283] = 0 # zero out pad tokens
|
1650 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
1651 |
input_ids, attention_mask, position_ids, labels
|
1652 |
)
|
1653 |
|
1654 |
-
|
1655 |
hidden_states = self.bert(
|
1656 |
input_ids,
|
1657 |
-
attention_mask=None, # let FA
|
1658 |
position_ids=position_ids,
|
1659 |
indices=indices,
|
1660 |
cu_seqlens=cu_seqlens,
|
1661 |
max_seqlen=max_seqlen,
|
1662 |
)
|
1663 |
-
# print(hidden_states.shape)
|
1664 |
|
1665 |
if self.compile_model:
|
1666 |
logits = self.compiled_lm_head(hidden_states)
|
@@ -1673,26 +1665,24 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1673 |
shift_labels = torch.full_like(input_ids, -100)
|
1674 |
shift_labels[:-1] = input_ids[1:]
|
1675 |
|
1676 |
-
# Mask boundaries
|
1677 |
for i in range(len(cu_seqlens) - 1):
|
1678 |
boundary_pos = cu_seqlens[i+1] - 1
|
1679 |
shift_labels[boundary_pos] = -100
|
1680 |
-
|
1681 |
-
# Mask out PAD tokens
|
1682 |
-
mask = (shift_labels == 50283)
|
1683 |
-
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
1684 |
-
|
1685 |
|
1686 |
-
|
1687 |
-
|
1688 |
-
|
1689 |
-
|
1690 |
-
# breakpoint() # pkill -u oweller2 -f wandb
|
1691 |
|
1692 |
else:
|
1693 |
# Padded case: simple shift
|
1694 |
shift_labels = input_ids[..., 1:].contiguous()
|
1695 |
logits = logits[..., :-1, :].contiguous()
|
|
|
|
|
|
|
|
|
1696 |
|
1697 |
# For both cases, we'll use the shifted input_ids as our labels
|
1698 |
labels = shift_labels
|
@@ -1703,26 +1693,14 @@ class FlexBertForCausalLM(FlexBertPreTrainedModel):
|
|
1703 |
shift_labels.view(-1)
|
1704 |
)
|
1705 |
|
1706 |
-
if self.unpad_embeddings: # revert back to normal logits
|
1707 |
-
logits = logits.view(batch_size, -1, self.vocab_size)
|
1708 |
-
|
1709 |
if self.pad_logits:
|
1710 |
-
# print(f"Padding logits: {logits.shape}")
|
1711 |
-
new_logits = self.pad_inputs(logits, indices, batch_size, seq_len-1)[0]
|
1712 |
-
# print(f"New logits: {new_logits.shape}")
|
1713 |
-
# print(new_logits.shape)
|
1714 |
-
# if new_logits.dim() == 2:
|
1715 |
-
# new_logits = new_logits.unsqueeze(0)
|
1716 |
return CausalLMOutput(
|
1717 |
loss=loss,
|
1718 |
-
logits=
|
1719 |
hidden_states=None,
|
1720 |
attentions=None,
|
1721 |
)
|
1722 |
else:
|
1723 |
-
# print(f"Non-padding logits: {logits.shape}")
|
1724 |
-
# if logits.dim() == 2:
|
1725 |
-
# logits = logits.unsqueeze(0)
|
1726 |
return CausalLMOutput(
|
1727 |
loss=loss,
|
1728 |
logits=logits,
|
@@ -1947,4 +1925,4 @@ def init_mlm_model_from_pretrained(
|
|
1947 |
pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
|
1948 |
)
|
1949 |
else:
|
1950 |
-
tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
|
|
|
1529 |
self.unpad_embeddings = config.unpad_embeddings
|
1530 |
self.pad_logits = config.pad_logits
|
1531 |
self.compile_model = config.compile_model
|
|
|
1532 |
# self.masked_prediction = config.masked_prediction
|
1533 |
|
1534 |
# Initialize weights and apply final processing
|
1535 |
self._init_weights(reset_params=False)
|
1536 |
|
1537 |
def _init_weights(self, module: Optional[nn.Module] = None, reset_params: Optional[bool] = None):
|
|
|
1538 |
assert (module is None) != (reset_params is None), "arg module xor reset_params must be specified"
|
|
|
1539 |
if module is not None:
|
1540 |
# Add basic initialization for common module types
|
1541 |
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
|
1549 |
assert isinstance(reset_params, bool)
|
1550 |
self.bert._init_weights(reset_params=reset_params)
|
1551 |
self.lm_head._init_weights(reset_params=reset_params)
|
1552 |
+
|
1553 |
if not self.config.tie_word_embeddings:
|
1554 |
init_weights(self.config, self.decoder, self.config.hidden_size, type_of_module=ModuleType.final_out)
|
1555 |
|
|
|
1637 |
#
|
1638 |
# Prediction scores are only computed for masked tokens and the (bs,
|
1639 |
# seqlen) dimensions are flattened
|
1640 |
+
|
1641 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1642 |
+
if self.unpad_embeddings and (indices is None and cu_seqlens is None and max_seqlen is None):
|
1643 |
+
batch_size, seq_len = input_ids.shape[:2]
|
|
|
|
|
|
|
|
|
1644 |
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = self.unpad_inputs(
|
1645 |
input_ids, attention_mask, position_ids, labels
|
1646 |
)
|
1647 |
|
|
|
1648 |
hidden_states = self.bert(
|
1649 |
input_ids,
|
1650 |
+
attention_mask=None, # let FA do this
|
1651 |
position_ids=position_ids,
|
1652 |
indices=indices,
|
1653 |
cu_seqlens=cu_seqlens,
|
1654 |
max_seqlen=max_seqlen,
|
1655 |
)
|
|
|
1656 |
|
1657 |
if self.compile_model:
|
1658 |
logits = self.compiled_lm_head(hidden_states)
|
|
|
1665 |
shift_labels = torch.full_like(input_ids, -100)
|
1666 |
shift_labels[:-1] = input_ids[1:]
|
1667 |
|
1668 |
+
# Mask boundaries, so eos doesn't predict bos
|
1669 |
for i in range(len(cu_seqlens) - 1):
|
1670 |
boundary_pos = cu_seqlens[i+1] - 1
|
1671 |
shift_labels[boundary_pos] = -100
|
|
|
|
|
|
|
|
|
|
|
1672 |
|
1673 |
+
# NOTE: no padding or mask in there for now
|
1674 |
+
assert 50283 not in shift_labels, f"PAD token found in shift_labels: {shift_labels}"
|
1675 |
+
assert 50284 not in shift_labels, f"MASK token found in shift_labels: {shift_labels}"
|
1676 |
+
assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
|
|
|
1677 |
|
1678 |
else:
|
1679 |
# Padded case: simple shift
|
1680 |
shift_labels = input_ids[..., 1:].contiguous()
|
1681 |
logits = logits[..., :-1, :].contiguous()
|
1682 |
+
# mask out PAD tokens in the shift_labels
|
1683 |
+
mask = (shift_labels == 50283)
|
1684 |
+
shift_labels = torch.where(mask, torch.tensor(-100, device=shift_labels.device), shift_labels)
|
1685 |
+
assert shift_labels.shape == logits.shape[:-1] # Verify shapes align
|
1686 |
|
1687 |
# For both cases, we'll use the shifted input_ids as our labels
|
1688 |
labels = shift_labels
|
|
|
1693 |
shift_labels.view(-1)
|
1694 |
)
|
1695 |
|
|
|
|
|
|
|
1696 |
if self.pad_logits:
|
|
|
|
|
|
|
|
|
|
|
|
|
1697 |
return CausalLMOutput(
|
1698 |
loss=loss,
|
1699 |
+
logits=self.pad_inputs(logits, indices, batch_size, seq_len)[0],
|
1700 |
hidden_states=None,
|
1701 |
attentions=None,
|
1702 |
)
|
1703 |
else:
|
|
|
|
|
|
|
1704 |
return CausalLMOutput(
|
1705 |
loss=loss,
|
1706 |
logits=logits,
|
|
|
1925 |
pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode, bias_only=True
|
1926 |
)
|
1927 |
else:
|
1928 |
+
tile_linear(pretrained_model.decoder, new_model.decoder, linear_type=TileLinear.default, mode=mode)
|
padding.py
CHANGED
@@ -84,4 +84,4 @@ def pad_input(
|
|
84 |
padded_labels[indices] = labels
|
85 |
padded_labels = padded_labels.view(batch, seqlen)
|
86 |
|
87 |
-
return padded_inputs, padded_labels
|
|
|
84 |
padded_labels[indices] = labels
|
85 |
padded_labels = padded_labels.view(batch, seqlen)
|
86 |
|
87 |
+
return padded_inputs, padded_labels
|