Skip to content

Note

Click here to download the full example code

Analysis of the scalable Sellar problem with JAX.

from __future__ import annotations

from datetime import timedelta
from timeit import default_timer
from typing import TYPE_CHECKING

from gemseo import configure
from gemseo import configure_logger
from gemseo import create_mda
from gemseo.core.discipline.discipline import Discipline
from gemseo.problems.mdo.sellar.sellar_1 import Sellar1
from gemseo.problems.mdo.sellar.sellar_2 import Sellar2
from gemseo.problems.mdo.sellar.sellar_system import SellarSystem
from gemseo.problems.mdo.sellar.utils import get_initial_data
from matplotlib.pyplot import show
from matplotlib.pyplot import subplots
from numpy import array
from numpy.random import default_rng

from gemseo_jax.jax_chain import DifferentiationMethod
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem

if TYPE_CHECKING:
    from gemseo.mda.base_mda import BaseMDA
    from gemseo.typing import RealArray

# Deactivate some checkers to speed up calculations in presence of cheap disciplines.
configure(False, False, True, False, False, False, False)
configure_logger()


def get_random_input_data(n: int) -> dict[str, RealArray]:
    """Return a random input value for [JAX]SellarSystem."""
    r_float = default_rng().random()
    return {
        name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()
    }


def get_numpy_disciplines(n: int) -> list[Discipline]:
    """Return the NumPy-based Sellar disciplines."""
    return [
        Sellar1(n=n),
        Sellar2(n=n),
        SellarSystem(n=n),
    ]


def get_jax_disciplines(
    n: int, differentiation_method=DifferentiationMethod.AUTO
) -> list[Discipline]:
    """Return the JAX-based Sellar disciplines."""
    disciplines = [
        JAXSellar1(n=n, differentiation_method=differentiation_method),
        JAXSellar2(n=n, differentiation_method=differentiation_method),
        JAXSellarSystem(n=n, differentiation_method=differentiation_method),
    ]
    for disc in disciplines:
        disc.set_cache(Discipline.CacheType.SIMPLE)
        disc.compile_jit()

    return disciplines

Initial setup for comparison

Here we intend to compare the original NumPy implementation with the JAX one. We then need to create the original MDA and one JAX MDA for each configuration we're testing. In this example we compare the performance of the JAXChain encapsulation and also the forward and reverse modes for automatic differentiation.

def get_analytical_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with analytical NumPy Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_numpy_disciplines(n),
        max_mda_iter=max_mda_iter,
        name="Analytical SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_forward_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with JAX forward-mode AD Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_chained_forward_ad_mda(
    n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
    """Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian."""
    discipline = JAXSellarChain(
        n=n,
        differentiation_method=DifferentiationMethod.FORWARD,
    )
    discipline.add_differentiated_inputs(discipline.input_grammar.names)
    discipline.add_differentiated_outputs(discipline.output_grammar.names)

    mda = create_mda(
        mda_name=mda_name,
        disciplines=[discipline],
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_reverse_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with JAX reverse-mode AD Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_chained_reverse_ad_mda(
    n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
    """Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian."""
    discipline = JAXSellarChain(
        n=n,
        differentiation_method=DifferentiationMethod.REVERSE,
    )
    discipline.add_differentiated_inputs(discipline.input_grammar.names)
    discipline.add_differentiated_outputs(discipline.output_grammar.names)

    mda = create_mda(
        mda_name=mda_name,
        disciplines=[discipline],
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


mdas = {
    "MDOChain[NumPy] - Analytical": get_analytical_mda,  # this is the reference
    "JAXChain - Forward AD": get_chained_forward_ad_mda,
    "JAXChain - Reverse AD": get_chained_reverse_ad_mda,
    "MDOChain[JAX] - Forward AD": get_forward_ad_mda,
    "MDOChain[JAX] - Reverse AD": get_reverse_ad_mda,
}

Execution and linearization scalability

Let's make a function to execute and linearize an MDA, while logging times. Also, we run several repetitions to avoid noisy results:

def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):
    mda = get_mda(dimension, **mda_options)
    t0 = default_timer()
    for _i in range(n_repeat):
        mda.execute({
            name: value
            for name, value in get_random_input_data(dimension).items()
            if name in mda.input_grammar.names
        })
    t1 = default_timer()
    t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)

    t2 = default_timer()
    for _i in range(n_repeat):
        mda.linearize({
            name: value
            for name, value in get_random_input_data(dimension).items()
            if name in mda.input_grammar.names
        })
    t3 = default_timer()
    t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)
    return t_execute, t_linearize

Run the MDA for each of the mdas, for several number of dimensions

dimensions = [1, 10, 100, 1000]
times = {}
mda_config = {"mda_name": "MDAGaussSeidel", "max_mda_iter": 1}
for mda_name, mda_func in mdas.items():
    time_exec = []
    time_lin = []
    for dimension in dimensions:
        t_e, t_l = run_and_log(mda_func, dimension, **mda_config)
        time_exec.append(t_e)
        time_lin.append(t_l)
    times[mda_name] = (array(time_exec), array(time_lin))

Out:

 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0017626268968514 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5490814097199339 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 7.982008573505464 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.9660036761568098 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.266635682412455 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.3680660664816737 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1583143200993138 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3756535944519814 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6629268720820956 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.146227410286926 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6329134859030284 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.40367504393253373 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7798147902236297 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6547630241504995 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0812850734538066 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0295503220122724 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9603941295542312 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.770766630472956 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9826225940604202 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2792157760529035 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.078822293313204 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.19618268236377787 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9578517329495027 is still above the tolerance 1e-06.
 WARNING - 12:32:30: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6882782691411637 is still above the tolerance 1e-06.
    INFO - 12:32:30: Compilation of the output function JAXSellarChain: 0:00:00.036878 seconds.
    INFO - 12:32:30: Compilation of the Jacobian function JAXSellarChain: 0:00:00.043214 seconds.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0153586038526248 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9244178584044795 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.221834619227134 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.26444795982149116 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8826216164820436 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5612831233979966 is still above the tolerance 1e-06.
    INFO - 12:32:30: Compilation of the output function JAXSellarChain: 0:00:00.064981 seconds.
    INFO - 12:32:30: Compilation of the Jacobian function JAXSellarChain: 0:00:00.152638 seconds.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.711942594699366 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9975788686267378 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7824303243072015 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.17135823035876915 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6985051937541462 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2754635068259765 is still above the tolerance 1e-06.
    INFO - 12:32:30: Compilation of the output function JAXSellarChain: 0:00:00.075539 seconds.
    INFO - 12:32:30: Compilation of the Jacobian function JAXSellarChain: 0:00:00.153766 seconds.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.300526981537987 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.07371335595689667 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9018360644465165 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.292998476839272 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.43369765805663635 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.20852203835536595 is still above the tolerance 1e-06.
    INFO - 12:32:30: Compilation of the output function JAXSellarChain: 0:00:00.075869 seconds.
    INFO - 12:32:30: Compilation of the Jacobian function JAXSellarChain: 0:00:00.185739 seconds.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4557850473869791 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.810140869981412 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5814521930810372 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6557194185660281 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8657544654809559 is still above the tolerance 1e-06.
 WARNING - 12:32:30: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7251031156187584 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellarChain: 0:00:00.040932 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarChain: 0:00:00.081095 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.671077282005725 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 9.75916530816569 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.6241737545410775 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 11.479849224270604 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 14.164523003972102 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 9.909893801672371 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellarChain: 0:00:00.060412 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarChain: 0:00:00.098568 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7648551374335328 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.12670597625369961 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.4804170753535906 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.034291552121516 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5390084806625017 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3690673047895956 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellarChain: 0:00:00.070307 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarChain: 0:00:00.135509 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.587623588836288 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.715896291039531 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.262875447305985 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.021409150860409 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.481249163824006 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.770954830211123 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellarChain: 0:00:00.068544 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarChain: 0:00:00.119753 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 13.549027279415084 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 9.377124011081483 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.107958702432196 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 24.15432779084192 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.100231366090937 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 7.77789488525749 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellar1: 0:00:00.018312 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellar1: 0:00:00.000267 seconds.
    INFO - 12:32:31: Compilation of the output function JAXSellar2: 0:00:00.017079 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellar2: 0:00:00.000218 seconds.
    INFO - 12:32:31: Compilation of the output function JAXSellarSystem: 0:00:00.025906 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000124 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 148.68966778386482 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 172.19331134736817 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 115.93759021007122 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 140.8340570090415 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 30.6023862961941 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 79.66702182978814 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellar1: 0:00:00.020983 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellar1: 0:00:00.000102 seconds.
    INFO - 12:32:31: Compilation of the output function JAXSellar2: 0:00:00.019569 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellar2: 0:00:00.000093 seconds.
    INFO - 12:32:31: Compilation of the output function JAXSellarSystem: 0:00:00.059662 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000118 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.1776489131043273 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4295481386238217 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.13716353047485633 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2568041388976283 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4678809826703939 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5517870696518166 is still above the tolerance 1e-06.
    INFO - 12:32:31: Compilation of the output function JAXSellar1: 0:00:00.020523 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellar1: 0:00:00.000138 seconds.
    INFO - 12:32:31: Compilation of the output function JAXSellar2: 0:00:00.020351 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellar2: 0:00:00.000096 seconds.
    INFO - 12:32:31: Compilation of the output function JAXSellarSystem: 0:00:00.052165 seconds.
    INFO - 12:32:31: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000113 seconds.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.044886183183311286 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3245341128310147 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6312704921448212 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3344600801349285 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0684477798947167 is still above the tolerance 1e-06.
 WARNING - 12:32:31: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.31676870899022946 is still above the tolerance 1e-06.
    INFO - 12:32:32: Compilation of the output function JAXSellar1: 0:00:00.020176 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar1: 0:00:00.000108 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellar2: 0:00:00.019718 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar2: 0:00:00.000101 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellarSystem: 0:00:00.052206 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000122 seconds.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5578637965005567 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.2563027499272725 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3662130564103718 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0182168619701872 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.10710665622489159 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.5173390766740615 is still above the tolerance 1e-06.
    INFO - 12:32:32: Compilation of the output function JAXSellar1: 0:00:00.018787 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar1: 0:00:00.000102 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellar2: 0:00:00.016504 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar2: 0:00:00.000096 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellarSystem: 0:00:00.025915 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000123 seconds.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3807239716123294 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.569451792989252 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0579847558097708 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2394532129026652 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.92783897106372 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.60460685930412 is still above the tolerance 1e-06.
    INFO - 12:32:32: Compilation of the output function JAXSellar1: 0:00:00.020962 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar1: 0:00:00.000107 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellar2: 0:00:00.019872 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar2: 0:00:00.000097 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellarSystem: 0:00:00.045883 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000134 seconds.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.07769768600313676 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1373394843172262 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8915925532140356 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.584104007716749 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5200713048190595 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2143734574330927 is still above the tolerance 1e-06.
    INFO - 12:32:32: Compilation of the output function JAXSellar1: 0:00:00.021384 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar1: 0:00:00.000107 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellar2: 0:00:00.019978 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar2: 0:00:00.000098 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellarSystem: 0:00:00.054417 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000144 seconds.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7267279142049554 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3868909863376825 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6370148052280161 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4906975644575014 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4184441514323642 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.38799505441220683 is still above the tolerance 1e-06.
    INFO - 12:32:32: Compilation of the output function JAXSellar1: 0:00:00.047428 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar1: 0:00:00.000113 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellar2: 0:00:00.020616 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellar2: 0:00:00.000106 seconds.
    INFO - 12:32:32: Compilation of the output function JAXSellarSystem: 0:00:00.052615 seconds.
    INFO - 12:32:32: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000150 seconds.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.223370036261753 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2923110109997984 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3606902755973218 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.30484854231673003 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2808817022073873 is still above the tolerance 1e-06.
 WARNING - 12:32:32: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8210433250547942 is still above the tolerance 1e-06.

Now let's visualize our results:

mda_ref = next(iter(mdas.keys()))
t_ref = times[mda_ref]
speedup = {
    mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()
}

fig, axes = subplots(2, 1, layout="constrained", figsize=(6, 8))
fig.suptitle("Speedup compared to NumPy Analytical")
for mda_name in mdas:
    linestyle = ":" if mda_name == mda_ref else "-"
    speedup_e, speedup_l = speedup[mda_name]
    axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)
    axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)
axes[0].legend(bbox_to_anchor=(0.9, -0.1))
axes[0].set_ylabel("Execution")
axes[0].set_xscale("log")
axes[1].set_ylabel("Linearization")
axes[1].set_xlabel("Dimension")
axes[1].set_xscale("log")
show()

Speedup compared to NumPy Analytical

Conclusion

JAX AD is as fast as analytical derivatives with NumPy. Encapsulation with JAXChain slows execution, but speeds-up linearization. Speedup is maintained even at higher dimensions.

Total running time of the script: ( 0 minutes 2.879 seconds)

Download Python source code: plot_sellar_scalable.py

Download Jupyter notebook: plot_sellar_scalable.ipynb

Gallery generated by mkdocs-gallery