Skip to content

Conversation

@AntonOresten
Copy link
Contributor

This is mostly a proof-of-concept for how we can rethink batching as stacking contexts as a block-diagonal in the attention space, masking any interactions between contexts of different "documents", and skipping tiles that are fully masked. For training runs with high variability in sequence length, it avoids a huge amount of needless computation on padding tokens. Dense layers in a model stack also benefit from this.

Ideally this would be generalized to fully-fledged flex attention, but even then, document_ids and lengths might need to be a special case to efficiently construct a block mask.

See also Flex Attention
image image
image

@AntonOresten
Copy link
Contributor Author

Added Grouped-Query attention from #19, and also #18 to avoid NaNs, assuming it to be correct. This should make this branch quite generally useful in reducing memory usage.

@AntonOresten
Copy link
Contributor Author

I think a flexible way of constructing the block mask is the trickiest part. The actual attention computation within each block should mostly remain the same.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant