Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduces latency of rendering Jax arrays. #40

Merged
merged 1 commit into from
Oct 1, 2024
Merged

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Sep 29, 2024

Reduces latency of rendering Jax arrays.

Before:
Cold rendering array of any size would take ~15-25 seconds (in colab).

Now:
Cold rendering of small-ish (less than 10M bytes) arrays is instantaneous (<< 1s)
Larger arrays (e.g. > 2K x 2K) take 3-4 seconds, which is much more manageable

To do this we do two things:
a) We convert to numpy for arrays smaller than 10M for computing stats and doing slicing. (We still maintain jax visualization of sharding and types)
b) We use a single jitted function for arrays larger than 10m to compute summaries rather than individual jax invocations for each stat, which end up jitted separately.

@copybara-service copybara-service bot force-pushed the test_680276759 branch 6 times, most recently from 7757819 to f47b55b Compare October 1, 2024 16:35
Before:
  Cold rendering array of any size would take ~15-25 seconds (in colab).

Now:
  Cold rendering of small-ish (less than 10M bytes) arrays is instantaneous (<< 1s)
  Larger arrays (e.g. > 2K x 2K) take 3-4 seconds, which is much more manageable

To do this we do two things:
 a) We convert to numpy for arrays smaller than 10M for computing stats and doing slicing. (We still maintain jax visualization of sharding and types)
 b) We use a single jitted function for arrays larger than 10m to compute summaries rather than individual jax invocations for each stat, which end up jitted separately.

PiperOrigin-RevId: 681055675
@copybara-service copybara-service bot merged commit 438b9a5 into main Oct 1, 2024
@copybara-service copybara-service bot deleted the test_680276759 branch October 1, 2024 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant