Skip to content

Jax chain

jax_chain

Module for executing a chain of JAXDiscipline's at once in JAX.

The :class:.MDOCouplingStructure of the :class:.JAXDisciplines is used to get the correct sequence of function calls, according to the dependencies among functions.

Note

If there is a coupling within disciplines, the resulting chain will be self-coupled, i.e., some variables are inputs and outputs to the chain and one chain execution corresponds to one fixed-point iteration.

Classes

JAXChain

JAXChain(
    disciplines: Sequence[JAXDiscipline],
    differentiation_method: DifferentiationMethod = AUTO,
    differentiate_at_execution: bool = False,
    name: str | None = None,
)

Bases: JAXDiscipline

A chain of JAX disciplines.

Initialize the JAXDiscipline.

Parameters:

  • disciplines (Sequence[JAXDiscipline]) –

    The JAX disciplines to create the chain over.

  • differentiation_method (DifferentiationMethod, default: AUTO ) –

    The method to compute the Jacobian.

  • differentiate_at_execution (bool, default: False ) –

    Whether to compute the Jacobian when executing the discipline.

  • name (str | None, default: None ) –

    The name of the discipline. If empty, use the name of the class.

Source code in src/gemseo_jax/jax_chain.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def __init__(
    self,
    disciplines: Sequence[JAXDiscipline],
    differentiation_method: DifferentiationMethod = DifferentiationMethod.AUTO,
    differentiate_at_execution: bool = False,
    name: str | None = None,
) -> None:
    """
    Args:
        disciplines: The JAX disciplines to create the chain over.
    """  # noqa: D205, D212, D415, D417
    self.__sequence = CouplingStructure(disciplines).sequence

    # Generate input and output names according to _output_sequence, this
    # adds coupling variables as inputs (as they may be required before computation)
    input_names = []
    output_names = []
    for mdas_at_priority in self.__sequence:
        for mda_at_mdas in mdas_at_priority:
            for disc in mda_at_mdas:
                input_names.extend([
                    var
                    for var in disc.input_grammar.names
                    if var not in output_names
                ])
                output_names.extend(disc.output_grammar.names)

    default_inputs = {}
    for discipline in disciplines:
        default_inputs.update({
            input_name: input_value
            for input_name, input_value in discipline.default_input_data.items()
            if input_name in input_names
        })

    super().__init__(
        function=self.__compute_all,
        input_names=input_names,
        output_names=output_names,
        default_inputs=default_inputs,
        differentiation_method=differentiation_method,
        differentiate_at_execution=differentiate_at_execution,
        name=name,
    )

    # Add self-coupled variables into differentiated inputs and outputs for MDA
    self_coupled_vars = set(self.input_grammar.names) & set(
        self.output_grammar.names
    )
    self.add_differentiated_inputs(self_coupled_vars)
    self.add_differentiated_outputs(self_coupled_vars)
Attributes
jax_out_func instance-attribute
jax_out_func: Callable[[DataType], DataType] = function

The JAX function to compute the outputs from the inputs.

Classes
DifferentiationMethod

Bases: StrEnum

The method to compute the Jacobian.

Functions
add_differentiated_inputs
add_differentiated_inputs(
    input_names: Iterable[str] = (),
) -> None

Add the inputs with respect to which to differentiate the outputs.

The inputs that do not represent continuous numbers are filtered out.

Parameters:

  • input_names (Iterable[str], default: () ) –

    The input variables with respect to which to differentiate the outputs. If empty, use all the inputs.

Raises:

  • ValueError

    When an input name is not the name of a discipline input.

Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def add_differentiated_inputs(
    self,
    input_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_inputs = self._differentiated_input_names.copy()
    super().add_differentiated_inputs(input_names=input_names)
    refilter = any(
        input_name not in old_differentiated_inputs
        for input_name in self._differentiated_input_names
    )
    if refilter:
        self._filter_jacobian()
add_differentiated_outputs
add_differentiated_outputs(
    output_names: Iterable[str] = (),
) -> None

Add the outputs to be differentiated.

The outputs that do not represent continuous numbers are filtered out.

Parameters:

  • output_names (Iterable[str], default: () ) –

    The outputs to be differentiated. If empty, use all the outputs.

Raises:

  • ValueError

    When an output name is not the name of a discipline output.

Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_differentiated_outputs(
    self,
    output_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_outputs = self._differentiated_output_names.copy()
    super().add_differentiated_outputs(output_names=output_names)
    refilter = any(
        output_name not in old_differentiated_outputs
        for output_name in self._differentiated_output_names
    )
    if refilter:
        self._filter_jacobian()
compile_jit
compile_jit(pre_run: bool = True) -> None

Apply jit compilation over function and jacobian.

Parameters:

  • pre_run (bool, default: True ) –

    Whether to call jitted callables once to trigger compilation and log times.

Source code in src/gemseo_jax/jax_discipline.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def compile_jit(
    self,
    pre_run: bool = True,
) -> None:
    """Apply jit compilation over function and jacobian.

    Args:
        pre_run: Whether to call jitted callables once to trigger compilation and
            log times.
    """
    self.jax_out_func = jit(self.jax_out_func)
    self.__jax_jac_func = jit(self.__jax_jac_func)
    if pre_run:
        self._jit_pre_run()