Skip to content

User guide

JAX

JAX is a library for array-oriented numerical computation (similar to NumPy), with automatic differentiation (AD) and just-in-time (JIT) compilation to enable hardware accelerated numerical computing.

JAX is heavily used for large-scale machine learning research, but many of its benefits can also be used to leverage scientific computing as a whole. In the context of Multidisciplinary Optimization (MDO), we use JAX to avoid manual implementation of derivatives of objective functions and constraints wrt optimization variables, which allows for using gradient-based optimizers without an extra implementation cost.

There are other libraries that allows for AD in python (Autograd, TensorFlow, PyTorch), but JAX was chosen due to its: - Wide community and ecosystem of libraries, projects and other associated resources; - Hardware-aware configuration optimizations on CPU, GPU and TPU; - Focus on general scientific computing rather than machine-learning problems.

In our experience, writing MDO problems with GEMSEO-JAX means less code and usually faster programs than their NumPy implementations (Is JAX faster than NumPy?).

For an initial introduction to JAX, we recommend reading the Quickstart, Key Concepts and How to Think in JAX notebooks. For a better grasp of JAX functionalities and how to efficiently use it, we recommed Just-in-time Compilation and The Autodiff Cookbook.

GEMSEO-JAX overview

The plugin is centered around JAXDiscipline, which wraps a JAX function, with built-in automatic differentiation. This class provides useful functionalities:

  • filtering of the Jacobian computation graph for specific inputs/outputs,
  • jit compiling function and jacobian call for lowering cost of re-evaluation,
  • performing pre-run's to trigger and log compilation times.

AutoJAXDiscipline is a special JAXDiscipline inferring the input names, output names and default input values from the signature of the JAX function, in the manner of AutoPyDiscipline.

JAXChain is a JAXDiscipline allowing to assemble a series of JaxDisciplines and execute them all in JAX. This is useful to avoid meaningless JAX-to/from-NumPy conversions.

Quick guide

from jax.numpy import exp, sqrt
from gemseo_jax.auto_jax_discipline import AutoJAXDiscipline


def compute_y_1(y_2=1.0, x_local=1.0, x_shared_1=1.0, x_shared_2=3.0):
    y_1 = x_shared_1**2 + x_shared_2 + x_local - 0.2 * y_2
    return y_1


def compute_y_2(y_1=1.0, x_shared_1=1.0, x_shared_2=3.0):
    y_2 = sqrt(abs(y_1)) + x_shared_1 + x_shared_2
    return y_2


def compute_obj_c_1_c_2(y_1=1.0, y_2=1.0, x_shared_2=3.0, x_local=1.0):
    obj = x_local**2 + x_shared_2 + y_1**2 + exp(-y_2)
    c_1 = 3.16 - y_1**2
    c_2 = y_2 - 24.0
    return obj, c_1, c_2


sellar_1 = AutoJAXDiscipline(compute_y_1, name="Sellar1")
sellar_2 = AutoJAXDiscipline(compute_y_2, name="Sellar2")
sellar_system = AutoJAXDiscipline(compute_obj_c_1_c_2, name="SellarSystem")
Here, the JAX functions are defined and automatically wrapped into GEMSEO disciplines with AutoJAXDiscipline. The Jacobians are automatically calculated using Automatic Differentiation (AD). By default we "agressively" promote double precision in all JAXDiscipline, which differs from JAX's default single-precision.

These disciplines can already be used in a GEMSEO process, but may lead to sub-optimal performance due to excessive number of conversions from NumPy(GEMSEO) to JAX arrays. To avoid this, a JAXChain can be used to keep the communication between disciplines all inside JAX, here GEMSEO is only used to generate the sequence of discipline execution.

from gemseo_jax.jax_chain import JAXChain


disciplines = [sellar_1, sellar_2, sellar_system]
jax_chain = JAXChain(disciplines, name="SellarChain")
We can also filter the Jacobian function to ensure AD is only made for some outputs of interest. In practice, this means JAX views fewer operations to trace and apply AD over.

jax_chain.add_differentiated_outputs(["obj", "c_1", "c_2"])

Finally, we may jit-compile the output and Jacobian functions that will be used. This takes an extra compilation time, but lowers significantly the cost of function calls.

As compilation is just-in-time, by default we make a pre-run of the jitted functions and log compilation times. This ensures compilation is not added to execution timing, but some comprehensive benchmarks may turn this off.

jax_chain.compile_jit(pre_run=True)