Create a JAXDiscipline from a discipline using JAX.
Note
Click here to download the full example code
Create a JAXDiscipline from a discipline using JAX.
from __future__ import annotations
from typing import TYPE_CHECKING
from gemseo.core.discipline.discipline import Discipline
from jax.numpy import sqrt
from numpy import array
from gemseo_jax.utils import create_jax_discipline_from_discipline
if TYPE_CHECKING:
from gemseo_jax.jax_discipline import DataType
This short example illustrates how to create a JAXDiscipline from a standard Discipline using JAX instead of NumPy and SciPy.
First, let us create such as discipline whose single output is the square root of its single input multiplied by 2:
class DummyDisciplineUsingJAX(Discipline):
"""A dummy discipline using JAX."""
default_grammar_type = Discipline.GrammarType.SIMPLER
def __init__(self) -> None:
super().__init__()
self.io.input_grammar.update_from_names(("in",))
self.io.output_grammar.update_from_names(("out",))
self.io.input_grammar.defaults = {"in": array([1.0])}
def _run(self, input_data: dict[str, DataType]) -> dict[str, DataType]:
return {"out": 2 * sqrt(input_data["in"])}
discipline_using_jax = DummyDisciplineUsingJAX()
Then, we use the function create_jax_discipline_from_discipline to create a JAXDiscipline
jax_discipline = create_jax_discipline_from_discipline(discipline_using_jax)
jax_discipline.add_differentiated_inputs(["in"])
jax_discipline.add_differentiated_outputs(["out"])
Now,
you can use jax_discipline
as any JAXDiscipline.
To execute it from default input values:
jax_discipline.execute()
jax_discipline.io.data["out"]
Out:
array([2.])
To execute it from new input values:
jax_discipline.execute({"in": array([3.0])})
jax_discipline.io.data["out"]
Out:
array([3.46410162])
To compute its Jacobian:
jax_discipline.linearize({"in": array([3.0])})
Out:
{'out': {'in': Array([[0.57735027]], dtype=float64)}}
Note
This JAXDiscipline is also compatible with JIT compilation.
Total running time of the script: ( 0 minutes 0.187 seconds)
Download Python source code: plot_from_discipline.py