Skip to content

Add split method to RngStream#5270

Open
samanklesaria wants to merge 3 commits intogoogle:mainfrom
samanklesaria:split_method
Open

Add split method to RngStream#5270
samanklesaria wants to merge 3 commits intogoogle:mainfrom
samanklesaria:split_method

Conversation

@samanklesaria
Copy link
Collaborator

This is a small convenience method for RngStreams. Specifically, self.split(k) = self.fork(split=k).

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @samanklesaria, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new convenience method, split, to the RngStream class. This method simplifies the process of forking random number generator streams by providing a more direct and readable way to call the existing fork functionality with a specified split parameter.

Highlights

  • New split method: Introduced a split method to the RngStream class, providing a convenient alias for self.fork(split=k).

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a convenient split method to the RngStream class, acting as an alias for fork(split=k). My review suggests enhancing this by allowing the split factor k to also be a tuple of integers, to fully align with the fork method's capabilities, and by adding a docstring to improve code clarity and maintainability.

Comment on lines +124 to +125
def split(self, k: int):
return self.fork(split=k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great convenience method! To make it even more useful and align it better with the full capabilities of the fork method, consider allowing k to be a tuple of integers as well. This would make split a more complete alias for the splitting functionality of fork.

Additionally, adding a docstring would improve clarity for future users.

  def split(self, k: int | tuple[int, ...]):
    """Forks the RngStream into `k` new streams.

    This is a convenience method for `self.fork(split=k)`.

    Args:
      k: The number of new streams to fork. Can be an integer or a tuple of
        integers to specify the shape of the split keys.

    Returns:
      A new `Rngs` object with `k` forked streams.
    """
    return self.fork(split=k)

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 25, 2026

We should add split to Rngs as well. Also I was wondering how to allow users to easily pass the Rngs differently to transforms, at least for vmap we can allow creating a prefix tree via a Rngs.prefix classmethod so users could do something like this:

rngs = Rngs(params=0, dropout=1)
...
rngs = rngs.split(5, only='dropout')

@vmap(in_axes=Rngs.prefix(params=None, dropout=0), ...)
def f(rngs):
  ...

Implementation could look something like:

class Rngs:
  @classmethod
  def prefix(cls, default=None, **kwargs):
    if default is not None:
      kwargs['default'] = default
    rngs = cls() # init empty
    for name, value in kwargs.items():
      setattr(rngs, name, nnx.data(value)) # set the inputs directly as attributes
    return rngs

@samanklesaria
Copy link
Collaborator Author

samanklesaria commented Mar 2, 2026

Also I was wondering how to allow users to easily pass the Rngs differently to transforms, at least for vmap we can allow creating a prefix tree via a Rngs.prefix classmethod.

@cgarciae It seems like the situation where we want to vmap different parts of a pytree with different batch axes extends beyond Rngs. The most generic way of doing this, as you mentioned before, is split and merge:

rngs = nnx.Rngs(params=0, dropout=1)
rngs = rngs.split(5, only='dropout')
graphdef, dropout, params = nnx.split(rngs, 'dropout', ...)

@nnx.vmap(in_axis=(None, 0), ...)
def f(params, dropout):
  rngs = nnx.merge(graphdef, params, dropout)
  ...
  
f(params, dropout)

But I agree that the 'prefix' method you propose above is easier to use. What if we add it to Pytree instead? That way, it could be used for sharing some subset of parameters when vmapping modules as well.

@samanklesaria samanklesaria force-pushed the split_method branch 4 times, most recently from 2c90755 to 8bf1e6d Compare March 2, 2026 21:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants