Skip to content

Machine Learning Frameworks

GPU Accelerated Support

This matrix shows which GPU-accelerated Machine Learning (ML) frameworks are supported under pip, conda, or inside a container. We are working on providing support through conda or a container on all of these packages. We do not plan to provide pip wheels. The installation methods are detailed below.

ML Framework Logo Pip Conda Container
PyTorch PyTorch
HuggingFace HuggingFace
TensorFlow TensorFlow
JAX JAX

Prerequisites

Please see the relevant documentation for using conda or for running containers. The containers listed below will be the Nvidia optimised images available in Nvidia GPU Cloud. When using containers ensure that you are using images with support for the ARM64 architecture.

Click through the tabs below to find installation instructions for the respective framework.

Pytorch does not currently build wheels for using a GPU under pip for aarch64. The release compatibility matrix can be found here.

However, a community build is available through conda-forge which will support GPU functionality:

$ conda install -c conda-forge pytorch-gpu

We recommend these Pytorch images from Nvidia GPU Cloud.

For example, you can pull and run a container that provides pytorch GPU support like so:

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/pytorch:24.07-py3
$ singularity run --nv pytorch:24.07-py3.sif python3 -c "import torch; print(torch.cuda.is_available())"

Tensorflow does not currently provide support through pip, you can see the pip wheels available here:

Tensorflow under conda is currently work-in-progress.

Tensorflow can currently be run with GPU support inside a container. The recommended container can be found here:

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/tensorflow:24.07-tf2-py3
$ singularity run --nv tensorflow_24.07-tf2-py3.sif 
$ Singularity> python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"

Jax provides aarch64 and GPU compatibility out-of-the-box. You can install it in pip:

pip install jax[cuda12]

Jax can be installed in conda using:

conda install -c conda-forge jax[cuda12]

We recommend these jax images from Nvidia GPU Cloud. We use the --no-home flag to ensure python does not clash with your conda environment if you have one. Note that you will have to unset the XLA_FLAGS environment variable due to issues with the Nvidia container image.

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/jax:24.04-py3
$ singularity run --nv --no-home jax_24.04-py3.sif 
$ Singularity> unset XLA_FLAGS
$ Singularity> python3 -c "import jax; print(jax.default_backend())"

Hugging Face can run many different frameworks as its backend. Please follow the instructions for the respective framework.