Skip to content

Jax discipline

jax_discipline

A discipline interfacing a JAX function.

Classes

JAXDiscipline

JAXDiscipline(
    function: Callable[[DataType], DataType],
    input_names: Sequence[str],
    output_names: Sequence[str],
    default_inputs: Mapping[str, NumberLike],
    differentiation_method: DifferentiationMethod = AUTO,
    differentiate_at_execution: bool = False,
    name: str | None = None,
)

Bases: Discipline

A discipline interfacing a JAX function.

Initialize the JAXDiscipline.

Parameters:

  • function (Callable[[DataType], DataType]) –

    The JAX function that takes a dictionary {input_name: input_value, ...} as argument and returns a dictionary {output_name: output_value, ...}.

  • input_names (Sequence[str]) –

    The names of the input variables.

  • output_names (Sequence[str]) –

    The names of the output variables.

  • default_inputs (Mapping[str, NumberLike]) –

    The default values of the input variables.

  • differentiation_method (DifferentiationMethod, default: AUTO ) –

    The method to compute the Jacobian.

  • differentiate_at_execution (bool, default: False ) –

    Whether to compute the Jacobian when executing the discipline.

  • name (str | None, default: None ) –

    The name of the discipline. If empty, use the name of the class.

Source code in src/gemseo_jax/jax_discipline.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def __init__(
    self,
    function: Callable[[DataType], DataType],
    input_names: Sequence[str],
    output_names: Sequence[str],
    default_inputs: Mapping[str, NumberLike],
    differentiation_method: DifferentiationMethod = DifferentiationMethod.AUTO,
    differentiate_at_execution: bool = False,
    name: str | None = None,
) -> None:
    """Initialize the JAXDiscipline.

    Args:
        function: The JAX function that takes a dictionary
            ``{input_name: input_value, ...}`` as argument and returns a dictionary
            ``{output_name: output_value, ...}``.
        input_names: The names of the input variables.
        output_names: The names of the output variables.
        default_inputs: The default values of the input variables.
        differentiation_method: The method to compute the Jacobian.
        differentiate_at_execution: Whether to compute the Jacobian when executing
            the discipline.
    """  # noqa: D205, D212, D415
    super().__init__(name=name)
    self.input_grammar.update_from_names(input_names)
    self.output_grammar.update_from_names(output_names)
    self.default_input_data = {
        input_name: np_array(input_value)
        if isinstance(input_value, JAXArray)
        else input_value
        for input_name, input_value in default_inputs.items()
    }
    self.__differentiate_at_execution = differentiate_at_execution
    self.jax_out_func = function
    self.__differentiation_method = differentiation_method
    self.__jax_jac_func = self.__create_jacobian_function(self.jax_out_func)
    self.__sizes = {}
    self.__jac_shape = {}
Attributes
jax_out_func instance-attribute
jax_out_func: Callable[[DataType], DataType] = function

The JAX function to compute the outputs from the inputs.

Classes
DifferentiationMethod

Bases: StrEnum

The method to compute the Jacobian.

Functions
add_differentiated_inputs
add_differentiated_inputs(
    input_names: Iterable[str] = (),
) -> None

Add the inputs with respect to which to differentiate the outputs.

The inputs that do not represent continuous numbers are filtered out.

Parameters:

  • input_names (Iterable[str], default: () ) –

    The input variables with respect to which to differentiate the outputs. If empty, use all the inputs.

Raises:

  • ValueError

    When an input name is not the name of a discipline input.

Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def add_differentiated_inputs(
    self,
    input_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_inputs = self._differentiated_input_names.copy()
    super().add_differentiated_inputs(input_names=input_names)
    refilter = any(
        input_name not in old_differentiated_inputs
        for input_name in self._differentiated_input_names
    )
    if refilter:
        self._filter_jacobian()
add_differentiated_outputs
add_differentiated_outputs(
    output_names: Iterable[str] = (),
) -> None

Add the outputs to be differentiated.

The outputs that do not represent continuous numbers are filtered out.

Parameters:

  • output_names (Iterable[str], default: () ) –

    The outputs to be differentiated. If empty, use all the outputs.

Raises:

  • ValueError

    When an output name is not the name of a discipline output.

Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_differentiated_outputs(
    self,
    output_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_outputs = self._differentiated_output_names.copy()
    super().add_differentiated_outputs(output_names=output_names)
    refilter = any(
        output_name not in old_differentiated_outputs
        for output_name in self._differentiated_output_names
    )
    if refilter:
        self._filter_jacobian()
compile_jit
compile_jit(pre_run: bool = True) -> None

Apply jit compilation over function and jacobian.

Parameters:

  • pre_run (bool, default: True ) –

    Whether to call jitted callables once to trigger compilation and log times.

Source code in src/gemseo_jax/jax_discipline.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def compile_jit(
    self,
    pre_run: bool = True,
) -> None:
    """Apply jit compilation over function and jacobian.

    Args:
        pre_run: Whether to call jitted callables once to trigger compilation and
            log times.
    """
    self.jax_out_func = jit(self.jax_out_func)
    self.__jax_jac_func = jit(self.__jax_jac_func)
    if pre_run:
        self._jit_pre_run()