Jax discipline
jax_discipline
¶
A discipline interfacing a JAX function.
Classes¶
JAXDiscipline
¶
JAXDiscipline(
function: Callable[[DataType], DataType],
input_names: Sequence[str],
output_names: Sequence[str],
default_inputs: Mapping[str, NumberLike],
differentiation_method: DifferentiationMethod = AUTO,
differentiate_at_execution: bool = False,
name: str | None = None,
)
Bases: Discipline
A discipline interfacing a JAX function.
Initialize the JAXDiscipline.
Parameters:
-
function(Callable[[DataType], DataType]) –The JAX function that takes a dictionary
{input_name: input_value, ...}as argument and returns a dictionary{output_name: output_value, ...}. -
input_names(Sequence[str]) –The names of the input variables.
-
output_names(Sequence[str]) –The names of the output variables.
-
default_inputs(Mapping[str, NumberLike]) –The default values of the input variables.
-
differentiation_method(DifferentiationMethod, default:AUTO) –The method to compute the Jacobian.
-
differentiate_at_execution(bool, default:False) –Whether to compute the Jacobian when executing the discipline.
-
name(str | None, default:None) –The name of the discipline. If empty, use the name of the class.
Source code in src/gemseo_jax/jax_discipline.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
Attributes¶
jax_out_func
instance-attribute
¶
jax_out_func: Callable[[DataType], DataType] = function
The JAX function to compute the outputs from the inputs.
Classes¶
DifferentiationMethod
¶
Bases: StrEnum
The method to compute the Jacobian.
Functions¶
add_differentiated_inputs
¶
Add the inputs with respect to which to differentiate the outputs.
The inputs that do not represent continuous numbers are filtered out.
Parameters:
-
input_names(Iterable[str], default:()) –The input variables with respect to which to differentiate the outputs. If empty, use all the inputs.
Raises:
-
ValueError–When an input name is not the name of a discipline input.
Notes
The Jacobian is also filtered to view non-differentiated static.
Source code in src/gemseo_jax/jax_discipline.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
add_differentiated_outputs
¶
Add the outputs to be differentiated.
The outputs that do not represent continuous numbers are filtered out.
Parameters:
-
output_names(Iterable[str], default:()) –The outputs to be differentiated. If empty, use all the outputs.
Raises:
-
ValueError–When an output name is not the name of a discipline output.
Notes
The Jacobian is also filtered to view non-differentiated static.
Source code in src/gemseo_jax/jax_discipline.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | |
compile_jit
¶
compile_jit(pre_run: bool = True) -> None
Apply jit compilation over function and jacobian.
Parameters:
-
pre_run(bool, default:True) –Whether to call jitted callables once to trigger compilation and log times.
Warning
Calling add_differentiated_inputs and add_differentiated_outputs must be done before calling compile_jit.
Source code in src/gemseo_jax/jax_discipline.py
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | |