Skip to content

Note

Click here to download the full example code

Analysis of the scalable Sellar problem with JAX.

from __future__ import annotations

from datetime import timedelta
from timeit import default_timer
from typing import TYPE_CHECKING

from gemseo import configure
from gemseo import configure_logger
from gemseo import create_mda
from gemseo.core.discipline.discipline import Discipline
from gemseo.problems.mdo.sellar.sellar_1 import Sellar1
from gemseo.problems.mdo.sellar.sellar_2 import Sellar2
from gemseo.problems.mdo.sellar.sellar_system import SellarSystem
from gemseo.problems.mdo.sellar.utils import get_initial_data
from matplotlib.pyplot import show
from matplotlib.pyplot import subplots
from numpy import array
from numpy.random import default_rng

from gemseo_jax.jax_chain import DifferentiationMethod
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem

if TYPE_CHECKING:
    from gemseo.mda.base_mda import BaseMDA
    from gemseo.typing import RealArray

# Deactivate some checkers to speed up calculations in presence of cheap disciplines.
configure(False, False, True, False, False, False, False)
configure_logger()


def get_random_input_data(n: int) -> dict[str, RealArray]:
    """Return a random input value for [JAX]SellarSystem."""
    r_float = default_rng().random()
    return {
        name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()
    }


def get_numpy_disciplines(n: int) -> list[Discipline]:
    """Return the NumPy-based Sellar disciplines."""
    return [
        Sellar1(n=n),
        Sellar2(n=n),
        SellarSystem(n=n),
    ]


def get_jax_disciplines(
    n: int, differentiation_method=DifferentiationMethod.AUTO
) -> list[Discipline]:
    """Return the JAX-based Sellar disciplines."""
    disciplines = [
        JAXSellar1(n=n, differentiation_method=differentiation_method),
        JAXSellar2(n=n, differentiation_method=differentiation_method),
        JAXSellarSystem(n=n, differentiation_method=differentiation_method),
    ]
    for disc in disciplines:
        disc.set_cache(Discipline.CacheType.SIMPLE)
        disc.compile_jit()

    return disciplines

Initial setup for comparison

Here we intend to compare the original NumPy implementation with the JAX one. We then need to create the original MDA and one JAX MDA for each configuration we're testing. In this example we compare the performance of the JAXChain encapsulation and also the forward and reverse modes for automatic differentiation.

def get_analytical_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with analytical NumPy Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_numpy_disciplines(n),
        max_mda_iter=max_mda_iter,
        name="Analytical SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_forward_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with JAX forward-mode AD Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_chained_forward_ad_mda(
    n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
    """Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian."""
    discipline = JAXSellarChain(
        n=n,
        differentiation_method=DifferentiationMethod.FORWARD,
    )
    discipline.add_differentiated_inputs(discipline.input_grammar.names)
    discipline.add_differentiated_outputs(discipline.output_grammar.names)

    mda = create_mda(
        mda_name=mda_name,
        disciplines=[discipline],
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_reverse_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with JAX reverse-mode AD Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_chained_reverse_ad_mda(
    n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
    """Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian."""
    discipline = JAXSellarChain(
        n=n,
        differentiation_method=DifferentiationMethod.REVERSE,
    )
    discipline.add_differentiated_inputs(discipline.input_grammar.names)
    discipline.add_differentiated_outputs(discipline.output_grammar.names)

    mda = create_mda(
        mda_name=mda_name,
        disciplines=[discipline],
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


mdas = {
    "MDOChain[NumPy] - Analytical": get_analytical_mda,  # this is the reference
    "JAXChain - Forward AD": get_chained_forward_ad_mda,
    "JAXChain - Reverse AD": get_chained_reverse_ad_mda,
    "MDOChain[JAX] - Forward AD": get_forward_ad_mda,
    "MDOChain[JAX] - Reverse AD": get_reverse_ad_mda,
}

Execution and linearization scalability

Let's make a function to execute and linearize an MDA, while logging times. Also, we run several repetitions to avoid noisy results:

def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):
    mda = get_mda(dimension, **mda_options)
    t0 = default_timer()
    for _i in range(n_repeat):
        mda.execute({
            name: value
            for name, value in get_random_input_data(dimension).items()
            if name in mda.input_grammar.names
        })
    t1 = default_timer()
    t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)

    t2 = default_timer()
    for _i in range(n_repeat):
        mda.linearize({
            name: value
            for name, value in get_random_input_data(dimension).items()
            if name in mda.input_grammar.names
        })
    t3 = default_timer()
    t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)
    return t_execute, t_linearize

Run the MDA for each of the mdas, for several number of dimensions

dimensions = [1, 10, 100, 1000]
times = {}
mda_config = {"mda_name": "MDAGaussSeidel", "max_mda_iter": 1}
for mda_name, mda_func in mdas.items():
    time_exec = []
    time_lin = []
    for dimension in dimensions:
        t_e, t_l = run_and_log(mda_func, dimension, **mda_config)
        time_exec.append(t_e)
        time_lin.append(t_l)
    times[mda_name] = (array(time_exec), array(time_lin))

Out:

 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.39364291855087 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.9534234284979757 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.675020367604622 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2047636603086361 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.398361923727751 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6945059763804774 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 16.92718729949267 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.56898194611222 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.925291885818622 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.354826399201734 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 16.824703062780745 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 9.516373246207907 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.583703014575516 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9011736823389935 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.8151237353023584 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.9404504263062154 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5371719806020074 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6099966111525569 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2946921057860779 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8593265740496864 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9989277287445051 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.21617283273903928 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5151470603794642 is still above the tolerance 1e-06.
 WARNING - 13:32:07: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3394395256053435 is still above the tolerance 1e-06.
    INFO - 13:32:08: Compilation of the output function JAXSellarChain: 0:00:00.031645 seconds.
    INFO - 13:32:08: Compilation of the Jacobian function JAXSellarChain: 0:00:00.043666 seconds.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.678166982299423 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9763856484955736 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.02927113568576234 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0247041565800417 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7493656186750602 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4765214680221626 is still above the tolerance 1e-06.
    INFO - 13:32:08: Compilation of the output function JAXSellarChain: 0:00:00.053851 seconds.
    INFO - 13:32:08: Compilation of the Jacobian function JAXSellarChain: 0:00:00.142314 seconds.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.126029737777008 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.352812789737491 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8837919481941034 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.033587341273381 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.13171624802745543 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3658491107436914 is still above the tolerance 1e-06.
    INFO - 13:32:08: Compilation of the output function JAXSellarChain: 0:00:00.063318 seconds.
    INFO - 13:32:08: Compilation of the Jacobian function JAXSellarChain: 0:00:00.155936 seconds.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7425264212880237 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4586568442479196 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.51346851236008 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4008021955165494 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0790664409884514 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.704602543258376 is still above the tolerance 1e-06.
    INFO - 13:32:08: Compilation of the output function JAXSellarChain: 0:00:00.060707 seconds.
    INFO - 13:32:08: Compilation of the Jacobian function JAXSellarChain: 0:00:00.191866 seconds.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7956653748244913 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7427131568541528 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6087602753655843 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1927596528408008 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7112934168748672 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2903902715145544 is still above the tolerance 1e-06.
    INFO - 13:32:08: Compilation of the output function JAXSellarChain: 0:00:00.029547 seconds.
    INFO - 13:32:08: Compilation of the Jacobian function JAXSellarChain: 0:00:00.077648 seconds.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8835664244367091 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8741650489164234 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.509374221413033 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2357968664978707 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3985578375657744 is still above the tolerance 1e-06.
 WARNING - 13:32:08: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.43408182281638624 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellarChain: 0:00:00.046240 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarChain: 0:00:00.094508 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.064310120293589 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.9996787479829177 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1210985382270575 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.325384914172193 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.9329331307178428 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6796712756593788 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellarChain: 0:00:00.058935 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarChain: 0:00:00.113328 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.11157556881930582 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.16516724217422377 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7499208652987416 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6480813009263496 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1831947444099682 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5341351428681627 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellarChain: 0:00:00.057942 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarChain: 0:00:00.129053 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3056464599528547 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.91753576122124 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0480898252472721 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.071641102625985 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7882162741344225 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.3643471035435604 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.016799 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000078 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.014375 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000116 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.024604 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000148 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.637327217871406 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0313494632306281 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4814959826722152 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5118496780464132 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.16700680607037027 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7800166857941228 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.020592 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000097 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.016549 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000073 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.039580 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000097 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5215504735083683 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.966770550840666 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.631360238346746 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.776128798608644 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.48401744194432 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.9825100302617473 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.018788 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000077 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.017283 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000099 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.044921 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000101 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1064608573643537 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5609652594030418 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.37183401487949297 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5113590565086871 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6331357100426023 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.23619151251660714 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.018443 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000089 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.017324 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000079 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.042531 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000101 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1536269998313766 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7770858687509218 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5841895648807953 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3801701148244149 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0172391068411897 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0547102065611327 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.015929 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000074 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.012761 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000074 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.019056 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000078 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4226123168235516 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4560837361308947 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2106228204751595 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0985713068848777 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.14204133444616 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8791507056028839 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.019712 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000082 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.015963 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000076 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.033897 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000088 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.45472843464555335 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.07689941610222455 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.20710029700972898 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.47769200910404197 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.18166234758014613 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.07260834916740014 is still above the tolerance 1e-06.
    INFO - 13:32:09: Compilation of the output function JAXSellar1: 0:00:00.017764 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar1: 0:00:00.000080 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellar2: 0:00:00.028124 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellar2: 0:00:00.000075 seconds.
    INFO - 13:32:09: Compilation of the output function JAXSellarSystem: 0:00:00.044984 seconds.
    INFO - 13:32:09: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000100 seconds.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3068159104197259 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6757267384277751 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6294775286517358 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9853641861128816 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.012953242625579795 is still above the tolerance 1e-06.
 WARNING - 13:32:09: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8905461117482197 is still above the tolerance 1e-06.
    INFO - 13:32:10: Compilation of the output function JAXSellar1: 0:00:00.018225 seconds.
    INFO - 13:32:10: Compilation of the Jacobian function JAXSellar1: 0:00:00.000084 seconds.
    INFO - 13:32:10: Compilation of the output function JAXSellar2: 0:00:00.016939 seconds.
    INFO - 13:32:10: Compilation of the Jacobian function JAXSellar2: 0:00:00.000072 seconds.
    INFO - 13:32:10: Compilation of the output function JAXSellarSystem: 0:00:00.042119 seconds.
    INFO - 13:32:10: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000116 seconds.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7297273680817066 is still above the tolerance 1e-06.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8111852478085803 is still above the tolerance 1e-06.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4651331388563826 is still above the tolerance 1e-06.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6119381910104982 is still above the tolerance 1e-06.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0454008349811905 is still above the tolerance 1e-06.
 WARNING - 13:32:10: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.19163293399342432 is still above the tolerance 1e-06.

Now let's visualize our results:

mda_ref = next(iter(mdas.keys()))
t_ref = times[mda_ref]
speedup = {
    mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()
}

fig, axes = subplots(2, 1, layout="constrained", figsize=(6, 8))
fig.suptitle("Speedup compared to NumPy Analytical")
for mda_name in mdas:
    linestyle = ":" if mda_name == mda_ref else "-"
    speedup_e, speedup_l = speedup[mda_name]
    axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)
    axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)
axes[0].legend(bbox_to_anchor=(0.9, -0.1))
axes[0].set_ylabel("Execution")
axes[0].set_xscale("log")
axes[1].set_ylabel("Linearization")
axes[1].set_xlabel("Dimension")
axes[1].set_xscale("log")
show()

Speedup compared to NumPy Analytical

Conclusion

JAX AD is as fast as analytical derivatives with NumPy. Encapsulation with JAXChain slows execution, but speeds-up linearization. Speedup is maintained even at higher dimensions.

Total running time of the script: ( 0 minutes 2.681 seconds)

Download Python source code: plot_sellar_scalable.py

Download Jupyter notebook: plot_sellar_scalable.ipynb

Gallery generated by mkdocs-gallery