Skip to content

[jax_intro] Update %time magic with %timeit #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

HumphreyYang
Copy link
Collaborator

This PR fixes #206.

Copy link

netlify bot commented Apr 7, 2025

Deploy Preview for incomparable-parfait-2417f8 ready!

Name Link
🔨 Latest commit f7bdbb3
🔍 Latest deploy log https://app.netlify.com/sites/incomparable-parfait-2417f8/deploys/67f3daebe555670008250a96
😎 Deploy Preview https://deploy-preview-207--incomparable-parfait-2417f8.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

Copy link

github-actions bot commented Apr 7, 2025

@github-actions github-actions bot temporarily deployed to pull request April 7, 2025 14:14 Inactive
@github-actions github-actions bot temporarily deployed to pull request April 7, 2025 14:18 Inactive
@HumphreyYang
Copy link
Collaborator Author

Hi @jstac,

Since %timeit runs the code multiple times, it's not useful for examples where we want to show the compilation time. So, I only used %timeit on the lines that were causing issues.

Interestingly, f_jit(x) is still slower than f(x), even under %timeit, in the preview here:

%timeit f(x).block_until_ready()
100 ms ± 27 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit f_jit(x).block_until_ready()
231 ms ± 313 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

but I cannot replicate it on Colab. On Colab, the jit version runs much faster:

%timeit f(x).block_until_ready()
127 ms ± 365 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit f_jit(x).block_until_ready()
67.4 ms ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Colab uses a newer version of JAX (jax==0.5.2), while we are using jax==0.4.35. Nonetheless, I don't think the issue is related to the JAX version, given that it was working properly before.

One approach might be to replace f with a more computationally intensive function so that the jit-compiled version shows a more noticeable performance difference.

Please let me know your thoughts on how we should proceed.

@jstac
Copy link
Contributor

jstac commented Apr 8, 2025

Thanks @HumphreyYang , much appreciated.

@mmcky , do you have thoughts here?

I know you were considering having a separate runner that we control. Overall, it would be nice to control our environment and be a bit more up to date with JAX versions.

@mmcky
Copy link
Contributor

mmcky commented Apr 8, 2025

Thanks @HumphreyYang just doing some version checking in #208 as I'm not sure why the jax version would be that old.

After I get the MIT Solve application together I am going to setup custom GitHub runners. I have done some experiments and we should be able to get the GitHub runner going on the GPU server. It would be ideal to have a dedicated machine though as it will be running arbitrary code from GitHub.

@mmcky
Copy link
Contributor

mmcky commented Apr 8, 2025

@HumphreyYang the environment on GitHub actions should be using

jax                       0.5.3                    pypi_0    pypi

can you let me know where you found the old version of jax? Thanks

@HumphreyYang
Copy link
Collaborator Author

Many thanks @mmcky,

I found it under the preview building action:

IMG_5636

@HumphreyYang
Copy link
Collaborator Author

Resolved by using JAX==0.6.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Replace %time with %timeit
3 participants