Skip to content

Changes for basic LLaDA style diffusion masking support #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,48 @@ class GPTBatch:


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
"""Collate function that supports LLaDA-style masking."""
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None

token_ids = torch.from_numpy(stacked_ids)

if sampling_parameters.diffusion.enabled:
batch_size, seq_len = token_ids.shape
device = token_ids.device
t = torch.rand(batch_size, device=device)
p_mask = (1 - sampling_parameters.diffusion.epsilon) * t + sampling_parameters.diffusion.epsilon
p_mask = torch.min(p_mask, torch.tensor(sampling_parameters.diffusion.max_mask_prob))
p_mask = p_mask[:, None].expand(-1, seq_len)

masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask

if sampling_parameters.diffusion.pad_prob > 0:
pad_mask = torch.rand((batch_size,), device=device) < sampling_parameters.diffusion.pad_prob
if pad_mask.any():
masked_indices[pad_mask] = True

token_ids = torch.where(masked_indices, sampling_parameters.diffusion.mask_token_id, token_ids)

if not stacked_spans:
stacked_spans = []
stacked_spans.extend([
torch.stack([masked_indices[i], p_mask[i]])
for i in range(batch_size)
])

if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if not stacked_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]

if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]

return GPTBatch(
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
token_ids=token_ids,
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths
)


Expand Down
50 changes: 50 additions & 0 deletions fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,51 @@ class GPTSamplingConfig(SamplingConfig):
)


@config_class()
class DiffusionMaskingConfig(Config):
"""Configuration for diffusion-based masking during data preparation."""

enabled: bool = Field(
default=False,
desc="Whether to use diffusion-based masking during training",
hint=FieldHint.feature
)

epsilon: float = Field(
default=1e-3,
desc="Minimum masking probability",
hint=FieldHint.performance,
valid=check_field(Assert.gt, 0)
)

max_mask_prob: float = Field(
default=0.15,
desc="Maximum masking probability",
hint=FieldHint.performance,
valid=check_field(Assert.gt, 0)
)

pad_prob: float = Field(
default=0.01,
desc="Probability of padding tokens for 1% of samples",
hint=FieldHint.optional,
valid=check_field(Assert.geq, 0)
)

mask_token_id: int = Field(
default=103,
desc="Token ID to use for masking",
hint=FieldHint.required
)

def _validate(self) -> None:
super()._validate()
Assert.lt(self.epsilon, self.max_mask_prob, "epsilon must be less than max_mask_prob")
Assert.lt(self.max_mask_prob, 1.0, "max_mask_prob must be less than 1.0")
if self.enabled:
Assert.is_not_none(self.mask_token_id, "mask_token_id must be set when masking is enabled")


@dataclasses.dataclass(kw_only=True)
class GPTSamplingParameters(SamplingParameters):
"""
Expand All @@ -77,6 +122,11 @@ class GPTSamplingParameters(SamplingParameters):
# How many extra tokens to add to the sequence length.
# This is used to provide labels even for the last tokens in the sequence.
extra_tokens: int = 1

# Diffusion masking configuration
diffusion: DiffusionMaskingConfig = dataclasses.field(
default_factory=DiffusionMaskingConfig
)


@dataclasses.dataclass(kw_only=True)
Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

58 changes: 58 additions & 0 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,61 @@ def _logits_cross_entropy_forward_backward(
# TODO: de-allocate earlier.
del logits
return loss, output_parallel_linear_backward(grad, context)

class MLMHead(LanguageModelHead):
"""
A masked language model head for diffusion-based training.
"""

def _logits_cross_entropy_forward_backward(
self,
input_: torch.Tensor,
labels: torch.Tensor | None,
weight: torch.Tensor,
grad_output: float,
kwargs: dict,
losses: dict | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
logits, context = output_parallel_linear_forward(
input_=input_,
weight=weight,
bias=None,
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
sequence_parallel=self._sequence_parallel and self._parallel_embeddings,
)

if self._z_loss_factor > 0.0:
logits = z_loss(
logits,
self._z_loss_factor,
self.training,
grad_output,
losses,
LanguageModelLossNames.z_loss,
logits_scale_factor=self._logits_scale_factor,
)

if labels is None:
return logits * self._logits_scale_factor, None

masked_indices = kwargs['masked_indices']
p_mask = kwargs['p_mask']

masked_logits = logits[masked_indices]
masked_labels = labels[masked_indices]
masked_p = p_mask[masked_indices]

# Compute MLM loss
loss, grad = cross_entropy_forward_backward(
masked_logits.flatten(0, -2),
masked_labels,
group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None,
grad_output=grad_output / masked_p,
implementation=self._cross_entropy_impl,
logits_scale_factor=self._logits_scale_factor,
)

loss = loss / (labels.shape[0] * labels.shape[1])

del logits
return loss, output_parallel_linear_backward(grad, context)
26 changes: 26 additions & 0 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,37 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
)


@config_class()
class DiffusionMaskingConfig(Config):
"""Configuration for diffusion-based masking in the transformer model.
This config only contains model-specific parameters. For masking parameters,
refer to fast_llm.data.dataset.gpt.config.DiffusionMaskingConfig."""

enabled: bool = Field(
default=False,
desc="Whether to use diffusion-based masking during training",
hint=FieldHint.feature
)
bidirectional_attention: bool = Field(
default=True,
desc="Whether to use bidirectional attention for masked tokens",
hint=FieldHint.feature
)

def _validate(self) -> None:
super()._validate()


@config_class()
class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig)
rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig)
peft: TransformerPeftConfig = FieldUpdate(default_factory=TransformerPeftConfig)
diffusion: DiffusionMaskingConfig = Field(
default_factory=DiffusionMaskingConfig,
desc="Configuration for diffusion-based masking",
hint=FieldHint.feature
)
# Default: hidden_size**-0.5
# TODO: Allow custom initialization (InitializationConfig?)
init_method_std: float = Field(
Expand Down
79 changes: 76 additions & 3 deletions fast_llm/layers/transformer/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,16 @@ class BackupAttentionPreprocessor(Preprocessor):
_scalar_dim: TensorDim
_kv_channels_dim: TensorDim
_rotary_embedding_frequencies: torch.Tensor
_mask: torch.Tensor
_mask_value: torch.Tensor
_mask: torch.Tensor | None
_mask_value: torch.Tensor | None
_tensor_cache_max_sequence_length: int = -1

def __init__(
self,
config: TransformerConfig,
tensor_space: TensorSpace,
):
super().__init__()
self._config = config
self._tensor_space = tensor_space
self._distributed_config = self._tensor_space.distributed_config
Expand Down Expand Up @@ -263,7 +264,7 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims(
(self._scalar_dim,),
tensor_name=TransformerKwargs.attention_mask_value,
dtype=self._tensor_space.distributed_config.training_dtype.torch,
dtype=self._distributed_config.training_dtype.torch,
)


Expand Down Expand Up @@ -337,3 +338,75 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
)
kwargs[TransformerKwargs.max_seqlen_q] = seqlens_q.max()
kwargs[TransformerKwargs.max_seqlen_k] = seqlens_k.max()


class LLaDAMaskingPreprocessor(Preprocessor):
"""Preprocessor for LLaDA-style masking with diffusion-based training."""

def __init__(self, config: TransformerConfig, tensor_space: TensorSpace):
self._config = config
self._tensor_space = tensor_space
self._distributed_config = tensor_space.distributed_config
self._scalar_dim = tensor_space.get_tensor_dim(DefaultDimNames.scalar)
self._sequence_dim = tensor_space.get_tensor_dim(TransformerDimNames.sequence_q)

def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
"""Apply LLaDA-style masking to the input sequence."""
# Get diffusion config from dataset parameters
diffusion_config = kwargs['parameters'].diffusion
if not diffusion_config.enabled:
return

batch_size, seq_len = batch.shape
device = batch.device

t = torch.rand(batch_size, device=device)

p_mask = (1 - diffusion_config.epsilon) * t + diffusion_config.epsilon

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some questions/thoughts (I am just browsing quickly, and I am not looking at the paper right now):

  • Why is the lower bound epsilon and the upper bound max_mask_prob?
  • My guts tell me you never want the mask probability to be exactly 1, for the same kind of reasons you don't want it to be exactly 0.
  • This approach using torch.min will put a discrete probability for p_mask to be exactly max_mask_prob.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying this coz we could have many timesteps with the exact masking level set to max_mask_prob? So are you suggesting some soft clipping instead of a hard upper bound?

p_mask = torch.min(p_mask, torch.tensor(diffusion_config.max_mask_prob))
p_mask = p_mask[:, None].expand(-1, seq_len)

masked_indices = torch.rand((batch_size, seq_len), device=device) < p_mask

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming True means "masked".


if diffusion_config.pad_prob > 0:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meta: I currently can't comment about padding; it will have to wait for next week, as I need to re-read the paper better (our own work doesn't do padding).

pad_mask = torch.rand((batch_size,), device=device) < diffusion_config.pad_prob
if pad_mask.any():
masked_indices[pad_mask] = True

kwargs['masked_indices'] = masked_indices
kwargs['p_mask'] = p_mask

if self._config.diffusion.bidirectional_attention:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want a string instead of a boolean, as there are many possible attention choices (e.g., blocks) that may come up. Also see the next comment below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, will change this!

# Bidirectional attention - all tokens can attend to all other tokens
attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool)
else:
# Causal attention
attention_mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool).tril_()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that you never want such a triangular causal attention, as this would give a strictly worse model than an autoregressive model.

Suppose that, at inference, tokens are unmasked in the order (4, 2, 3, 0, 1). Token 4 is unmasked first, but this triangular matrix prevents all other tokens from ever "seeing" it.

What is the closest case that makes sense to me would be to permute the rows and columns of the triangular matrix using (4,2,3,0,1), so that token 2 can see token 4, token 3 can see tokens 2 and 4, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, permuted rows and columns makes sense - so we can preserve the order in which it was unmasked. I will update this.
I guess this idea is similar to this paper? https://arxiv.org/abs/1906.08237



kwargs[TransformerKwargs.attention_mask] = attention_mask
kwargs[TransformerKwargs.attention_mask_value] = torch.tensor(-10000.0, device=device)

def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
"""Define tensor metadata for masking tensors."""
# Get diffusion config from dataset parameters
diffusion_config = kwargs['parameters'].diffusion
if not diffusion_config.enabled:
return

kwargs['masked_indices'] = TensorMeta.from_dims(
(self._scalar_dim, self._sequence_dim),
tensor_name='masked_indices'
)
kwargs['p_mask'] = TensorMeta.from_dims(
(self._scalar_dim, self._sequence_dim),
tensor_name='p_mask'
)
kwargs[TransformerKwargs.attention_mask] = TensorMeta.from_dims(
(self._scalar_dim, self._scalar_dim, self._sequence_dim, self._sequence_dim),
tensor_name=TransformerKwargs.attention_mask
)
kwargs[TransformerKwargs.attention_mask_value] = TensorMeta.from_dims(
(self._scalar_dim,),
tensor_name=TransformerKwargs.attention_mask_value
)
Loading
Loading