Skip to content

Note

Click here to download the full example code

Analysis of the scalable Sellar problem with JAX.

from __future__ import annotations

from datetime import timedelta
from timeit import default_timer
from typing import TYPE_CHECKING

from gemseo import configure
from gemseo import configure_logger
from gemseo import create_mda
from gemseo.core.discipline.discipline import Discipline
from gemseo.problems.mdo.sellar.sellar_1 import Sellar1
from gemseo.problems.mdo.sellar.sellar_2 import Sellar2
from gemseo.problems.mdo.sellar.sellar_system import SellarSystem
from gemseo.problems.mdo.sellar.utils import get_initial_data
from matplotlib.pyplot import show
from matplotlib.pyplot import subplots
from numpy import array
from numpy.random import default_rng

from gemseo_jax.jax_chain import DifferentiationMethod
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem

if TYPE_CHECKING:
    from gemseo.mda.base_mda import BaseMDA
    from gemseo.typing import RealArray

# Deactivate some checkers to speed up calculations in presence of cheap disciplines.
configure(False, False, True, False, False, False, False)
configure_logger()


def get_random_input_data(n: int) -> dict[str, RealArray]:
    """Return a random input value for [JAX]SellarSystem."""
    r_float = default_rng().random()
    return {
        name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()
    }


def get_numpy_disciplines(n: int) -> list[Discipline]:
    """Return the NumPy-based Sellar disciplines."""
    return [
        Sellar1(n=n),
        Sellar2(n=n),
        SellarSystem(n=n),
    ]


def get_jax_disciplines(
    n: int, differentiation_method=DifferentiationMethod.AUTO
) -> list[Discipline]:
    """Return the JAX-based Sellar disciplines."""
    disciplines = [
        JAXSellar1(n=n, differentiation_method=differentiation_method),
        JAXSellar2(n=n, differentiation_method=differentiation_method),
        JAXSellarSystem(n=n, differentiation_method=differentiation_method),
    ]
    for disc in disciplines:
        disc.set_cache(Discipline.CacheType.SIMPLE)
        disc.compile_jit()

    return disciplines

Out:

/builds/gemseo/dev/gemseo-jax/docs/examples/sellar/plot_sellar_scalable.py:49: DeprecationWarning: configure() is deprecated; use gemseo.configuration instead.
  configure(False, False, True, False, False, False, False)
/builds/gemseo/dev/gemseo-jax/docs/examples/sellar/plot_sellar_scalable.py:50: DeprecationWarning: configure_logger() is deprecated; use gemseo.configuration.logging instead.
  configure_logger()

Initial setup for comparison

Here we intend to compare the original NumPy implementation with the JAX one. We then need to create the original MDA and one JAX MDA for each configuration we're testing. In this example we compare the performance of the JAXChain encapsulation and also the forward and reverse modes for automatic differentiation.

def get_analytical_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with analytical NumPy Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_numpy_disciplines(n),
        max_mda_iter=max_mda_iter,
        name="Analytical SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_forward_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with JAX forward-mode AD Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_chained_forward_ad_mda(
    n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
    """Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian."""
    discipline = JAXSellarChain(
        n=n,
        differentiation_method=DifferentiationMethod.FORWARD,
    )
    discipline.add_differentiated_inputs(discipline.input_grammar.names)
    discipline.add_differentiated_outputs(discipline.output_grammar.names)

    mda = create_mda(
        mda_name=mda_name,
        disciplines=[discipline],
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_reverse_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
    """Return the Sellar MDA with JAX reverse-mode AD Jacobian."""
    mda = create_mda(
        mda_name=mda_name,
        disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


def get_chained_reverse_ad_mda(
    n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
    """Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian."""
    discipline = JAXSellarChain(
        n=n,
        differentiation_method=DifferentiationMethod.REVERSE,
    )
    discipline.add_differentiated_inputs(discipline.input_grammar.names)
    discipline.add_differentiated_outputs(discipline.output_grammar.names)

    mda = create_mda(
        mda_name=mda_name,
        disciplines=[discipline],
        max_mda_iter=max_mda_iter,
        name="JAX SellarChain",
    )
    mda.set_cache(Discipline.CacheType.SIMPLE)
    return mda


mdas = {
    "MDOChain[NumPy] - Analytical": get_analytical_mda,  # this is the reference
    "JAXChain - Forward AD": get_chained_forward_ad_mda,
    "JAXChain - Reverse AD": get_chained_reverse_ad_mda,
    "MDOChain[JAX] - Forward AD": get_forward_ad_mda,
    "MDOChain[JAX] - Reverse AD": get_reverse_ad_mda,
}

Execution and linearization scalability

Let's make a function to execute and linearize an MDA, while logging times. Also, we run several repetitions to avoid noisy results:

def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):
    mda = get_mda(dimension, **mda_options)
    t0 = default_timer()
    for _i in range(n_repeat):
        mda.execute({
            name: value
            for name, value in get_random_input_data(dimension).items()
            if name in mda.input_grammar.names
        })
    t1 = default_timer()
    t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)

    t2 = default_timer()
    for _i in range(n_repeat):
        mda.linearize({
            name: value
            for name, value in get_random_input_data(dimension).items()
            if name in mda.input_grammar.names
        })
    t3 = default_timer()
    t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)
    return t_execute, t_linearize

Run the MDA for each of the mdas, for several number of dimensions

dimensions = [1, 10, 100, 1000]
times = {}
mda_config = {"mda_name": "MDAGaussSeidel", "max_mda_iter": 1}
for mda_name, mda_func in mdas.items():
    time_exec = []
    time_lin = []
    for dimension in dimensions:
        t_e, t_l = run_and_log(mda_func, dimension, **mda_config)
        time_exec.append(t_e)
        time_lin.append(t_l)
    times[mda_name] = (array(time_exec), array(time_lin))

Out:

 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0196199738908978 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0196199738908978 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.32623504510131635 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.32623504510131635 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2572739627791674 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2572739627791674 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9894555395357275 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9894555395357275 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1473907170119917 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1473907170119917 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3163285013103006 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3163285013103006 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8460840494800351 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8460840494800351 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7750276699978446 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7750276699978446 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7466169792674311 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7466169792674311 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5785469771419497 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5785469771419497 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7621165177330704 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7621165177330704 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9475005691540478 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9475005691540478 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.10736409821793287 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.10736409821793287 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.08619188678949952 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.08619188678949952 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8280613761648348 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8280613761648348 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.11571353277398841 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.11571353277398841 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.683535956605644 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.683535956605644 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5614815801670866 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5614815801670866 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.46201878242294053 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.46201878242294053 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.652730816457202 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.652730816457202 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5726405066623661 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5726405066623661 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.036850355085621 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.036850355085621 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.05664032999041831 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.05664032999041831 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.1608586207945852 is still above the tolerance 1e-06.
 WARNING - 02:44:05: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.1608586207945852 is still above the tolerance 1e-06.
    INFO - 02:44:05: Compilation of the output function JAXSellarChain: 0:00:00.035137 seconds.
    INFO - 02:44:05: Compilation of the output function JAXSellarChain: 0:00:00.035137 seconds.
    INFO - 02:44:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.043264 seconds.
    INFO - 02:44:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.043264 seconds.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4956387960708724 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4956387960708724 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7437162920842937 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7437162920842937 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5467424605753093 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5467424605753093 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.210419338983432 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.210419338983432 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.231425943196176 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.231425943196176 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.08601421696666016 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.08601421696666016 is still above the tolerance 1e-06.
    INFO - 02:44:05: Compilation of the output function JAXSellarChain: 0:00:00.064808 seconds.
    INFO - 02:44:05: Compilation of the output function JAXSellarChain: 0:00:00.064808 seconds.
    INFO - 02:44:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.139561 seconds.
    INFO - 02:44:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.139561 seconds.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.938937971870738 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.938937971870738 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.743004737605165 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.743004737605165 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1098498589615426 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1098498589615426 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2347664037468666 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2347664037468666 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5264725750887225 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5264725750887225 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6390840858717894 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6390840858717894 is still above the tolerance 1e-06.
    INFO - 02:44:05: Compilation of the output function JAXSellarChain: 0:00:00.069671 seconds.
    INFO - 02:44:05: Compilation of the output function JAXSellarChain: 0:00:00.069671 seconds.
    INFO - 02:44:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.148114 seconds.
    INFO - 02:44:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.148114 seconds.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.47239385706309456 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.47239385706309456 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.043687228119033 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.043687228119033 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7238949924989856 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7238949924989856 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.1512685070232425 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.1512685070232425 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1825056577102717 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1825056577102717 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.289498815983237 is still above the tolerance 1e-06.
 WARNING - 02:44:05: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.289498815983237 is still above the tolerance 1e-06.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.067940 seconds.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.067940 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.184899 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.184899 seconds.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.33249055313244247 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.33249055313244247 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0000424010470885 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0000424010470885 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7969252541053791 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7969252541053791 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8683111945254725 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8683111945254725 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.18094174161040044 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.18094174161040044 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6255193987611076 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6255193987611076 is still above the tolerance 1e-06.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.034595 seconds.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.034595 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.101418 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.101418 seconds.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 137.8478403898255 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 137.8478403898255 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.729481385969917 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.729481385969917 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 236.63958793339836 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 236.63958793339836 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 276.8309183836523 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 276.8309183836523 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 70.05795250247795 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 70.05795250247795 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 116.4801218459037 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 116.4801218459037 is still above the tolerance 1e-06.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.060631 seconds.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.060631 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.099005 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.099005 seconds.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.608896424060815 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.608896424060815 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7389175730502303 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7389175730502303 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.981862023288328 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.981862023288328 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1547631260192108 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1547631260192108 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8333197597370452 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8333197597370452 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7736506599312667 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7736506599312667 is still above the tolerance 1e-06.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.084241 seconds.
    INFO - 02:44:06: Compilation of the output function JAXSellarChain: 0:00:00.084241 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.161178 seconds.
    INFO - 02:44:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.161178 seconds.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9957908579404036 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9957908579404036 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5240966755652399 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5240966755652399 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.216922578674246 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.216922578674246 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.629585317099367 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.629585317099367 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.375652574769338 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.375652574769338 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.151686493089283 is still above the tolerance 1e-06.
 WARNING - 02:44:06: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.151686493089283 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellarChain: 0:00:00.066006 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarChain: 0:00:00.066006 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.134530 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.134530 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1372529578951081 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1372529578951081 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2471307502594895 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2471307502594895 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2577624702608348 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2577624702608348 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0211417476069573 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0211417476069573 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.210540296158669 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.210540296158669 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8003786124851122 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8003786124851122 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.016247 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.016247 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000076 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000076 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.013781 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.013781 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000077 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000077 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.022711 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.022711 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000089 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000089 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.7415305257982086 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.7415305257982086 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2350815152118355 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2350815152118355 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.160146443560269 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.160146443560269 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8258650882729287 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8258650882729287 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4732085436160605 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4732085436160605 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7051553282203107 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7051553282203107 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018586 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018586 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000088 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000088 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.017000 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.017000 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000077 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000077 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.045973 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.045973 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000101 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000101 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9339943880948812 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9339943880948812 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8222547632346351 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8222547632346351 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8036425471837777 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8036425471837777 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2845256543434771 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2845256543434771 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.45694746474150144 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.45694746474150144 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5934122230999254 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5934122230999254 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018802 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018802 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000086 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000086 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018340 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018340 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000076 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000076 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.051647 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.051647 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000091 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000091 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.788151978498167 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.788151978498167 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4183060048592023 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4183060048592023 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2829960526819817 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2829960526819817 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7158021966712136 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7158021966712136 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.369434250973646 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.369434250973646 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.025139204522386018 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.025139204522386018 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.019161 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.019161 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000085 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000085 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.034341 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.034341 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000076 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000076 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.051017 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.051017 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000097 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000097 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.705578548017265 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.705578548017265 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2355508828055848 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2355508828055848 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.664198479652921 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.664198479652921 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0108167550739897 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0108167550739897 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.04383270402125358 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.04383270402125358 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2044719603477205 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2044719603477205 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.017277 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.017277 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000080 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000080 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.014784 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.014784 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000075 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000075 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.024639 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.024639 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000137 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000137 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.579888724659396 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.579888724659396 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 16.44763580199732 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 16.44763580199732 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 15.76331509733214 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 15.76331509733214 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 7.584546710362932 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 7.584546710362932 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.014782450433781 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 5.014782450433781 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 8.955152538631513 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 8.955152538631513 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018790 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018790 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000088 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000088 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018308 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018308 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000074 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000074 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.039649 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.039649 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000118 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000118 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1883531330765285 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1883531330765285 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8414529142300701 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8414529142300701 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0609127236124913 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0609127236124913 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.03736808400747718 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.03736808400747718 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7474020321423281 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7474020321423281 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.11136594233748599 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.11136594233748599 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018316 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018316 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000079 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000079 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018090 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018090 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000109 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000109 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.048240 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.048240 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000093 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000093 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9533526518790265 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9533526518790265 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6480948931845716 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6480948931845716 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2534129969733071 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2534129969733071 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.06804641043054446 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.06804641043054446 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1040527589040883 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1040527589040883 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0213884475366504 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0213884475366504 is still above the tolerance 1e-06.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018639 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar1: 0:00:00.018639 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000086 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000086 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018136 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellar2: 0:00:00.018136 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000089 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellar2: 0:00:00.000089 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.046762 seconds.
    INFO - 02:44:07: Compilation of the output function JAXSellarSystem: 0:00:00.046762 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000099 seconds.
    INFO - 02:44:07: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000099 seconds.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1623627628165063 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1623627628165063 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.42364313560631645 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.42364313560631645 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6363721644718494 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6363721644718494 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3001024426814596 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3001024426814596 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.34986516119071315 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.34986516119071315 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3501727936894313 is still above the tolerance 1e-06.
 WARNING - 02:44:07: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3501727936894313 is still above the tolerance 1e-06.

Now let's visualize our results:

mda_ref = next(iter(mdas.keys()))
t_ref = times[mda_ref]
speedup = {
    mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()
}

fig, axes = subplots(2, 1, layout="constrained", figsize=(6, 8))
fig.suptitle("Speedup compared to NumPy Analytical")
for mda_name in mdas:
    linestyle = ":" if mda_name == mda_ref else "-"
    speedup_e, speedup_l = speedup[mda_name]
    axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)
    axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)
axes[0].legend(bbox_to_anchor=(0.9, -0.1))
axes[0].set_ylabel("Execution")
axes[0].set_xscale("log")
axes[1].set_ylabel("Linearization")
axes[1].set_xlabel("Dimension")
axes[1].set_xscale("log")
show()

Speedup compared to NumPy Analytical

Conclusion

JAX AD is as fast as analytical derivatives with NumPy. Encapsulation with JAXChain slows execution, but speeds-up linearization. Speedup is maintained even at higher dimensions.

Total running time of the script: ( 0 minutes 2.798 seconds)

Download Python source code: plot_sellar_scalable.py

Download Jupyter notebook: plot_sellar_scalable.ipynb

Gallery generated by mkdocs-gallery