Skip to content

Machine Learning Packages

GPU Accelerated Support

This matrix shows which GPU-accelerated Machine Learning (ML) packages are supported under pip, Conda, or inside a container for Linux Arm64 (aarch64). We are working on providing support through Conda or a container on all of these packages. We do not plan to provide pip wheels. The installation methods are detailed below.

ML Framework Pip Conda Container
PyTorch PyTorch
HuggingFace HuggingFace
TensorFlow TensorFlow
JAX JAX
Flash-Attention

Prerequisites

Please see the relevant documentation for using conda or for running containers. The containers listed below will be the Nvidia optimised images available in Nvidia GPU Cloud. When using containers ensure that you are using images with support for the ARM64 architecture.

Tip

The Arm architecture and Hopper GPUs used in Isambard-AI typically require modern versions throughout the machine learning software stack. Where possible, prefer more recent releases of machine learning packages as this usually ensures easy installation and optimal performance for your applications.

Click through the tabs below to find installation instructions for the respective package.

Pytorch does not currently build wheels for using a GPU under pip for aarch64. The release compatibility matrix can be found here.

conda-forge provides a PyTorch 2.5 package with aarch64, Cuda (GPU), and Numpy support:

$ conda install conda-forge::pytorch

We recommend these Pytorch images from Nvidia GPU Cloud.

For example, you can pull and run a container that provides pytorch GPU support like so:

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/pytorch:24.07-py3
$ singularity run --nv pytorch:24.07-py3.sif python3 -c "import torch; print(torch.cuda.is_available())"

Tensorflow does not currently provide support for GPU-enabled aarch64 through pip, you can see the pip wheels available here:

Linux Arm64 builds of Tensorflow from conda-forge is currently work-in-progress.

Tensorflow can currently be run with GPU support inside a container. The recommended container can be found here:

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/tensorflow:24.07-tf2-py3
$ singularity run --nv tensorflow_24.07-tf2-py3.sif 
$ Singularity> python3 -c "import tensorflow as tf; print(tf.config.list_physical_devices('GPU'))"

Jax provides aarch64 and GPU compatibility out-of-the-box. You can install it in pip:

pip install jax[cuda12]

Jax can be installed in conda using:

conda install -c conda-forge jax[cuda12]

We recommend these jax images from Nvidia GPU Cloud. We use the --no-home flag to ensure python does not clash with your conda environment if you have one. Note that you will have to unset the XLA_FLAGS environment variable due to issues with the Nvidia container image.

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/jax:24.04-py3
$ singularity run --nv --no-home jax_24.04-py3.sif 
$ Singularity> unset XLA_FLAGS
$ Singularity> python3 -c "import jax; print(jax.default_backend())"

Hugging Face can run many different frameworks as its backend. Please follow the instructions for the respective framework.

Flash Attention data types for Hopper GPUs

The versions of Flash-Attention that support sm_90 (Hopper GPUs) do not support fp8.

Please use either of the bfloat16 or fp16 floating point data types.

Flash-Attention does not provide pip wheels for aarch64. Installing the package using pip will cause the package to be built from source. We recommend installing Flash-Attention through the flash-attn pip package in a Conda environment.

Flash-Attention has certain dependencies to be built/installed in a Conda environment.

flash_attn_conda_env.yml
name: flash_env
channels:
  - conda-forge
  - nodefaults
dependencies:
  - python=3.10
  - pytorch=2.5.1
  - ninja
  - gcc=12.3.0
  - gxx=12.3.0
  - pip
  - pip:
      - flash-attn==2.7.0.post2

We recommend building on a compute node after setting MAX_JOBS since the build is demanding. If MAX_JOBS is set too high, the build may exhaust the memory on the node. Please note this is a slow build.

$ module load cudatoolkit
$ export MAX_JOBS=10
$ conda env create -f flash_attn_conda_env.yml

NGC pytorch containers have flash-attention pre-bundled.

$ singularity pull --arch aarch64 docker://nvcr.io/nvidia/pytorch:24.04-py3
$ singularity run --nv pytorch:24.04-py3.sif python3 -m pip list | grep flash
flash-attn                2.4.2