Skip to content

Aerodynamics

aerodynamics

Aerodynamics discipline for the Sobieski's SSBJ use case.

Classes

JAXSobieskiAerodynamics

JAXSobieskiAerodynamics()

Bases: BaseJAXSobieskiDiscipline

Aerodynamics discipline for the Sobieski's SSBJ use case.

Parameters:

  • function (Callable[[NumberLike, ..., Any, ...], tuple[NumberLike]]) –

    The JAX function.

  • static_args (Mapping[str, Any], default: READ_ONLY_EMPTY_DICT ) –

    The names and values of the static arguments of the JAX function. These arguments are constant at discipline execution. The non-numeric arguments can also be included.

Source code in src/gemseo_jax/problems/sobieski/aerodynamics.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(self) -> None:  # noqa: D107
    super().__init__()
    self.__flag1 = array([[0.95, 1.0, 1.05], [0.95, 1.0, 1.05]])
    self.__bound1 = array([0.25, 0.25])
    self.__flag2 = array([[1.0025, 1.0, 1.0025]])
    self.__bound2 = array([0.25])
    self.__flag3 = array([[0.95, 1.0, 1.05]])
    self.__bound3 = array([0.25])
    self.__esf_cf_initial = array([self.esf_initial, self.cf_initial])
    self.__twist_initial = array([self.twist_initial])
    self.__tc_initial = array([self.tc_initial])
    self.default_input_data["c_4"] = np_array([self.constants[4]])
    self.__a0_fo1 = 1.0
    self.__ai_fo1 = array([0.2, 0.2])
    self.__aij_fo1 = array([[0.0, 0.0], [0.0, 0.0]])
    self.__a0_fo2 = 1.0
    self.__ai_fo2 = array([0.0])
    self.__aij_fo2 = array([[0.02]])
    self.__a0_g2 = 1.0
    self.__ai_g2 = array([0.2])
    self.__aij_g2 = array([[0.0]])
Attributes
jax_out_func instance-attribute
jax_out_func: Callable[[DataType], DataType] = function

The JAX function to compute the outputs from the inputs.

Classes
DifferentiationMethod

Bases: StrEnum

The method to compute the Jacobian.

Functions
add_differentiated_inputs
add_differentiated_inputs(
    input_names: Iterable[str] = (),
) -> None
Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def add_differentiated_inputs(
    self,
    input_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_inputs = self._differentiated_input_names.copy()
    super().add_differentiated_inputs(input_names=input_names)
    refilter = any(
        input_name not in old_differentiated_inputs
        for input_name in self._differentiated_input_names
    )
    if refilter:
        self._filter_jacobian()
add_differentiated_outputs
add_differentiated_outputs(
    output_names: Iterable[str] = (),
) -> None
Notes

The Jacobian is also filtered to view non-differentiated static.

Source code in src/gemseo_jax/jax_discipline.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_differentiated_outputs(
    self,
    output_names: Iterable[str] = (),
) -> None:
    """
    Notes:
        The Jacobian is also filtered to view non-differentiated static.
    """  # noqa: D205, D212, D415
    old_differentiated_outputs = self._differentiated_output_names.copy()
    super().add_differentiated_outputs(output_names=output_names)
    refilter = any(
        output_name not in old_differentiated_outputs
        for output_name in self._differentiated_output_names
    )
    if refilter:
        self._filter_jacobian()
compile_jit
compile_jit(pre_run: bool = True) -> None

Apply jit compilation over function and jacobian.

Parameters:

  • pre_run (bool, default: True ) –

    Whether to call jitted callables once to trigger compilation and log times.

Warning

Calling add_differentiated_inputs and add_differentiated_outputs must be done before calling compile_jit.

Source code in src/gemseo_jax/jax_discipline.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def compile_jit(
    self,
    pre_run: bool = True,
) -> None:
    """Apply jit compilation over function and jacobian.

    Args:
        pre_run: Whether to call jitted callables once to trigger compilation and
            log times.

    Warning:
        Calling
        [add_differentiated_inputs][gemseo_jax.jax_discipline.JAXDiscipline.add_differentiated_inputs]
        and
        [add_differentiated_outputs][gemseo_jax.jax_discipline.JAXDiscipline.add_differentiated_outputs]
        must be done before calling
        [compile_jit][gemseo_jax.jax_discipline.JAXDiscipline.compile_jit].
    """
    self.jax_out_func = jit(self.jax_out_func)
    self.__jax_jac_func = jit(self.__jax_jac_func)
    if pre_run:
        self._jit_pre_run()