You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems there are two broadly different approaches we could take:
bundle CUDA in the wheel, or
use "external" CUDA.
(1) is the traditional method and results in large wheels. (2) is what JAX does, but is pretty cutting-edge and under-documented.
The way (1) would work is that we would statically link the CUDA libraries (which is what we're currently doing, I think), or dynamically link but let auditwheel copy the libraries into the wheel. There's a few parts of this I still don't understand, such as how it would work with JAX linking against one CUDA runtime but jax-finufft potentially having another. Would that result in two CUDA contexts? Clearly it's already working somehow!
It's somewhat hacky, but the basic ideas are clear. We would set the rpath to find the pip-installed CUDA libraries, using some helper scripts and auditwheel --exclude to allow specific shared libraries. At runtime, the linker will look for the pip-installed CUDA, or user/system installations if that fails.
In terms of PyPI distribution, with CUDA minor version compatibility, I think we can just do what cupy does and use jax-finufft-cuda12x and jax-finufft-cuda11x (if we want to support CUDA 11); no need for a custom package index URL. With (2), we would use [cuda_local] and [cuda_pip] extras. I don't think we need a full matrix of cuDNN versions like JAX does, but I could be wrong.
This is all a bit experimental since this isn't a "vanilla" CUDA extension, but one that has to work with JAX! For that reason, (2) seems more appealing, since it seems more likely to find the same CUDA JAX does more often than not.
The text was updated successfully, but these errors were encountered:
Is there any update on this? We'd love to use this in DESC (PlasmaControl/DESC#1294) but having to build from source to get GPU support is limiting, as all of our other dependencies can be handled with pip install ...
Just starting to write down my thoughts on how we could build and distribute GPU wheels.
For background on GPU wheels, this is the best summary of the current state of affairs I've found: https://pypackaging-native.github.io/key-issues/gpus/
It seems there are two broadly different approaches we could take:
(1) is the traditional method and results in large wheels. (2) is what JAX does, but is pretty cutting-edge and under-documented.
The way (1) would work is that we would statically link the CUDA libraries (which is what we're currently doing, I think), or dynamically link but let auditwheel copy the libraries into the wheel. There's a few parts of this I still don't understand, such as how it would work with JAX linking against one CUDA runtime but jax-finufft potentially having another. Would that result in two CUDA contexts? Clearly it's already working somehow!
With (2), we would use the NVIDIA CUDA wheels on PyPI. I can't find any official documentation on them, but the Python CUDA tech lead did write this nice tutorial in the cuQuantum repo: https://github.com/NVIDIA/cuQuantum/tree/main/extra/demo_build_with_wheels
It's somewhat hacky, but the basic ideas are clear. We would set the rpath to find the pip-installed CUDA libraries, using some helper scripts and
auditwheel --exclude
to allow specific shared libraries. At runtime, the linker will look for the pip-installed CUDA, or user/system installations if that fails.Either way, I think the build itself can be done on cibuildwheel, probably just with a
yum install
of the CUDA development libraries (like this project does: https://github.com/OpenNMT/CTranslate2/blob/master/python/tools/prepare_build_environment_linux.sh).In terms of PyPI distribution, with CUDA minor version compatibility, I think we can just do what cupy does and use
jax-finufft-cuda12x
andjax-finufft-cuda11x
(if we want to support CUDA 11); no need for a custom package index URL. With (2), we would use[cuda_local]
and[cuda_pip]
extras. I don't think we need a full matrix of cuDNN versions like JAX does, but I could be wrong.This is all a bit experimental since this isn't a "vanilla" CUDA extension, but one that has to work with JAX! For that reason, (2) seems more appealing, since it seems more likely to find the same CUDA JAX does more often than not.
The text was updated successfully, but these errors were encountered: