Skip to content

Mission

mission

Mission discipline for the Sobieski's SSBJ use case.

Classes

JAXSobieskiMission

JAXSobieskiMission()

Bases: BaseJAXSobieskiDiscipline

Mission discipline for the Sobieski's SSBJ use case.

Initialize the JAXDiscipline.

Source code in src/gemseo_jax/problems/sobieski/base.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self) -> None:  # noqa: D107
    super().__init__(self._jax_func)
    self.default_input_data = SobieskiProblem().get_default_inputs(
        self.io.input_grammar.names
    )
    base = SobieskiBase(SobieskiBase.DataType.FLOAT)
    self.constants = base.constants
    self._coeff_mtrix = array(
        [
            [0.2736, 0.3970, 0.8152, 0.9230, 0.1108],
            [0.4252, 0.4415, 0.6357, 0.7435, 0.1138],
            [0.0329, 0.8856, 0.8390, 0.3657, 0.0019],
            [0.0878, 0.7248, 0.1978, 0.0200, 0.0169],
            [0.8955, 0.4568, 0.8075, 0.9239, 0.2525],
        ],
    )
    (
        self.x_initial,
        self.tc_initial,
        self.half_span_initial,
        self.aero_center_initial,
        self.cf_initial,
        self.mach_initial,
        self.h_initial,
        self.throttle_initial,
        self.lift_initial,
        self.twist_initial,
        self.esf_initial,
    ) = base.get_initial_values()
Attributes
jax_out_func instance-attribute
jax_out_func: Callable[[DataType], DataType] = function

The JAX function to compute the outputs from the inputs.

Classes
DifferentiationMethod

Bases: StrEnum

The method to compute the Jacobian.

Functions
add_differentiated_inputs
add_differentiated_inputs(
    input_names: Iterable[str] = (),
) -> None

Add the inputs with respect to which to differentiate the outputs.

The inputs that do not represent continuous numbers are filtered out.

Parameters:

  • input_names (Iterable[str], default: () ) –

    The input variables with respect to which to differentiate the outputs. If empty, use all the inputs.

Raises:

  • ValueError

    When an input name is not the name of a discipline input.

Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def add_differentiated_inputs(
    self,
    input_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_inputs = self._differentiated_input_names.copy()
    super().add_differentiated_inputs(input_names=input_names)
    refilter = any(
        input_name not in old_differentiated_inputs
        for input_name in self._differentiated_input_names
    )
    if refilter:
        self._filter_jacobian()
add_differentiated_outputs
add_differentiated_outputs(
    output_names: Iterable[str] = (),
) -> None

Add the outputs to be differentiated.

The outputs that do not represent continuous numbers are filtered out.

Parameters:

  • output_names (Iterable[str], default: () ) –

    The outputs to be differentiated. If empty, use all the outputs.

Raises:

  • ValueError

    When an output name is not the name of a discipline output.

Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_differentiated_outputs(
    self,
    output_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_outputs = self._differentiated_output_names.copy()
    super().add_differentiated_outputs(output_names=output_names)
    refilter = any(
        output_name not in old_differentiated_outputs
        for output_name in self._differentiated_output_names
    )
    if refilter:
        self._filter_jacobian()
compile_jit
compile_jit(pre_run: bool = True) -> None

Apply jit compilation over function and jacobian.

Parameters:

  • pre_run (bool, default: True ) –

    Whether to call jitted callables once to trigger compilation and log times.

Source code in src/gemseo_jax/jax_discipline.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def compile_jit(
    self,
    pre_run: bool = True,
) -> None:
    """Apply jit compilation over function and jacobian.

    Args:
        pre_run: Whether to call jitted callables once to trigger compilation and
            log times.
    """
    self.jax_out_func = jit(self.jax_out_func)
    self.__jax_jac_func = jit(self.__jax_jac_func)
    if pre_run:
        self._jit_pre_run()