-
Notifications
You must be signed in to change notification settings - Fork 29
Changes for basic LLaDA style diffusion masking support #238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -195,15 +195,16 @@ class BackupAttentionPreprocessor(Preprocessor): | |
_scalar_dim: TensorDim | ||
_kv_channels_dim: TensorDim | ||
_rotary_embedding_frequencies: torch.Tensor | ||
_mask: torch.Tensor | ||
_mask_value: torch.Tensor | ||
_mask: torch.Tensor | None | ||
_mask_value: torch.Tensor | None | ||
_tensor_cache_max_sequence_length: int = -1 | ||
|
||
def __init__( | ||
self, | ||
config: TransformerConfig, | ||
tensor_space: TensorSpace, | ||
): | ||
super().__init__() | ||
self._config = config | ||
self._tensor_space = tensor_space | ||
self._distributed_config = self._tensor_space.distributed_config | ||
|
@@ -263,7 +264,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: | |
kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( | ||
(self._scalar_dim,), | ||
tensor_name=TransformerKwargs.attention_mask_value, | ||
dtype=self._tensor_space.distributed_config.training_dtype.torch, | ||
dtype=self._distributed_config.training_dtype.torch, | ||
) | ||
|
||
|
||
|
@@ -337,3 +338,75 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: | |
) | ||
kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max() | ||
kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max() | ||
|
||
|
||
class LLaDAMaskingPreprocessor(Preprocessor): | ||
"""Preprocessor for LLaDA-style masking with diffusion-based training.""" | ||
|
||
def __init__(self, config: TransformerConfig, tensor_space: TensorSpace): | ||
self._config = config | ||
self._tensor_space = tensor_space | ||
self._distributed_config = tensor_space.distributed_config | ||
self._scalar_dim = tensor_space.get_tensor_dim(DefaultDimNames.scalar) | ||
self._sequence_dim = tensor_space.get_tensor_dim(TransformerDimNames.sequence_q) | ||
|
||
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: | ||
"""Apply LLaDA-style masking to the input sequence.""" | ||
# Get diffusion config from dataset parameters | ||
diffusion_config = kwargs['parameters'].diffusion | ||
if not diffusion_config.enabled: | ||
return | ||
|
||
batch_size, seq_len = batch.shape | ||
device = batch.device | ||
|
||
t = torch.rand(batch_size, device=device) | ||
|
||
p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon | ||
p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob)) | ||
p_mask = p_mask[:, None].expand(-1, seq_len) | ||
|
||
masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assuming |
||
|
||
if diffusion_config.pad_prob > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Meta: I currently can't comment about padding; it will have to wait for next week, as I need to re-read the paper better (our own work doesn't do padding). |
||
pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob | ||
if pad_mask.any(): | ||
masked_indices[pad_mask] = True | ||
|
||
kwargs['masked_indices'] = masked_indices | ||
kwargs['p_mask'] = p_mask | ||
|
||
if self._config.diffusion.bidirectional_attention: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may want a string instead of a boolean, as there are many possible attention choices (e.g., blocks) that may come up. Also see the next comment below. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, will change this! |
||
# Bidirectional attention - all tokens can attend to all other tokens | ||
attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool) | ||
else: | ||
# Causal attention | ||
attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that you never want such a triangular causal attention, as this would give a strictly worse model than an autoregressive model. Suppose that, at inference, tokens are unmasked in the order (4, 2, 3, 0, 1). Token 4 is unmasked first, but this triangular matrix prevents all other tokens from ever "seeing" it. What is the closest case that makes sense to me would be to permute the rows and columns of the triangular matrix using (4,2,3,0,1), so that token 2 can see token 4, token 3 can see tokens 2 and 4, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, permuted rows and columns makes sense - so we can preserve the order in which it was unmasked. I will update this. |
||
|
||
|
||
kwargs[TransformerKwargs.attention_mask] = attention_mask | ||
kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(-10000.0, device=device) | ||
|
||
def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: | ||
"""Define tensor metadata for masking tensors.""" | ||
# Get diffusion config from dataset parameters | ||
diffusion_config = kwargs['parameters'].diffusion | ||
if not diffusion_config.enabled: | ||
return | ||
|
||
kwargs['masked_indices'] = TensorMeta.from_dims( | ||
(self._scalar_dim, self._sequence_dim), | ||
tensor_name='masked_indices' | ||
) | ||
kwargs['p_mask'] = TensorMeta.from_dims( | ||
(self._scalar_dim, self._sequence_dim), | ||
tensor_name='p_mask' | ||
) | ||
kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims( | ||
(self._scalar_dim, self._scalar_dim, self._sequence_dim, self._sequence_dim), | ||
tensor_name=TransformerKwargs.attention_mask | ||
) | ||
kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims( | ||
(self._scalar_dim,), | ||
tensor_name=TransformerKwargs.attention_mask_value | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some questions/thoughts (I am just browsing quickly, and I am not looking at the paper right now):
epsilon
and the upper boundmax_mask_prob
?torch.min
will put a discrete probability forp_mask
to be exactlymax_mask_prob
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you saying this coz we could have many timesteps with the exact masking level set to
max_mask_prob
? So are you suggesting some soft clipping instead of a hard upper bound?