Create a JAXDiscipline from a discipline using JAX.

Note

Click here to download the full example code

Create a JAXDiscipline from a discipline using JAX.

from __future__ import annotations

from typing import TYPE_CHECKING

from gemseo.core.discipline.discipline import Discipline
from jax.numpy import sqrt
from numpy import array

from gemseo_jax.utils import create_jax_discipline_from_discipline

if TYPE_CHECKING:
    from gemseo_jax.jax_discipline import DataType

This short example illustrates how to create a JAXDiscipline from a standard Discipline using JAX instead of NumPy and SciPy.

First, let us create such as discipline whose single output is the square root of its single input multiplied by 2:

class DummyDisciplineUsingJAX(Discipline):
    """A dummy discipline using JAX."""

    default_grammar_type = Discipline.GrammarType.SIMPLER

    def __init__(self) -> None:
        super().__init__()
        self.io.input_grammar.update_from_names(("in",))
        self.io.output_grammar.update_from_names(("out",))
        self.io.input_grammar.defaults = {"in": array([1.0])}

    def _run(self, input_data: dict[str, DataType]) -> dict[str, DataType]:
        return {"out": 2 * sqrt(input_data["in"])}


discipline_using_jax = DummyDisciplineUsingJAX()

Then, we use the function create_jax_discipline_from_discipline to create a JAXDiscipline

jax_discipline = create_jax_discipline_from_discipline(discipline_using_jax)
jax_discipline.add_differentiated_inputs(["in"])
jax_discipline.add_differentiated_outputs(["out"])

Now, you can use jax_discipline as any JAXDiscipline. To execute it from default input values:

jax_discipline.execute()
jax_discipline.io.data["out"]

Out:

array([2.])

To execute it from new input values:

jax_discipline.execute({"in": array([3.0])})
jax_discipline.io.data["out"]

Out:

array([3.46410162])

To compute its Jacobian:

jax_discipline.linearize({"in": array([3.0])})

Out:

{'out': {'in': Array([[0.57735027]], dtype=float64)}}

Note

This JAXDiscipline is also compatible with JIT compilation.

Total running time of the script: ( 0 minutes 0.187 seconds)

Download Python source code: plot_from_discipline.py

Download Jupyter notebook: plot_from_discipline.ipynb

Gallery generated by mkdocs-gallery