Skip to content

Note

Click here to download the full example code

Experimenting with Sellar MDA with JAX.

from __future__ import annotations

from datetime import timedelta
from logging import getLogger
from timeit import default_timer
from typing import TYPE_CHECKING

from gemseo import configure_logger
from gemseo import create_mda

from gemseo_jax.jax_chain import JAXChain
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem

if TYPE_CHECKING:
    from gemseo.mda.base_mda import BaseMDA

LOGGER = getLogger(__name__)


configure_logger()

Out:

<RootLogger root (INFO)>

There are many options when combining JAXDisciplines for execution: - JAXChain encapsulation: This wraps a set of disciplines and executes as a single monolithic one. Pros: Execution and compilation are faster. Cons: Incompatible with MDANewtonRaphson.

  • .compile_jit(): This compiles the function used to compute outputs and Jacobian. Pros: Faster execution. Cons: Compilation time can exceed total execution if function is called only few times. Note: Compilation itself does not takes place in this method, but only once a jit-compiled function is executed, that is why we may add a pre-run.
    • pre-run: Executes the output and Jacobian function once to ensure compilation. Pros: Ensures benchmarks include only execution times. Cons: Compilation time is included in the 1st evaluation.

Create the disciplines with AutoJAXDiscipline:

def get_disciplines():
    """Get new instances of the disciplines of the Sellar problem."""
    return [JAXSellar1(), JAXSellar2(), JAXSellarSystem()]

Create the function to run MDA and log execution times:

def execute_and_log_mda(name: str, mda_chain: BaseMDA) -> None:
    """Execute mda and log total execution time."""
    t0 = default_timer()
    mda_chain.execute()
    t1 = default_timer()
    # mda.plot_residual_history(show=True, save=False)
    LOGGER.info(
        "MDA execution %s: %s seconds.",
        name,
        timedelta(seconds=t1 - t0),
    )

No JAXChain encapsulation

MDA over separate disciplines WITHOUT compilation

disciplines = get_disciplines()
mda = create_mda(
    "MDAChain",
    disciplines,  # separate disciplines
    inner_mda_name="MDAJacobi",
)
execute_and_log_mda("separate disciplines (no jit)", mda)

Out:

    INFO - 13:07:05: MDA execution separate disciplines (no jit): 0:00:00.072983 seconds.

MDA over separate disciplines WITH compilation WITHOUT pre-run

disciplines = get_disciplines()
for disc in disciplines:
    disc.compile_jit(pre_run=False)
mda = create_mda(
    "MDAChain",
    disciplines,  # separate disciplines
    inner_mda_name="MDAJacobi",
)
execute_and_log_mda("separate disciplines (jit, no pre-run)", mda)

Out:

    INFO - 13:07:05: MDA execution separate disciplines (jit, no pre-run): 0:00:00.062108 seconds.

MDA over separate disciplines WITH compilation WITH pre-run (standard)

disciplines = get_disciplines()
for disc in disciplines:
    disc.compile_jit()
mda = create_mda(
    "MDAChain",
    disciplines,  # separate disciplines
    inner_mda_name="MDAJacobi",
)
execute_and_log_mda("separate disciplines (jit, pre-run)", mda)

Out:

    INFO - 13:07:05: Compilation of the output function JAXSellar1: 0:00:00.020623 seconds.
    INFO - 13:07:05: Compilation of the Jacobian function JAXSellar1: 0:00:00.000108 seconds.
    INFO - 13:07:05: Compilation of the output function JAXSellar2: 0:00:00.017226 seconds.
    INFO - 13:07:05: Compilation of the Jacobian function JAXSellar2: 0:00:00.000092 seconds.
    INFO - 13:07:05: Compilation of the output function JAXSellarSystem: 0:00:00.027900 seconds.
    INFO - 13:07:05: Compilation of the Jacobian function JAXSellarSystem: 0:00:00.000104 seconds.
    INFO - 13:07:05: MDA execution separate disciplines (jit, pre-run): 0:00:00.006799 seconds.

Conclusion

MDA is 1.8x faster with JIT compilation. If compilation times are excluded from benchmark, the speedup is 10x!

With JAXChain encapsulation

MDA over JAXChain WITHOUT compilation

jax_chain = JAXChain(get_disciplines())
mda = create_mda(
    "MDAChain",
    [jax_chain],  # chain as single discipline
    inner_mda_name="MDAJacobi",
)
execute_and_log_mda("chained disciplines (no jit)", mda)

Out:

    INFO - 13:07:05: MDA execution chained disciplines (no jit): 0:00:00.083352 seconds.

MDA over JAXChain WITH compilation WITHOUT pre-run

jax_chain = JAXSellarChain(pre_run=False)
mda = create_mda(
    "MDAChain",
    [jax_chain],  # chain as single discipline
    inner_mda_name="MDAJacobi",
)
execute_and_log_mda("chained disciplines (jit, no pre-run)", mda)

Out:

    INFO - 13:07:05: MDA execution chained disciplines (jit, no pre-run): 0:00:00.043600 seconds.

MDA over JAXChain WITH compilation WITH pre-run

jax_chain = JAXSellarChain()
jax_chain.compile_jit()
mda = create_mda(
    "MDAChain",
    [jax_chain],  # chain as single discipline
    inner_mda_name="MDAJacobi",
)
execute_and_log_mda("chained disciplines (jit, pre-run)", mda)

Out:

    INFO - 13:07:05: Compilation of the output function JAXSellarChain: 0:00:00.037024 seconds.
    INFO - 13:07:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.054044 seconds.
    INFO - 13:07:05: Compilation of the output function JAXSellarChain: 0:00:00.000134 seconds.
    INFO - 13:07:05: Compilation of the Jacobian function JAXSellarChain: 0:00:00.000133 seconds.
    INFO - 13:07:05: MDA execution chained disciplines (jit, pre-run): 0:00:00.039438 seconds.

Conclusion

Encapsulation with JAXChain (without JIT) allows for 1.4x speedup. JIT compilation allows for 1.8x speedup relative to un-jitted JAXChain and 2.6x relative to un-jitted separate disciplines. If compilation times are excluded from benchmark, these speedups are 2.2x and 3.2x, respectively.

Total running time of the script: ( 0 minutes 0.518 seconds)

Download Python source code: plot_sellar_mda.py

Download Jupyter notebook: plot_sellar_mda.ipynb

Gallery generated by mkdocs-gallery