-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve precision of 32-bit gammaln
?
#21284
Comments
Thanks for the report! This looks like a float precision issue. Scipy uses 64-bit precision, while JAX uses 32-bit precision by default. If you enable 64-bit precision in JAX, you get the expected output: In [1]: import jax
In [2]: jax.config.update('jax_enable_x64', True)
In [3]: x = jax.numpy.asarray(2.00001)
In [4]: jax.scipy.special.gammaln(x)
Out[4]: Array(4.2278756e-06, dtype=float64, weak_type=True) That said, it seems like the 32-bit computation should be able to return a more precise answer here – in the case of I wonder if there's a different series expansion we could use that would be more accurate for small outputs? |
Interesting, I thought that we were enabling 64-bit JAX over in SciPy. Let me check. |
We have this in our import jax.numpy # type: ignore[import-not-found]
xp_available_backends.update({'jax.numpy': jax.numpy})
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices(SCIPY_DEVICE)[0]) |
This was our bad, the test was sending in |
scipy.special.gammaln(2.00001)
incorrectgammaln
?
Description
https://www.wolframalpha.com/input?i=ln%28Gamma%282.00001%29%29
Found in scipy/scipy#20085 (comment).
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: