admin管理员组文章数量:1410717
Question
I successfully built jaxlib
. However, how can I build jax[cuda12]
? I did not find any instructions in the jax documentation for building a specific wheel with the [cuda12]
tag. Thank you for your response!
Background
I have some applications running on jax
installed via pip install jax[cuda12]==0.4.34 jaxlib==0.4.34
. Recently, I encountered this issue when enabling both cuDNN
and persistent caching simultaneously: String field 'xla.gpu.CompilationResultProto.DnnCompiledGraphsEntry.value' contains invalid UTF-8 data
.
I noticed that this bug was fixed in the commit to the xla repository on January 9, 2025, and after that, the latest version of jax
was 0.5.0
. So, I reinstalled the latest jax
using pip install jax[cuda12]==0.5.0 jaxlib==0.5.0
after removing the old jax
. The aforementioned issue was resolved!
However, my application is not compatible with jax 0.5.0
, as it runs slower and encounters some NaN
errors. I decided to build jaxlib
and jax[cuda12]
using a local xla
repository.
What did I do
I pulled the jax
and xla
repositories and switched to specific commits using the following commands:
git clone --recurse-submodules .git
git clone --recurse-submodules .git
cd jax
# This corresponds to the [jax v0.4.34 release](.4.34) version.
git checkout affba367c5533df8900e32cbc3d31ca92dd1c1ea
git submodule update --init --recursive
cd ..
cd xla
# This is the XLA version defined in the [/jax/third_party/xla/workspace.bzl]() file from the `jax v0.4.34 release`.
git checkout cd6e808c59f53b40a99df1f1b860db9a3e598bff
git submodule update --init --recursive
After modifying the XLA source code to fix the bug, I read developer.md and used the following command to build jaxlib
:
python3 build/build.py \
--python_version=3.11 \
--enable_cuda \
--cuda_version=12.6.1 \
--cudnn_version=9.4.0 \
--bazel_options=--override_repository=xla=/home/thomas/xla \
--verbose
After a long wait, I received the following output:
Target //jaxlib/tools:build_wheel up-to-date:
bazel-bin/jaxlib/tools/build_wheel
INFO: Elapsed time: 1928.511s, Critical Path: 287.80s
INFO: 3390 processes: 24 internal, 3366 local.
INFO: Build completed successfully, 3390 total actions
INFO: Running command line: bazel-bin/jaxlib/tools/build_wheel '--output_path=/home/thomas/jax/dist' '--jaxlib_git_hash=affba367c5533df8900e32cbc3d31ca92dd1c1ea' '--cpu=x86_64'
Output wheel: /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl
To install the newly-built jaxlib wheel on system Python, run:
pip install /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl --force-reinstall
To install the newly-built jaxlib wheel on hermetic Python, run:
echo -e "\n/home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl" >> build/requirements.in
bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.11
Question
I successfully built jaxlib
. However, how can I build jax[cuda12]
? I did not find any instructions in the jax documentation for building a specific wheel with the [cuda12]
tag. Thank you for your response!
Background
I have some applications running on jax
installed via pip install jax[cuda12]==0.4.34 jaxlib==0.4.34
. Recently, I encountered this issue when enabling both cuDNN
and persistent caching simultaneously: String field 'xla.gpu.CompilationResultProto.DnnCompiledGraphsEntry.value' contains invalid UTF-8 data
.
I noticed that this bug was fixed in the commit to the xla repository on January 9, 2025, and after that, the latest version of jax
was 0.5.0
. So, I reinstalled the latest jax
using pip install jax[cuda12]==0.5.0 jaxlib==0.5.0
after removing the old jax
. The aforementioned issue was resolved!
However, my application is not compatible with jax 0.5.0
, as it runs slower and encounters some NaN
errors. I decided to build jaxlib
and jax[cuda12]
using a local xla
repository.
What did I do
I pulled the jax
and xla
repositories and switched to specific commits using the following commands:
git clone --recurse-submodules https://github/jax-ml/jax.git
git clone --recurse-submodules https://github/openxla/xla.git
cd jax
# This corresponds to the [jax v0.4.34 release](https://github/jax-ml/jax/tree/jax-v0.4.34) version.
git checkout affba367c5533df8900e32cbc3d31ca92dd1c1ea
git submodule update --init --recursive
cd ..
cd xla
# This is the XLA version defined in the [/jax/third_party/xla/workspace.bzl](https://github/jax-ml/jax/commit/aa9ee7abfab2344ce56483af29266a31ca7b7708) file from the `jax v0.4.34 release`.
git checkout cd6e808c59f53b40a99df1f1b860db9a3e598bff
git submodule update --init --recursive
After modifying the XLA source code to fix the bug, I read developer.md and used the following command to build jaxlib
:
python3 build/build.py \
--python_version=3.11 \
--enable_cuda \
--cuda_version=12.6.1 \
--cudnn_version=9.4.0 \
--bazel_options=--override_repository=xla=/home/thomas/xla \
--verbose
After a long wait, I received the following output:
Target //jaxlib/tools:build_wheel up-to-date:
bazel-bin/jaxlib/tools/build_wheel
INFO: Elapsed time: 1928.511s, Critical Path: 287.80s
INFO: 3390 processes: 24 internal, 3366 local.
INFO: Build completed successfully, 3390 total actions
INFO: Running command line: bazel-bin/jaxlib/tools/build_wheel '--output_path=/home/thomas/jax/dist' '--jaxlib_git_hash=affba367c5533df8900e32cbc3d31ca92dd1c1ea' '--cpu=x86_64'
Output wheel: /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl
To install the newly-built jaxlib wheel on system Python, run:
pip install /home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl --force-reinstall
To install the newly-built jaxlib wheel on hermetic Python, run:
echo -e "\n/home/thomas/jax/dist/jaxlib-0.4.34.dev20250218-cp311-cp311-manylinux2014_x86_64.whl" >> build/requirements.in
bazel run //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION=3.11
Share
Improve this question
edited Mar 5 at 23:19
talonmies
72.4k35 gold badges203 silver badges289 bronze badges
asked Mar 5 at 19:53
ThomasThomas
111 bronze badge
1 Answer
Reset to default 0Quoting from https://docs.jax.dev/en/latest/developer.html:
If you would like to build
jaxlib
and the CUDA plugins: Runpython build/build.py build --wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt
to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and clang, but it can be restricted to clang via the
--build_cuda_with_clang
flag.
The resulting jaxlib
, jax-cuda-plugin
, and jax-cuda-pjrt
wheels are local builds of the packages that jax[cuda12]
installs, as you can see in JAX's setup.py
definition: https://github/jax-ml/jax/blob/jax-v0.5.2/setup.py#L88-L91.
To install the equivalent of jax[cuda12]
with your local builds, you would install these three wheels manually.
本文标签: How to build jax with label cuda12 locallyStack Overflow
版权声明:本文标题:How to build jax with label [cuda12] locally? - Stack Overflow 内容由网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:http://www.betaflare.com/web/1745010624a2637530.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论