{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Experimenting with Sellar MDA with JAX.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import annotations\n\nfrom datetime import timedelta\nfrom logging import getLogger\nfrom timeit import default_timer\nfrom typing import TYPE_CHECKING\n\nfrom gemseo import configure_logger\nfrom gemseo import create_mda\n\nfrom gemseo_jax.jax_chain import JAXChain\nfrom gemseo_jax.problems.sellar.sellar_1 import JAXSellar1\nfrom gemseo_jax.problems.sellar.sellar_2 import JAXSellar2\nfrom gemseo_jax.problems.sellar.sellar_chain import JAXSellarChain\nfrom gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem\n\nif TYPE_CHECKING:\n from gemseo.mda.base_mda import BaseMDA\n\nLOGGER = getLogger(__name__)\n\n\nconfigure_logger()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are many options when combining JAXDisciplines for execution:\n - JAXChain encapsulation:\n This wraps a set of disciplines and executes as a single monolithic one.\n Pros: Execution and compilation are faster.\n Cons: Incompatible with MDANewtonRaphson.\n\n - .compile_jit():\n This compiles the function used to compute outputs and Jacobian.\n Pros: Faster execution.\n Cons: Compilation time can exceed total execution if function is called\n only few times.\n Note: Compilation itself does not takes place in this method, but only once a\n jit-compiled function is executed, that is why we may add a pre-run.\n - pre-run:\n Executes the output and Jacobian function once to ensure compilation.\n Pros: Ensures benchmarks include only execution times.\n Cons: Compilation time is included in the 1st evaluation.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the disciplines with `AutoJAXDiscipline`:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_disciplines():\n \"\"\"Get new instances of the disciplines of the Sellar problem.\"\"\"\n return [JAXSellar1(), JAXSellar2(), JAXSellarSystem()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the function to run MDA and log execution times:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def execute_and_log_mda(name: str, mda_chain: BaseMDA) -> None:\n \"\"\"Execute mda and log total execution time.\"\"\"\n t0 = default_timer()\n mda_chain.execute()\n t1 = default_timer()\n # mda.plot_residual_history(show=True, save=False)\n LOGGER.info(\n \"MDA execution %s: %s seconds.\",\n name,\n timedelta(seconds=t1 - t0),\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# No JAXChain encapsulation\n\n## MDA over separate disciplines WITHOUT compilation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "disciplines = get_disciplines()\nmda = create_mda(\n \"MDAChain\",\n disciplines, # separate disciplines\n inner_mda_name=\"MDAJacobi\",\n)\nexecute_and_log_mda(\"separate disciplines (no jit)\", mda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MDA over separate disciplines WITH compilation WITHOUT pre-run\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "disciplines = get_disciplines()\nfor disc in disciplines:\n disc.compile_jit(pre_run=False)\nmda = create_mda(\n \"MDAChain\",\n disciplines, # separate disciplines\n inner_mda_name=\"MDAJacobi\",\n)\nexecute_and_log_mda(\"separate disciplines (jit, no pre-run)\", mda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MDA over separate disciplines WITH compilation WITH pre-run (standard)\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "disciplines = get_disciplines()\nfor disc in disciplines:\n disc.compile_jit()\nmda = create_mda(\n \"MDAChain\",\n disciplines, # separate disciplines\n inner_mda_name=\"MDAJacobi\",\n)\nexecute_and_log_mda(\"separate disciplines (jit, pre-run)\", mda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\nMDA is 1.8x faster with JIT compilation. If compilation times are excluded from\nbenchmark, the speedup is 10x!\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# With JAXChain encapsulation\n\n## MDA over JAXChain WITHOUT compilation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_chain = JAXChain(get_disciplines())\nmda = create_mda(\n \"MDAChain\",\n [jax_chain], # chain as single discipline\n inner_mda_name=\"MDAJacobi\",\n)\nexecute_and_log_mda(\"chained disciplines (no jit)\", mda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MDA over JAXChain WITH compilation WITHOUT pre-run\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_chain = JAXSellarChain(pre_run=False)\nmda = create_mda(\n \"MDAChain\",\n [jax_chain], # chain as single discipline\n inner_mda_name=\"MDAJacobi\",\n)\nexecute_and_log_mda(\"chained disciplines (jit, no pre-run)\", mda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## MDA over JAXChain WITH compilation WITH pre-run\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_chain = JAXSellarChain()\njax_chain.compile_jit()\nmda = create_mda(\n \"MDAChain\",\n [jax_chain], # chain as single discipline\n inner_mda_name=\"MDAJacobi\",\n)\nexecute_and_log_mda(\"chained disciplines (jit, pre-run)\", mda)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conclusion\nEncapsulation with JAXChain (without JIT) allows for 1.4x speedup.\nJIT compilation allows for 1.8x speedup relative to un-jitted JAXChain and 2.6x\nrelative to un-jitted separate disciplines.\nIf compilation times are excluded from benchmark, these speedups are 2.2x and 3.2x,\nrespectively.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }