Note
Click here to download the full example code
Analysis of the scalable Sellar problem with JAX.
from __future__ import annotations
from datetime import timedelta
from timeit import default_timer
from typing import TYPE_CHECKING
from gemseo import configure
from gemseo import configure_logger
from gemseo import create_mda
from gemseo.core.discipline.discipline import Discipline
from gemseo.problems.mdo.sellar.sellar_1 import Sellar1
from gemseo.problems.mdo.sellar.sellar_2 import Sellar2
from gemseo.problems.mdo.sellar.sellar_system import SellarSystem
from gemseo.problems.mdo.sellar.utils import get_initial_data
from matplotlib.pyplot import show
from matplotlib.pyplot import subplots
from numpy import array
from numpy.random import default_rng
from gemseo_jax.jax_chain import DifferentiationMethod
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem
if TYPE_CHECKING:
from gemseo.mda.base_mda import BaseMDA
from gemseo.typing import RealArray
# Deactivate some checkers to speed up calculations in presence of cheap disciplines.
configure(False, False, True, False, False, False, False)
configure_logger()
def get_random_input_data(n: int) -> dict[str, RealArray]:
"""Return a random input value for [JAX]SellarSystem."""
r_float = default_rng().random()
return {
name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()
}
def get_numpy_disciplines(n: int) -> list[Discipline]:
"""Return the NumPy-based Sellar disciplines."""
return [
Sellar1(n=n),
Sellar2(n=n),
SellarSystem(n=n),
]
def get_jax_disciplines(
n: int, differentiation_method=DifferentiationMethod.AUTO
) -> list[Discipline]:
"""Return the JAX-based Sellar disciplines."""
disciplines = [
JAXSellar1(n=n, differentiation_method=differentiation_method),
JAXSellar2(n=n, differentiation_method=differentiation_method),
JAXSellarSystem(n=n, differentiation_method=differentiation_method),
]
for disc in disciplines:
disc.set_cache(Discipline.CacheType.SIMPLE)
disc.compile_jit()
return disciplines
Initial setup for comparison¶
Here we intend to compare the original NumPy implementation with the JAX one. We then need to create the original MDA and one JAX MDA for each configuration we're testing. In this example we compare the performance of the JAXChain encapsulation and also the forward and reverse modes for automatic differentiation.
def get_analytical_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with analytical NumPy Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_numpy_disciplines(n),
max_mda_iter=max_mda_iter,
name="Analytical SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_forward_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with JAX forward-mode AD Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_chained_forward_ad_mda(
n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
"""Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian."""
discipline = JAXSellarChain(
n=n,
differentiation_method=DifferentiationMethod.FORWARD,
)
discipline.add_differentiated_inputs(discipline.input_grammar.names)
discipline.add_differentiated_outputs(discipline.output_grammar.names)
mda = create_mda(
mda_name=mda_name,
disciplines=[discipline],
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_reverse_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with JAX reverse-mode AD Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_chained_reverse_ad_mda(
n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
"""Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian."""
discipline = JAXSellarChain(
n=n,
differentiation_method=DifferentiationMethod.REVERSE,
)
discipline.add_differentiated_inputs(discipline.input_grammar.names)
discipline.add_differentiated_outputs(discipline.output_grammar.names)
mda = create_mda(
mda_name=mda_name,
disciplines=[discipline],
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
mdas = {
"MDOChain[NumPy] - Analytical": get_analytical_mda, # this is the reference
"JAXChain - Forward AD": get_chained_forward_ad_mda,
"JAXChain - Reverse AD": get_chained_reverse_ad_mda,
"MDOChain[JAX] - Forward AD": get_forward_ad_mda,
"MDOChain[JAX] - Reverse AD": get_reverse_ad_mda,
}
Execution and linearization scalability¶
Let's make a function to execute and linearize an MDA, while logging times. Also, we run several repetitions to avoid noisy results:
def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):
mda = get_mda(dimension, **mda_options)
t0 = default_timer()
for _i in range(n_repeat):
mda.execute({
name: value
for name, value in get_random_input_data(dimension).items()
if name in mda.input_grammar.names
})
t1 = default_timer()
t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)
t2 = default_timer()
for _i in range(n_repeat):
mda.linearize({
name: value
for name, value in get_random_input_data(dimension).items()
if name in mda.input_grammar.names
})
t3 = default_timer()
t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)
return t_execute, t_linearize
Run the MDA for each of the mdas, for several number of dimensions¶
dimensions = [1, 10, 100, 1000]
times = {}
mda_config = {"mda_name": "MDAGaussSeidel", "max_mda_iter": 1}
for mda_name, mda_func in mdas.items():
time_exec = []
time_lin = []
for dimension in dimensions:
t_e, t_l = run_and_log(mda_func, dimension, **mda_config)
time_exec.append(t_e)
time_lin.append(t_l)
times[mda_name] = (array(time_exec), array(time_lin))
Out:
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.731571242324299 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5749379687917393 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5631209405087193 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5037662734254547 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.01920829060517616 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.48181745959362804 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4875045947678294 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7662130690137707 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9638121349576176 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4301936444453073 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1316581998647564 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4031649493299726 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.12086472549922253 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.296131816731855 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0714488839519931 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0546725197898346 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6749882957206633 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9312048945660178 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7758348876721982 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.062740022232492 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6772109943026285 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9307564699159938 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.06841719810036027 is still above the tolerance 1e-06.
WARNING - 09:58:35: Analytical SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0367076464955671 is still above the tolerance 1e-06.
INFO - 09:58:36: Compilation of the output function JAXSellarChain: 0:00:00.034781 seconds.
INFO - 09:58:36: Compilation of the Jacobian function JAXSellarChain: 0:00:00.042578 seconds.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8928236920159215 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6484017209297558 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.6541743879206379 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.10039161632401486 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7987256131511357 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.9389225503849092 is still above the tolerance 1e-06.
INFO - 09:58:36: Compilation of the output function JAXSellarChain: 0:00:00.061201 seconds.
INFO - 09:58:36: Compilation of the Jacobian function JAXSellarChain: 0:00:00.145521 seconds.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.609348896733363 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.672171712145969 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.03956122395800638 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2991660576502344 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.855283569718096 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.438606202548064 is still above the tolerance 1e-06.
INFO - 09:58:36: Compilation of the output function JAXSellarChain: 0:00:00.074003 seconds.
INFO - 09:58:36: Compilation of the Jacobian function JAXSellarChain: 0:00:00.149375 seconds.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.874712801665954 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.4112100422616383 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5905758034072706 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6143477242903345 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.34694365251407366 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.10103685967127757 is still above the tolerance 1e-06.
INFO - 09:58:36: Compilation of the output function JAXSellarChain: 0:00:00.071255 seconds.
INFO - 09:58:36: Compilation of the Jacobian function JAXSellarChain: 0:00:00.174499 seconds.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.667076836352051 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.003615450513030216 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8210833976718002 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1499009611109794 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.462966253599552 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2430169176679695 is still above the tolerance 1e-06.
INFO - 09:58:36: Compilation of the output function JAXSellarChain: 0:00:00.046104 seconds.
INFO - 09:58:36: Compilation of the Jacobian function JAXSellarChain: 0:00:00.083009 seconds.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.048731288746746 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.479813297957901 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2887161177673174 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.095575256502335 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1428341328626177 is still above the tolerance 1e-06.
WARNING - 09:58:36: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.08001851661064713 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellarChain: 0:00:00.054317 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarChain: 0:00:00.095768 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.643713957050578 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.0939347025493693 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0768509552633683 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3891497916141767 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8625990175374711 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.28025752675291127 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellarChain: 0:00:00.065994 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarChain: 0:00:00.112838 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.973152345952776 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4124460540162393 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.2264014614334258 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4895028673670563 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.37294428115951744 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.360657627221549 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellarChain: 0:00:00.064353 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarChain: 0:00:00.115507 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.28114243258328364 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.45991335170584996 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9534166771866155 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0407212735390023 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6313069477717125 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.8767586421241225 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.017663 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000099 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellar2: 0:00:00.015454 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar2: 0:00:00.000088 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellarSystem: 0:00:00.024068 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000113 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.3235110035257684 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8077234879316084 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.45167518578860505 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2906563988691306 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4994489808452018 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.3563192708017637 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.019634 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000121 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellar2: 0:00:00.017396 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar2: 0:00:00.000106 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellarSystem: 0:00:00.051366 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000109 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3797588268785002 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.6882017533818101 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.347961326636091 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.362006406309854 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9917652673338507 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.8352317214616845 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.018783 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000091 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellar2: 0:00:00.018349 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar2: 0:00:00.000086 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellarSystem: 0:00:00.049873 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000106 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4292586104426512 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1981313029656037 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.011231211654108585 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.42882683063104876 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.3508270768314768 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.1890509051420213 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.018616 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000093 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellar2: 0:00:00.018237 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar2: 0:00:00.000106 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellarSystem: 0:00:00.049658 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000123 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.04355235672645337 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7883273520959159 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5004773848118363 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7006370162571727 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5158801981919856 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9868624668945293 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.017053 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000092 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellar2: 0:00:00.015396 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar2: 0:00:00.000088 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellarSystem: 0:00:00.024030 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000124 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.4104374189032316 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.29328989520588483 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0915386619993674 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3130922268178664 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.0714955872436628 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1997306274405335 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.019542 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000187 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellar2: 0:00:00.018148 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar2: 0:00:00.000088 seconds.
INFO - 09:58:37: Compilation of the output function JAXSellarSystem: 0:00:00.041294 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000103 seconds.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9078683950805118 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.47396718230605955 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.7999887027693842 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.9830704663831445 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.634749144323467 is still above the tolerance 1e-06.
WARNING - 09:58:37: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.4351233623111798 is still above the tolerance 1e-06.
INFO - 09:58:37: Compilation of the output function JAXSellar1: 0:00:00.018649 seconds.
INFO - 09:58:37: Compilation of the Jacobian function JAXSellar1: 0:00:00.000089 seconds.
INFO - 09:58:38: Compilation of the output function JAXSellar2: 0:00:00.018258 seconds.
INFO - 09:58:38: Compilation of the Jacobian function JAXSellar2: 0:00:00.000097 seconds.
INFO - 09:58:38: Compilation of the output function JAXSellarSystem: 0:00:00.051903 seconds.
INFO - 09:58:38: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000110 seconds.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.5156093426722732 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.2040300734477092 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.871112765483029 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.8186203837036001 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.4771161033615234 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 2.569518093035286 is still above the tolerance 1e-06.
INFO - 09:58:38: Compilation of the output function JAXSellar1: 0:00:00.046689 seconds.
INFO - 09:58:38: Compilation of the Jacobian function JAXSellar1: 0:00:00.000115 seconds.
INFO - 09:58:38: Compilation of the output function JAXSellar2: 0:00:00.019258 seconds.
INFO - 09:58:38: Compilation of the Jacobian function JAXSellar2: 0:00:00.000095 seconds.
INFO - 09:58:38: Compilation of the output function JAXSellarSystem: 0:00:00.047279 seconds.
INFO - 09:58:38: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000115 seconds.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.0 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.09720000834611164 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.47625423394113 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.14737575540216824 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.1430636159911219 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 1.3053541614220143 is still above the tolerance 1e-06.
WARNING - 09:58:38: JAX SellarChain has reached its maximum number of iterations, but the normalized residual norm 0.5645961890241062 is still above the tolerance 1e-06.
Now let's visualize our results:
mda_ref = next(iter(mdas.keys()))
t_ref = times[mda_ref]
speedup = {
mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()
}
fig, axes = subplots(2, 1, layout="constrained", figsize=(6, 8))
fig.suptitle("Speedup compared to NumPy Analytical")
for mda_name in mdas:
linestyle = ":" if mda_name == mda_ref else "-"
speedup_e, speedup_l = speedup[mda_name]
axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)
axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)
axes[0].legend(bbox_to_anchor=(0.9, -0.1))
axes[0].set_ylabel("Execution")
axes[0].set_xscale("log")
axes[1].set_ylabel("Linearization")
axes[1].set_xlabel("Dimension")
axes[1].set_xscale("log")
show()
Conclusion¶
JAX AD is as fast as analytical derivatives with NumPy. Encapsulation with JAXChain slows execution, but speeds-up linearization. Speedup is maintained even at higher dimensions.
Total running time of the script: ( 0 minutes 2.711 seconds)
Download Python source code: plot_sellar_scalable.py