Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve precision of 32-bit gammaln? #21284

Open
lucascolley opened this issue May 17, 2024 · 4 comments
Open

Improve precision of 32-bit gammaln? #21284

lucascolley opened this issue May 17, 2024 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@lucascolley
Copy link

Description

In [2]: from jax.scipy.special import gammaln as gammaln_jax
In [5]: x = jax.numpy.asarray(2.00001)
In [7]: gammaln_jax(x)
Out[7]: Array(5.722046e-06, dtype=float32, weak_type=True)
...
In [1]: from scipy.special import gammaln
In [4]: x = np.asarray(2.00001)
In [5]: gammaln(x)
Out[5]: 4.227875597648359e-06

https://www.wolframalpha.com/input?i=ln%28Gamma%282.00001%29%29

Found in scipy/scipy#20085 (comment).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.27
jaxlib: 0.4.23.dev20240502
numpy:  1.26.4
python: 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:35:20) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Lucass-MacBook-Air-4.local', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:12:41 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T8103', machine='arm64')
@lucascolley lucascolley added the bug Something isn't working label May 17, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented May 17, 2024

Thanks for the report! This looks like a float precision issue. Scipy uses 64-bit precision, while JAX uses 32-bit precision by default. If you enable 64-bit precision in JAX, you get the expected output:

In [1]: import jax
In [2]: jax.config.update('jax_enable_x64', True)
In [3]: x = jax.numpy.asarray(2.00001)
In [4]: jax.scipy.special.gammaln(x)
Out[4]: Array(4.2278756e-06, dtype=float64, weak_type=True)

That said, it seems like the 32-bit computation should be able to return a more precise answer here – in the case of gammaln, JAX is more-or-less directly calling OpenXLA's lgamma function, which is implemented here: https://github.com/openxla/xla/blob/1b9e830bc32da1e75a4bab4130376accd7e61e4d/xla/client/lib/math.cc#L513

I wonder if there's a different series expansion we could use that would be more accurate for small outputs?

@jakevdp jakevdp self-assigned this May 17, 2024
@lucascolley
Copy link
Author

Interesting, I thought that we were enabling 64-bit JAX over in SciPy. Let me check.

@lucascolley
Copy link
Author

We have this in our conftest.py, so not sure why we observed the less precise value in CI:

import jax.numpy  # type: ignore[import-not-found]
xp_available_backends.update({'jax.numpy': jax.numpy})
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_default_device", jax.devices(SCIPY_DEVICE)[0])

@lucascolley
Copy link
Author

lucascolley commented May 17, 2024

Feel free to close this, I re-ran CI and the test passed, so perhaps a temporary blip out of 64-bit mode for some reason.

This was our bad, the test was sending in float32 arrays, but the tolerance was only temperamentally being violated.

@lucascolley lucascolley changed the title BUG: scipy.special.gammaln(2.00001) incorrect Improve precision of 32-bit gammaln? May 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants