Note
Click here to download the full example code
Analysis of the scalable Sellar problem with JAX.
from __future__ import annotations
from datetime import timedelta
from timeit import default_timer
from typing import TYPE_CHECKING
from gemseo import configure
from gemseo import configure_logger
from gemseo import create_mda
from gemseo.core.discipline.discipline import Discipline
from gemseo.problems.mdo.sellar.sellar_1 import Sellar1
from gemseo.problems.mdo.sellar.sellar_2 import Sellar2
from gemseo.problems.mdo.sellar.sellar_system import SellarSystem
from gemseo.problems.mdo.sellar.utils import get_initial_data
from matplotlib.pyplot import show
from matplotlib.pyplot import subplots
from numpy import array
from numpy.random import default_rng
from gemseo_jax.jax_chain import DifferentiationMethod
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem
if TYPE_CHECKING:
from gemseo.mda.base_mda import BaseMDA
from gemseo.typing import RealArray
# Deactivate some checkers to speed up calculations in presence of cheap disciplines.
configure(False, False, True, False, False, False, False)
configure_logger()
def get_random_input_data(n: int) -> dict[str, RealArray]:
"""Return a random input value for [JAX]SellarSystem."""
r_float = default_rng().random()
return {
name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()
}
def get_numpy_disciplines(n: int) -> list[Discipline]:
"""Return the NumPy-based Sellar disciplines."""
return [
Sellar1(n=n),
Sellar2(n=n),
SellarSystem(n=n),
]
def get_jax_disciplines(
n: int, differentiation_method=DifferentiationMethod.AUTO
) -> list[Discipline]:
"""Return the JAX-based Sellar disciplines."""
disciplines = [
JAXSellar1(n=n, differentiation_method=differentiation_method),
JAXSellar2(n=n, differentiation_method=differentiation_method),
JAXSellarSystem(n=n, differentiation_method=differentiation_method),
]
for disc in disciplines:
disc.set_cache(Discipline.CacheType.SIMPLE)
disc.compile_jit()
return disciplines
Initial setup for comparison¶
Here we intend to compare the original NumPy implementation with the JAX one. We then need to create the original MDA and one JAX MDA for each configuration we're testing. In this example we compare the performance of the JAXChain encapsulation and also the forward and reverse modes for automatic differentiation.
def get_analytical_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with analytical NumPy Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_numpy_disciplines(n),
max_mda_iter=max_mda_iter,
name="Analytical SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_forward_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with JAX forward-mode AD Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_chained_forward_ad_mda(
n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
"""Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian."""
discipline = JAXSellarChain(
n=n,
differentiation_method=DifferentiationMethod.FORWARD,
)
discipline.add_differentiated_inputs(discipline.input_grammar.names)
discipline.add_differentiated_outputs(discipline.output_grammar.names)
mda = create_mda(
mda_name=mda_name,
disciplines=[discipline],
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_reverse_ad_mda(n: int, mda_name="MDAGaussSeidel", max_mda_iter=5) -> BaseMDA:
"""Return the Sellar MDA with JAX reverse-mode AD Jacobian."""
mda = create_mda(
mda_name=mda_name,
disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
def get_chained_reverse_ad_mda(
n: int, mda_name="MDAGaussSeidel", max_mda_iter=5
) -> BaseMDA:
"""Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian."""
discipline = JAXSellarChain(
n=n,
differentiation_method=DifferentiationMethod.REVERSE,
)
discipline.add_differentiated_inputs(discipline.input_grammar.names)
discipline.add_differentiated_outputs(discipline.output_grammar.names)
mda = create_mda(
mda_name=mda_name,
disciplines=[discipline],
max_mda_iter=max_mda_iter,
name="JAX SellarChain",
)
mda.set_cache(Discipline.CacheType.SIMPLE)
return mda
mdas = {
"MDOChain[NumPy] - Analytical": get_analytical_mda, # this is the reference
"JAXChain - Forward AD": get_chained_forward_ad_mda,
"JAXChain - Reverse AD": get_chained_reverse_ad_mda,
"MDOChain[JAX] - Forward AD": get_forward_ad_mda,
"MDOChain[JAX] - Reverse AD": get_reverse_ad_mda,
}
Execution and linearization scalability¶
Let's make a function to execute and linearize an MDA, while logging times. Also, we run several repetitions to avoid noisy results:
def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):
mda = get_mda(dimension, **mda_options)
t0 = default_timer()
for _i in range(n_repeat):
mda.execute({
name: value
for name, value in get_random_input_data(dimension).items()
if name in mda.input_grammar.names
})
t1 = default_timer()
t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)
t2 = default_timer()
for _i in range(n_repeat):
mda.linearize({
name: value
for name, value in get_random_input_data(dimension).items()
if name in mda.input_grammar.names
})
t3 = default_timer()
t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)
return t_execute, t_linearize
Run the MDA for each of the mdas, for several number of dimensions¶
dimensions = [1, 10, 100, 1000]
times = {}
mda_config = {"mda_name": "MDAGaussSeidel", "max_mda_iter": 1}
for mda_name, mda_func in mdas.items():
time_exec = []
time_lin = []
for dimension in dimensions:
t_e, t_l = run_and_log(mda_func, dimension, **mda_config)
time_exec.append(t_e)
time_lin.append(t_l)
times[mda_name] = (array(time_exec), array(time_lin))
Out:
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.15645019260222157 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.6270221415510441 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.3703403844773184 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.31260773913647055 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.5929074857727761 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.4304972712307708 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 6.16201716256071 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 30.105777697476288 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 14.721940954495464 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 24.16610494848708 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 29.996737407353926 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 26.98093121215529 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 6.243472425469544 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 6.4747665653998725 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 12.435679805746306 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 9.384026680568596 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 6.537448235931721 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.8005284609745523 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 2.7822010673468762 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 19.047190817844243 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 11.472676118868774 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 0.24504375029051115 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 21.699228577400795 is still above the tolerance 1e-06.
WARNING - 13:07:06: Analytical SellarChain has reached its maximum number of iterations but the normed residual 10.738983903854573 is still above the tolerance 1e-06.
INFO - 13:07:06: Compilation of the output function JAXSellarChain: 0:00:00.035777 seconds.
INFO - 13:07:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.051886 seconds.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.5905289425541702 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.077850700037392 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7675322815269972 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 5.865757994838282 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 2.4530760566352448 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 5.735982966477577 is still above the tolerance 1e-06.
INFO - 13:07:06: Compilation of the output function JAXSellarChain: 0:00:00.063203 seconds.
INFO - 13:07:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.164498 seconds.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.057078474261425156 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.218039117520956 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.457386212866514 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.4444239262702856 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.423388950237638 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.1635035339876243 is still above the tolerance 1e-06.
INFO - 13:07:06: Compilation of the output function JAXSellarChain: 0:00:00.067206 seconds.
INFO - 13:07:06: Compilation of the Jacobian function JAXSellarChain: 0:00:00.163735 seconds.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.42034906252499005 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.1907990925743115 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.8795693925359636 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.980711322874636 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 2.0503657593017284 is still above the tolerance 1e-06.
WARNING - 13:07:06: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.22831395010679814 is still above the tolerance 1e-06.
INFO - 13:07:07: Compilation of the output function JAXSellarChain: 0:00:00.069508 seconds.
INFO - 13:07:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.215236 seconds.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 336.4889874652211 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 176.57343516007123 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 6.076669703932627 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 91.74295851256814 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 457.7670391902996 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 64.91505899021577 is still above the tolerance 1e-06.
INFO - 13:07:07: Compilation of the output function JAXSellarChain: 0:00:00.036592 seconds.
INFO - 13:07:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.090182 seconds.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.1802048299408419 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.35677528174007517 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.1408249519976819 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.1307532507301117 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.1700128006386539 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.36258551114908394 is still above the tolerance 1e-06.
INFO - 13:07:07: Compilation of the output function JAXSellarChain: 0:00:00.057662 seconds.
INFO - 13:07:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.107873 seconds.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7675162746095148 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0809390979602116 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.704815940174381 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.10907833187398362 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.8756838198802682 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.49632541052977924 is still above the tolerance 1e-06.
INFO - 13:07:07: Compilation of the output function JAXSellarChain: 0:00:00.062264 seconds.
INFO - 13:07:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.130473 seconds.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.08467965860271073 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.8466493533683936 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.319325162155551 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 4.954864318329517 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.2709725651143005 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 3.0404356081104718 is still above the tolerance 1e-06.
INFO - 13:07:07: Compilation of the output function JAXSellarChain: 0:00:00.063597 seconds.
INFO - 13:07:07: Compilation of the Jacobian function JAXSellarChain: 0:00:00.145311 seconds.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.9486514981745048 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.24486550234169496 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.6276337271559553 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.3445061088528841 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.12572359926553556 is still above the tolerance 1e-06.
WARNING - 13:07:07: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.8045793338069069 is still above the tolerance 1e-06.
INFO - 13:07:07: Compilation of the output function JAXSellar1: 0:00:00.020468 seconds.
INFO - 13:07:07: Compilation of the Jacobian function JAXSellar1: 0:00:00.000094 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.017345 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000152 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.036879 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000102 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.11392108612590067 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.33201354768285496 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.5150350926819272 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 2.098231165255954 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.4258060006844228 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.109486066937409 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.026126 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000096 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.019922 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000093 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.038848 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000105 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.10088252297350085 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.1457639424039268 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.8823224570641361 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.243022397452796 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0896916117606323 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.3048558392580727 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.021111 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000090 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.018333 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000087 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.046291 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000105 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 4.313961926043567 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 4.396812314317121 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 2.3526506026115466 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 4.878869332401476 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.5235177333835672 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.7068615818317028 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.021427 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000094 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.020298 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000094 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.048391 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000108 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.13653315301425117 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7288338365089271 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.26884468326686856 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.3892393395653002 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.650492596662699 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.39230693086603263 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.020241 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000095 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.016788 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000092 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.026360 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000103 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7673555228938249 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.3855287570048787 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.1733986654790818 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7261806996323679 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.348705081446188 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.245130844434205 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.027618 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000096 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.019661 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000088 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.040515 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000098 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 9.803852288678877 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 14.805133944406151 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 4.291856907360072 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 12.609001801043162 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 11.441798544634315 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.12397643686445965 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.021565 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000100 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.018489 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000129 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.046559 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000109 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7468824598169091 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.4409475133255714 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.000402011001018926 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.4637268999300698 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.37735129558117086 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.10075337972869863 is still above the tolerance 1e-06.
INFO - 13:07:08: Compilation of the output function JAXSellar1: 0:00:00.020708 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar1: 0:00:00.000098 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellar2: 0:00:00.018939 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellar2: 0:00:00.000090 seconds.
INFO - 13:07:08: Compilation of the output function JAXSellarSystem: 0:00:00.044310 seconds.
INFO - 13:07:08: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000106 seconds.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 1.0 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 11.961795682051548 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 11.708642529432945 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 9.548273495662393 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 14.210221501046426 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 0.7205296806254416 is still above the tolerance 1e-06.
WARNING - 13:07:08: JAX SellarChain has reached its maximum number of iterations but the normed residual 15.434954566947555 is still above the tolerance 1e-06.
Now let's visualize our results:
mda_ref = next(iter(mdas.keys()))
t_ref = times[mda_ref]
speedup = {
mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()
}
fig, axes = subplots(2, 1, layout="constrained", figsize=(6, 8))
fig.suptitle("Speedup compared to NumPy Analytical")
for mda_name in mdas:
linestyle = ":" if mda_name == mda_ref else "-"
speedup_e, speedup_l = speedup[mda_name]
axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)
axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)
axes[0].legend(bbox_to_anchor=(0.9, -0.1))
axes[0].set_ylabel("Execution")
axes[0].set_xscale("log")
axes[1].set_ylabel("Linearization")
axes[1].set_xlabel("Dimension")
axes[1].set_xscale("log")
show()
Conclusion¶
JAX AD is as fast as analytical derivatives with NumPy. Encapsulation with JAXChain slows execution, but speeds-up linearization. Speedup is maintained even at higher dimensions.
Total running time of the script: ( 0 minutes 3.009 seconds)
Download Python source code: plot_sellar_scalable.py