guides.lw1.at


JAX Tips and Tricks

published 2025-04-07 (last changed on 2025-04-09) by Lukas Winkler

This is an unordered list of useful things I found while using JAX, that don’t seem to be documented well somewhere else. Partially this is because these features are experimental, so don’t depend on them working the exact same way in future versions of JAX. Nevertheless, they might be useful during development and testing and have been tested using JAX 0.5.3 as of writing.

# jax-array-info

This is a collection of tools I wrote, that allow printing some useful information about JAX arrays and their properties:

from jax_array_info import sharding_info, sharding_vis, simple_array_info, print_array_stats, pretty_memory_stats
some_array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
some_array = jax.device_put(some_array, NamedSharding(mesh, P(None, "gpus")))
sharding_info(some_array, "some_array")
╭───────────────── some_array ─────────────────╮
│ shape: (128, 128, 128)                       │
│ dtype: float32                               │
│ size: 8.0 MiB                                │
│ NamedSharding: P(None, 'gpus')               │
│ axis 1 is sharded: CPU 0 contains 0:16 (1/8) │
│                    Total size: 128           │
╰──────────────────────────────────────────────╯

You can find a description of all features and many test cases of sharded JAX operations in the GitHub repository.

# Visualizing the HLO Graph

One useful thing during debugging is to visualize the computational graph JAX creates (especially after JIT compilation) as a visual graph.

import subprocess
from pathlib import Path
import jax
from jax._src.stages import Compiled
from jaxlib import xla_client

def todotgraph(x: str) -> str:
    return xla_client._xla.hlo_module_to_dot_graph(xla_client._xla.hlo_module_from_text(x))

def write_debug_graph(compiled: Compiled, output: Path):
    hlo = compiled.as_text()
    output.with_suffix(".hlo.txt").write_text(hlo)

    graph_file = output.with_suffix(".graph.dot")
    graph_file.write_text(todotgraph(hlo))

    with output.with_suffix(".graph.svg").open("w") as f:
        subprocess.run(["dot", "-Tsvg", graph_file], stdout=f)

Using the two above helper functions we can take a jit-compiled (and lowered) function and let XLA generate a Graphviz graph for all operations. The call to dot -Tsvg input.dot requires graphviz to be installed.

def simple_function(x):
    return jax.numpy.sin(x[:3])

simple_function = jax.jit(simple_function)

input = jax.numpy.zeros((20,))

write_debug_graph(
    simple_function.lower(input).compile(), 
    Path("simple_function")
)
The output graph

If you open the generated SVG image in a web browser, you can also read additional metadata by hovering over nodes.

# Visualizing the HLO Graph without code modifications

There is an alternative way to achieve the same result as the previous section (compute graph as a figure), but without any code modifications and for all intermediate functions (before and after optimization).

For this we call our python script with the --xla_dump_to=tmp flag to dump all intermediate XLA results into tmp/ and --xla_dump_hlo_as_html=true to also generate HTML output.

import jax

def simple_function(x):
    return jax.numpy.sin(x[:3])

simple_function = jax.jit(simple_function)
input = jax.numpy.zeros((20,))

simple_function(input)
$ XLA_FLAGS="--xla_dump_hlo_as_html=true --xla_dump_to=tmp" python dot_test2.py

The resulting output in tmp/ when opened in a browser will be similar to the one created in the previous section, but with interactive zoom and pan.

# Additional XLA options

Only a small subset of the possible XLA options are documented in the JAX documentation. The full list can be extracted from the XLA --help command.

import os

def add_xla_flag(flag: str) -> None:
    if "XLA_FLAGS" not in os.environ:
        os.environ["XLA_FLAGS"] = flag
        return

    os.environ["XLA_FLAGS"] = os.environ["XLA_FLAGS"] + " " + flag

add_xla_flag('--xla_force_host_platform_device_count=4')
add_xla_flag('--xla_dump_to=xla_dump_directory')

A very straightforward way to print some key information about the current setup.

import jax
jax.print_environment_info()
jax:    0.5.3
jaxlib: 0.5.3
numpy:  2.2.4
python: 3.13.2 (main, Mar 29 2025, 10:04:43) [GCC 14.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='lukasnotebook', release='6.12.21-amd64', version='#1 SMPPREEMPT_DYNAMIC Debian 6.12.21-1 (2025-03-30)', machine='x86_64')

You have any feedback or ideas to improve this? Contact me per E-Mail or edit the source directly. You can find my other projects at lw1.at.