Note
Click here to download the full example code
Analysis of the scalable Sellar problem with JAX.
from __future__ import annotations
from datetime import timedelta
from timeit import default_timer
from typing import TYPE_CHECKING
from gemseo import configure
from gemseo import configure_logger
from gemseo import create_mda
from gemseo.core.discipline.discipline import Discipline
from gemseo.problems.mdo.sellar.sellar_1 import Sellar1
from gemseo.problems.mdo.sellar.sellar_2 import Sellar2
from gemseo.problems.mdo.sellar.sellar_system import SellarSystem
from gemseo.problems.mdo.sellar.utils import get_initial_data
from matplotlib.pyplot import show
from matplotlib.pyplot import subplots
from numpy import array
from numpy.random import default_rng
from gemseo_jax.jax_chain import DifferentiationMethod
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem
if TYPE_CHECKING:
from gemseo.mda.base_mda import BaseMDA
from gemseo.typing import RealArray
# Deactivate some checkers to speed up calculations in presence of cheap disciplines.
configure(False, False, True, False, False, False, False)
configure_logger()
def get_random_input_data(n: int) -> dict[str, RealArray]:
"""Return a random input value for [JAX]SellarSystem."""
r_float = default_rng().random()
return {
name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()
}
def get_numpy_disciplines(n: int) -> list[Discipline]:
"""Return the NumPy-based Sellar disciplines."""
return [
Sellar1(n=n),
Sellar2(n=n),
SellarSystem(n=n),
]
def get_jax_disciplines(
n: int, differentiation_method=DifferentiationMethod.AUTO
) -> list[Discipline]:
"""Return the JAX-based Sellar disciplines."""
disciplines = [
JAXSellar1(n=n, differentiation_method=differentiation_method),
JAXSellar2(n=n, differentiation_method=differentiation_method),
JAXSellarSystem(n=n, differentiation_method=differentiation_method),
]
for disc in disciplines:
disc.set_cache(Discipline.CacheType.SIMPLE)
disc.compile_jit()
return disciplines
Initial setup for comparison¶
Here we intend to compare the original NumPy implementation with the JAX one. We then need to create the original MDA and one JAX MDA for each configuration we're testing. In this example we compare the performance of the JAXChain encapsulation and also the forward and reverse modes for automatic differentiation.
def get_analytical_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with analytical NumPy Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_numpy_disciplines(n),
max_mda_iter=max_mda_iter,
name="Analytical SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_forward_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with JAX forward-mode AD Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_chained_forward_ad_mda(
n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
"""Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian."""
discipline = JAXSellarChain(
n=n,
differentiation_method=DifferentiationMethod.FORWARD,
)
discipline.add_differentiated_inputs(discipline.input_grammar.names)
discipline.add_differentiated_outputs(discipline.output_grammar.names)
mda = create_mda(
mda_name=mda_name,
disciplines=[discipline],
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_reverse_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with JAX reverse-mode AD Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_chained_reverse_ad_mda(
n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
"""Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian."""
discipline = JAXSellarChain(
n=n,
differentiation_method=DifferentiationMethod.REVERSE,
)
discipline.add_differentiated_inputs(discipline.input_grammar.names)
discipline.add_differentiated_outputs(discipline.output_grammar.names)
mda = create_mda(
mda_name=mda_name,
disciplines=[discipline],
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
mdas = {
"MDOChain[NumPy] - Analytical": get_analytical_mda, # this is the reference
"JAXChain - Forward AD": get_chained_forward_ad_mda,
"JAXChain - Reverse AD": get_chained_reverse_ad_mda,
"MDOChain[JAX] - Forward AD": get_forward_ad_mda,
"MDOChain[JAX] - Reverse AD": get_reverse_ad_mda,
}
Execution and linearization scalability¶
Let's make a function to execute and linearize an MDA, while logging times. Also, we run several repetitions to avoid noisy results:
def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):
mda = get_mda(dimension, **mda_options)
t0 = default_timer()
for _i in range(n_repeat):
mda.execute({
name: value
for name, value in get_random_input_data(dimension).items()
if name in mda.input_grammar.names
})
t1 = default_timer()
t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)
t2 = default_timer()
for _i in range(n_repeat):
mda.linearize({
name: value
for name, value in get_random_input_data(dimension).items()
if name in mda.input_grammar.names
})
t3 = default_timer()
t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)
return t_execute, t_linearize
Run the MDA for each of the mdas, for several number of dimensions¶
dimensions = [1, 10, 100, 1000]
times = {}
mda_config = {"mda_name": "MDAGaussSeidel", "max_mda_iter": 1}
for mda_name, mda_func in mdas.items():
time_exec = []
time_lin = []
for dimension in dimensions:
t_e, t_l = run_and_log(mda_func, dimension, **mda_config)
time_exec.append(t_e)
time_lin.append(t_l)
times[mda_name] = (array(time_exec), array(time_lin))
Out:
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5546886292069908 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0986070997661397 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.620118226172266 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7048358556528266 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3648964298332587 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.25980956033730035 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9516636622427919 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3395020760450006 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9150608020023999 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.04758781336242676 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5660006640828333 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9668009455572535 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5673820160431442 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.31721138662974613 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5028104394333263 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6756343250168605 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.176223158607883 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.0030013402689324513 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7746933028741678 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7082704420320178 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.669177942729462 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5038237757177729 is still above the tolerance 1e-06.
WARNING - 08:28:56: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0048616618567556 is still above the tolerance 1e-06.
WARNING - 08:28:57: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1445325710580079 is still above the tolerance 1e-06.
INFO - 08:28:57: Compilation of the output function JAXSellarChain: 0:00:00.037103 seconds.
INFO - 08:28:57: Compilation of the Jacobian function JAXSellarChain: 0:00:00.042440 seconds.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7896396368998957 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.48695639881204306 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0300474677814973 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.30285989906836885 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8580817335051943 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.23599716439065963 is still above the tolerance 1e-06.
INFO - 08:28:57: Compilation of the output function JAXSellarChain: 0:00:00.062856 seconds.
INFO - 08:28:57: Compilation of the Jacobian function JAXSellarChain: 0:00:00.146684 seconds.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.469730283312027 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6997925328881764 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.252352661355927 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.4133054640375904 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.9800727448080906 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.450638235009209 is still above the tolerance 1e-06.
INFO - 08:28:57: Compilation of the output function JAXSellarChain: 0:00:00.069301 seconds.
INFO - 08:28:57: Compilation of the Jacobian function JAXSellarChain: 0:00:00.141928 seconds.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8046244014934893 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.173441766175061 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8959033875877964 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3895398950840101 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7038856216160984 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7699469960652549 is still above the tolerance 1e-06.
INFO - 08:28:57: Compilation of the output function JAXSellarChain: 0:00:00.068091 seconds.
INFO - 08:28:57: Compilation of the Jacobian function JAXSellarChain: 0:00:00.176125 seconds.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.106646829679672 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.4268154368193056 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.872809318086947 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.382903231181133 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.1840028044465605 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.1264229970075226 is still above the tolerance 1e-06.
INFO - 08:28:57: Compilation of the output function JAXSellarChain: 0:00:00.033875 seconds.
INFO - 08:28:57: Compilation of the Jacobian function JAXSellarChain: 0:00:00.094772 seconds.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 6.6120700328781545 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.930822001433046 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 6.496117168540394 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.363439891639536 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 7.253730237079061 is still above the tolerance 1e-06.
WARNING - 08:28:57: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 4.365639817612439 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellarChain: 0:00:00.054878 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarChain: 0:00:00.097151 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.24508272068731643 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.49980944608771855 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2693811780699665 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3683268549779166 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5206822346320639 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2449563634011011 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellarChain: 0:00:00.067232 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarChain: 0:00:00.116384 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0226940986818058 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8780402160210831 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.367580867900315 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.32245088168757474 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8388684157406607 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9833466447477932 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellarChain: 0:00:00.066038 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarChain: 0:00:00.120726 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.765435881773645 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.101856123643968 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2845146232881204 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.5637186581284337 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9365743824660403 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.21149143583364763 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellar1: 0:00:00.017449 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar1: 0:00:00.000098 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellar2: 0:00:00.015129 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar2: 0:00:00.000116 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellarSystem: 0:00:00.024700 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000114 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8343327070461993 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.578865321893849 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.15877218374331 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.144094857758517 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.013018066099392187 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.860826526881003 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellar1: 0:00:00.020012 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar1: 0:00:00.000120 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellar2: 0:00:00.018158 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar2: 0:00:00.000094 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellarSystem: 0:00:00.044499 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000135 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.7180948681055423 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.5228131905017115 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.5207094533313286 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1667966768133644 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2287914566128084 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2581727081326972 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellar1: 0:00:00.020716 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar1: 0:00:00.000097 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellar2: 0:00:00.019797 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar2: 0:00:00.000094 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellarSystem: 0:00:00.051673 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000109 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1822918952461472 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2754893422285849 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0611084726253808 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.36765185170669024 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9895425347523426 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5901900512551367 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellar1: 0:00:00.018445 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar1: 0:00:00.000094 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellar2: 0:00:00.036387 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar2: 0:00:00.000094 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellarSystem: 0:00:00.046621 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000118 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.061068023379994 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6522475106635245 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3340714270133879 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.537594906118384 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.43841302972833623 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5002127020996002 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellar1: 0:00:00.017439 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar1: 0:00:00.000114 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellar2: 0:00:00.015567 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar2: 0:00:00.000095 seconds.
INFO - 08:28:58: Compilation of the output function JAXSellarSystem: 0:00:00.023970 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000110 seconds.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.600216695181937 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.520386830672144 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8216047666345632 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.05898395014122103 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9450846369438449 is still above the tolerance 1e-06.
WARNING - 08:28:58: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9418856242646887 is still above the tolerance 1e-06.
INFO - 08:28:58: Compilation of the output function JAXSellar1: 0:00:00.019292 seconds.
INFO - 08:28:58: Compilation of the Jacobian function JAXSellar1: 0:00:00.000117 seconds.
INFO - 08:28:59: Compilation of the output function JAXSellar2: 0:00:00.017229 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellar2: 0:00:00.000091 seconds.
INFO - 08:28:59: Compilation of the output function JAXSellarSystem: 0:00:00.042825 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000111 seconds.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.03153088408933349 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.642422401185542 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.219046297927667 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 6.758177308518381 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 3.3615582756185027 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.9570094280832104 is still above the tolerance 1e-06.
INFO - 08:28:59: Compilation of the output function JAXSellar1: 0:00:00.020958 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellar1: 0:00:00.000109 seconds.
INFO - 08:28:59: Compilation of the output function JAXSellar2: 0:00:00.019420 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellar2: 0:00:00.000095 seconds.
INFO - 08:28:59: Compilation of the output function JAXSellarSystem: 0:00:00.051350 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000114 seconds.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 14.183939028324621 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 15.808276272852707 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 8.681827923711703 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6955868709388913 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 10.352736623816092 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 12.431057379495849 is still above the tolerance 1e-06.
INFO - 08:28:59: Compilation of the output function JAXSellar1: 0:00:00.018951 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellar1: 0:00:00.000101 seconds.
INFO - 08:28:59: Compilation of the output function JAXSellar2: 0:00:00.019499 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellar2: 0:00:00.000097 seconds.
INFO - 08:28:59: Compilation of the output function JAXSellarSystem: 0:00:00.049886 seconds.
INFO - 08:28:59: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000121 seconds.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.28719902233293154 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.062286506267661905 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.2843294493565955 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.12939879587869 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1263114491264071 is still above the tolerance 1e-06.
WARNING - 08:28:59: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7341429045834756 is still above the tolerance 1e-06.
Now let's visualize our results:
mda_ref = next(iter(mdas.keys()))
t_ref = times[mda_ref]
speedup = {
mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()
}
fig, axes = subplots(2, 1, layout="constrained", figsize=(6, 8))
fig.suptitle("Speedup compared to NumPy Analytical")
for mda_name in mdas:
linestyle = ":" if mda_name == mda_ref else "-"
speedup_e, speedup_l = speedup[mda_name]
axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)
axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)
axes[0].legend(bbox_to_anchor=(0.9, -0.1))
axes[0].set_ylabel("Execution")
axes[0].set_xscale("log")
axes[1].set_ylabel("Linearization")
axes[1].set_xlabel("Dimension")
axes[1].set_xscale("log")
show()

Conclusion¶
JAX AD is as fast as analytical derivatives with NumPy. Encapsulation with JAXChain slows execution, but speeds-up linearization. Speedup is maintained even at higher dimensions.
Total running time of the script: ( 0 minutes 2.738 seconds)
Download Python source code: plot_sellar_scalable.py