Jax chain
jax_chain
¶
Module for executing a chain of JAXDiscipline's at once in JAX.
The :class:.MDOCouplingStructure of the :class:.JAXDisciplines is used to get the
correct sequence of function calls, according to the dependencies among functions.
Note
If there is a coupling within disciplines, the resulting chain will be self-coupled, i.e., some variables are inputs and outputs to the chain and one chain execution corresponds to one fixed-point iteration.
Classes¶
JAXChain
¶
JAXChain(
disciplines: Sequence[JAXDiscipline],
differentiation_method: DifferentiationMethod = AUTO,
differentiate_at_execution: bool = False,
name: str | None = None,
)
Bases: JAXDiscipline
A chain of JAX disciplines.
Parameters:
-
disciplines(Sequence[JAXDiscipline]) –The JAX disciplines to create the chain over.
Source code in src/gemseo_jax/jax_chain.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | |
Attributes¶
jax_out_func
instance-attribute
¶
jax_out_func: Callable[[DataType], DataType] = function
The JAX function to compute the outputs from the inputs.
Classes¶
DifferentiationMethod
¶
Bases: StrEnum
The method to compute the Jacobian.
Functions¶
add_differentiated_inputs
¶
add_differentiated_inputs(
input_names: Iterable[str] = (),
) -> None
Notes
The Jacobian is also filtered to view non-differentiated static.
Source code in src/gemseo_jax/jax_discipline.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
add_differentiated_outputs
¶
add_differentiated_outputs(
output_names: Iterable[str] = (),
) -> None
Notes
The Jacobian is also filtered to view non-differentiated static.
Source code in src/gemseo_jax/jax_discipline.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | |
compile_jit
¶
compile_jit(pre_run: bool = True) -> None
Apply jit compilation over function and jacobian.
Parameters:
-
pre_run(bool, default:True) –Whether to call jitted callables once to trigger compilation and log times.
Warning
Calling add_differentiated_inputs and add_differentiated_outputs must be done before calling compile_jit.
Source code in src/gemseo_jax/jax_discipline.py
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |