{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Analysis of the scalable Sellar problem with JAX.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import annotations\n\nfrom datetime import timedelta\nfrom timeit import default_timer\nfrom typing import TYPE_CHECKING\n\nfrom gemseo import configure\nfrom gemseo import configure_logger\nfrom gemseo import create_mda\nfrom gemseo.core.discipline.discipline import Discipline\nfrom gemseo.problems.mdo.sellar.sellar_1 import Sellar1\nfrom gemseo.problems.mdo.sellar.sellar_2 import Sellar2\nfrom gemseo.problems.mdo.sellar.sellar_system import SellarSystem\nfrom gemseo.problems.mdo.sellar.utils import get_initial_data\nfrom matplotlib.pyplot import show\nfrom matplotlib.pyplot import subplots\nfrom numpy import array\nfrom numpy.random import default_rng\n\nfrom gemseo_jax.jax_chain import DifferentiationMethod\nfrom gemseo_jax.problems.sellar.sellar_1 import JAXSellar1\nfrom gemseo_jax.problems.sellar.sellar_2 import JAXSellar2\nfrom gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain\nfrom gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem\n\nif TYPE_CHECKING:\n from gemseo.mda.base_mda import BaseMDA\n from gemseo.typing import RealArray\n\n# Deactivate some checkers to speed up calculations in presence of cheap disciplines.\nconfigure(False, False, True, False, False, False, False)\nconfigure_logger()\n\n\ndef get_random_input_data(n: int) -> dict[str, RealArray]:\n \"\"\"Return a random input value for [JAX]SellarSystem.\"\"\"\n r_float = default_rng().random()\n return {\n name: 1.5 * r_float * value for name, value in get_initial_data(n=n).items()\n }\n\n\ndef get_numpy_disciplines(n: int) -> list[Discipline]:\n \"\"\"Return the NumPy-based Sellar disciplines.\"\"\"\n return [\n Sellar1(n=n),\n Sellar2(n=n),\n SellarSystem(n=n),\n ]\n\n\ndef get_jax_disciplines(\n n: int, differentiation_method=DifferentiationMethod.AUTO\n) -> list[Discipline]:\n \"\"\"Return the JAX-based Sellar disciplines.\"\"\"\n disciplines = [\n JAXSellar1(n=n, differentiation_method=differentiation_method),\n JAXSellar2(n=n, differentiation_method=differentiation_method),\n JAXSellarSystem(n=n, differentiation_method=differentiation_method),\n ]\n for disc in disciplines:\n disc.set_cache(Discipline.CacheType.SIMPLE)\n disc.compile_jit()\n\n return disciplines" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Initial setup for comparison\nHere we intend to compare the original NumPy implementation with the JAX one.\nWe then need to create the original MDA and one JAX MDA for each configuration we're\ntesting. In this example we compare the performance of the JAXChain encapsulation and\nalso the forward and reverse modes for automatic differentiation.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_analytical_mda(n: int, mda_name=\"MDAGaussSeidel\", max_mda_iter=5) -> BaseMDA:\n \"\"\"Return the Sellar MDA with analytical NumPy Jacobian.\"\"\"\n mda = create_mda(\n mda_name=mda_name,\n disciplines=get_numpy_disciplines(n),\n max_mda_iter=max_mda_iter,\n name=\"Analytical SellarChain\",\n )\n mda.set_cache(Discipline.CacheType.SIMPLE)\n return mda\n\n\ndef get_forward_ad_mda(n: int, mda_name=\"MDAGaussSeidel\", max_mda_iter=5) -> BaseMDA:\n \"\"\"Return the Sellar MDA with JAX forward-mode AD Jacobian.\"\"\"\n mda = create_mda(\n mda_name=mda_name,\n disciplines=get_jax_disciplines(n, DifferentiationMethod.FORWARD),\n max_mda_iter=max_mda_iter,\n name=\"JAX SellarChain\",\n )\n mda.set_cache(Discipline.CacheType.SIMPLE)\n return mda\n\n\ndef get_chained_forward_ad_mda(\n n: int, mda_name=\"MDAGaussSeidel\", max_mda_iter=5\n) -> BaseMDA:\n \"\"\"Return the Sellar MDA with JAXChain encapsulation and forward-mode Jacobian.\"\"\"\n discipline = JAXSellarChain(\n n=n,\n differentiation_method=DifferentiationMethod.FORWARD,\n )\n discipline.add_differentiated_inputs(discipline.input_grammar.names)\n discipline.add_differentiated_outputs(discipline.output_grammar.names)\n\n mda = create_mda(\n mda_name=mda_name,\n disciplines=[discipline],\n max_mda_iter=max_mda_iter,\n name=\"JAX SellarChain\",\n )\n mda.set_cache(Discipline.CacheType.SIMPLE)\n return mda\n\n\ndef get_reverse_ad_mda(n: int, mda_name=\"MDAGaussSeidel\", max_mda_iter=5) -> BaseMDA:\n \"\"\"Return the Sellar MDA with JAX reverse-mode AD Jacobian.\"\"\"\n mda = create_mda(\n mda_name=mda_name,\n disciplines=get_jax_disciplines(n, DifferentiationMethod.REVERSE),\n max_mda_iter=max_mda_iter,\n name=\"JAX SellarChain\",\n )\n mda.set_cache(Discipline.CacheType.SIMPLE)\n return mda\n\n\ndef get_chained_reverse_ad_mda(\n n: int, mda_name=\"MDAGaussSeidel\", max_mda_iter=5\n) -> BaseMDA:\n \"\"\"Return the Sellar MDA with JAXChain encapsulation and reverse-mode Jacobian.\"\"\"\n discipline = JAXSellarChain(\n n=n,\n differentiation_method=DifferentiationMethod.REVERSE,\n )\n discipline.add_differentiated_inputs(discipline.input_grammar.names)\n discipline.add_differentiated_outputs(discipline.output_grammar.names)\n\n mda = create_mda(\n mda_name=mda_name,\n disciplines=[discipline],\n max_mda_iter=max_mda_iter,\n name=\"JAX SellarChain\",\n )\n mda.set_cache(Discipline.CacheType.SIMPLE)\n return mda\n\n\nmdas = {\n \"MDOChain[NumPy] - Analytical\": get_analytical_mda, # this is the reference\n \"JAXChain - Forward AD\": get_chained_forward_ad_mda,\n \"JAXChain - Reverse AD\": get_chained_reverse_ad_mda,\n \"MDOChain[JAX] - Forward AD\": get_forward_ad_mda,\n \"MDOChain[JAX] - Reverse AD\": get_reverse_ad_mda,\n}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Execution and linearization scalability\n\nLet's make a function to execute and linearize an MDA, while logging times.\nAlso, we run several repetitions to avoid noisy results:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def run_and_log(get_mda, dimension, n_repeat=7, **mda_options):\n mda = get_mda(dimension, **mda_options)\n t0 = default_timer()\n for _i in range(n_repeat):\n mda.execute({\n name: value\n for name, value in get_random_input_data(dimension).items()\n if name in mda.input_grammar.names\n })\n t1 = default_timer()\n t_execute = timedelta(seconds=t1 - t0) / float(n_repeat)\n\n t2 = default_timer()\n for _i in range(n_repeat):\n mda.linearize({\n name: value\n for name, value in get_random_input_data(dimension).items()\n if name in mda.input_grammar.names\n })\n t3 = default_timer()\n t_linearize = timedelta(seconds=t3 - t2) / float(n_repeat)\n return t_execute, t_linearize" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Run the MDA for each of the mdas, for several number of dimensions\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dimensions = [1, 10, 100, 1000]\ntimes = {}\nmda_config = {\"mda_name\": \"MDAGaussSeidel\", \"max_mda_iter\": 1}\nfor mda_name, mda_func in mdas.items():\n time_exec = []\n time_lin = []\n for dimension in dimensions:\n t_e, t_l = run_and_log(mda_func, dimension, **mda_config)\n time_exec.append(t_e)\n time_lin.append(t_l)\n times[mda_name] = (array(time_exec), array(time_lin))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's visualize our results:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mda_ref = next(iter(mdas.keys()))\nt_ref = times[mda_ref]\nspeedup = {\n mda_name: (t_e / t_ref[0], t_l / t_ref[1]) for mda_name, (t_e, t_l) in times.items()\n}\n\nfig, axes = subplots(2, 1, layout=\"constrained\", figsize=(6, 8))\nfig.suptitle(\"Speedup compared to NumPy Analytical\")\nfor mda_name in mdas:\n linestyle = \":\" if mda_name == mda_ref else \"-\"\n speedup_e, speedup_l = speedup[mda_name]\n axes[0].plot(dimensions, speedup_e, linestyle, label=mda_name)\n axes[1].plot(dimensions, speedup_l, linestyle, label=mda_name)\naxes[0].legend(bbox_to_anchor=(0.9, -0.1))\naxes[0].set_ylabel(\"Execution\")\naxes[0].set_xscale(\"log\")\naxes[1].set_ylabel(\"Linearization\")\naxes[1].set_xlabel(\"Dimension\")\naxes[1].set_xscale(\"log\")\nshow()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conclusion\nJAX AD is as fast as analytical derivatives with NumPy.\nEncapsulation with JAXChain slows execution, but speeds-up linearization.\nSpeedup is maintained even at higher dimensions.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }