{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a JAXDiscipline from a discipline using JAX.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from __future__ import annotations\n\nfrom typing import TYPE_CHECKING\n\nfrom gemseo.core.discipline.discipline import Discipline\nfrom jax.numpy import sqrt\nfrom numpy import array\n\nfrom gemseo_jax.utils import create_jax_discipline_from_discipline\n\nif TYPE_CHECKING:\n from gemseo_jax.jax_discipline import DataType" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This short example illustrates\nhow to create a [JAXDiscipline][gemseo_jax.jax_discipline.JAXDiscipline]\nfrom a standard [Discipline][gemseo.core.discipline.discipline.Discipline]\nusing JAX instead of NumPy and SciPy.\n\nFirst,\nlet us create such as discipline\nwhose single output is the square root of its single input multiplied by 2:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class DummyDisciplineUsingJAX(Discipline):\n \"\"\"A dummy discipline using JAX.\"\"\"\n\n default_grammar_type = Discipline.GrammarType.SIMPLER\n\n def __init__(self) -> None:\n super().__init__()\n self.io.input_grammar.update_from_names((\"in\",))\n self.io.output_grammar.update_from_names((\"out\",))\n self.io.input_grammar.defaults = {\"in\": array([1.0])}\n\n def _run(self, input_data: dict[str, DataType]) -> dict[str, DataType]:\n return {\"out\": 2 * sqrt(input_data[\"in\"])}\n\n\ndiscipline_using_jax = DummyDisciplineUsingJAX()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then,\nwe use the function\n[create_jax_discipline_from_discipline][gemseo_jax.utils.create_jax_discipline_from_discipline]\nto create a [JAXDiscipline][gemseo_jax.jax_discipline.JAXDiscipline]\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_discipline = create_jax_discipline_from_discipline(discipline_using_jax)\njax_discipline.add_differentiated_inputs([\"in\"])\njax_discipline.add_differentiated_outputs([\"out\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now,\nyou can use `jax_discipline`\nas any [JAXDiscipline][gemseo_jax.jax_discipline.JAXDiscipline].\nTo execute it from default input values:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_discipline.execute()\njax_discipline.io.data[\"out\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To execute it from new input values:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_discipline.execute({\"in\": array([3.0])})\njax_discipline.io.data[\"out\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To compute its Jacobian:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jax_discipline.linearize({\"in\": array([3.0])})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "!!! note\n This [JAXDiscipline][gemseo_jax.jax_discipline.JAXDiscipline]\n is also compatible with JIT compilation.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 0 }