Sellar 2
sellar_2
¶
The second discipline of the Sellar problem in JAX.
Classes¶
JAXSellar2
¶
JAXSellar2(
n: int = 1,
k: float = 1.0,
static_args: Mapping[str, Any] = READ_ONLY_EMPTY_DICT,
differentiation_method: DifferentiationMethod = AUTO,
differentiate_at_execution: bool = False,
)
Bases: BaseJAXSellar
The discipline to compute the coupling variable :math:y_2
in JAX.
Initialize the JAXDiscipline.
Parameters:
-
n
(int
, default:1
) –The size of the local design variables and coupling variables.
-
k
(float
, default:1.0
) –The shared coefficient controlling the coupling strength.
-
static_args
(Mapping[str, Any]
, default:READ_ONLY_EMPTY_DICT
) –The names and values of the static arguments of the JAX function. These arguments are constant at discipline execution. The non-numeric arguments can also be included.
-
differentiation_method
(DifferentiationMethod
, default:AUTO
) –The method to compute the Jacobian.
-
differentiate_at_execution
(bool
, default:False
) –Whether to compute the Jacobian when executing the discipline.
Source code in src/gemseo_jax/problems/sellar/sellar_2.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
Attributes¶
jax_out_func
instance-attribute
¶
jax_out_func: Callable[[DataType], DataType] = function
The JAX function to compute the outputs from the inputs.
Classes¶
DifferentiationMethod
¶
Bases: StrEnum
The method to compute the Jacobian.
Functions¶
add_differentiated_inputs
¶
Add the inputs with respect to which to differentiate the outputs.
The inputs that do not represent continuous numbers are filtered out.
Parameters:
-
input_names
(Iterable[str]
, default:()
) –The input variables with respect to which to differentiate the outputs. If empty, use all the inputs.
Raises:
-
ValueError
–When an input name is not the name of a discipline input.
Notes
The Jacobian is also filtered to view non-differentiated static.
Source code in src/gemseo_jax/jax_discipline.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
|
add_differentiated_outputs
¶
Add the outputs to be differentiated.
The outputs that do not represent continuous numbers are filtered out.
Parameters:
-
output_names
(Iterable[str]
, default:()
) –The outputs to be differentiated. If empty, use all the outputs.
Raises:
-
ValueError
–When an output name is not the name of a discipline output.
Notes
The Jacobian is also filtered to view non-differentiated static.
Source code in src/gemseo_jax/jax_discipline.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
|
compile_jit
¶
compile_jit(pre_run: bool = True) -> None
Apply jit compilation over function and jacobian.
Parameters:
-
pre_run
(bool
, default:True
) –Whether to call jitted callables once to trigger compilation and log times.
Source code in src/gemseo_jax/jax_discipline.py
217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
|