Solve the Sellar MDO problem with JAX.
Note
Click here to download the full example code
Solve the Sellar MDO problem with JAX.
from __future__ import annotations
from gemseo import configure_logger
from gemseo import create_scenario
from gemseo.core.mdo_functions.mdo_function import MDOFunction
from gemseo.problems.mdo.sellar.sellar_design_space import SellarDesignSpace
from gemseo_jax.jax_chain import JAXChain
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem
configure_logger()
Out:
<RootLogger root (INFO)>
Create the disciplines:
sellar_1 = JAXSellar1()
sellar_2 = JAXSellar2()
sellar_system = JAXSellarSystem()
Make a JAXChain
to assemble the 3 without reconverting to NumPy:
disciplines = [sellar_1, sellar_2, sellar_system]
jax_chain = JAXChain(disciplines, name="SellarChain")
Add the differentiated outputs to reduce the computation graph of the Jacobian:
jax_chain.add_differentiated_outputs(["obj", "c_1", "c_2"])
Compile functions, this takes an extra compilation time, but lowers the cost of re-evaluation:
jax_chain.compile_jit()
Out:
INFO - 12:32:27: Compilation of the output function SellarChain: 0:00:00.041348 seconds.
INFO - 12:32:27: Compilation of the Jacobian function SellarChain: 0:00:00.060866 seconds.
Create the MDO scenario with an MDF formulation:
design_space = SellarDesignSpace()
scenario = create_scenario(
jax_chain,
"obj",
design_space,
formulation_name="MDF",
main_mda_settings={"inner_mda_name": "MDAGaussSeidel"},
)
scenario.add_constraint(["c_1", "c_2"], MDOFunction.ConstraintType.INEQ)
Out:
/builds/gemseo/dev/gemseo-jax/.tox/doc/lib64/python3.12/site-packages/gemseo/algos/design_space.py:426: ComplexWarning: Casting complex values to real discards the imaginary part
self.__current_value[name] = array_value.astype(
INFO - 12:32:27: Variable y_1 was removed from the Design Space, it is not an input of any discipline.
Execute the scenario and post-process the results:
scenario.execute(algo_name="SLSQP", max_iter=10)
scenario.post_process(post_name="OptHistoryView", save=False, show=True)




Out:
INFO - 12:32:27: *** Start MDOScenario execution ***
INFO - 12:32:27: MDOScenario
INFO - 12:32:27: Disciplines: SellarChain
INFO - 12:32:27: MDO formulation: MDF
INFO - 12:32:27: Optimization problem:
INFO - 12:32:27: minimize obj(x_1, x_2, x_shared)
INFO - 12:32:27: with respect to x_1, x_2, x_shared
INFO - 12:32:27: under the inequality constraints
INFO - 12:32:27: c_1_c_2(x_1, x_2, x_shared) <= 0
INFO - 12:32:27: over the design space:
INFO - 12:32:27: +-------------+-------------+-------+-------------+-------+
INFO - 12:32:27: | Name | Lower bound | Value | Upper bound | Type |
INFO - 12:32:27: +-------------+-------------+-------+-------------+-------+
INFO - 12:32:27: | x_1 | 0 | 1 | 10 | float |
INFO - 12:32:27: | x_2 | 0 | 1 | 10 | float |
INFO - 12:32:27: | x_shared[0] | -10 | 4 | 10 | float |
INFO - 12:32:27: | x_shared[1] | 0 | 3 | 10 | float |
INFO - 12:32:27: +-------------+-------------+-------+-------------+-------+
INFO - 12:32:27: Solving optimization problem with algorithm SLSQP:
INFO - 12:32:27: 10%|█ | 1/10 [00:00<00:00, 520.13 it/sec, obj=23]
INFO - 12:32:27: 20%|██ | 2/10 [00:00<00:01, 6.67 it/sec, obj=56.9]
INFO - 12:32:27: 30%|███ | 3/10 [00:00<00:00, 9.93 it/sec, obj=27.1]
INFO - 12:32:27: 40%|████ | 4/10 [00:00<00:00, 13.15 it/sec, obj=23.9]
INFO - 12:32:27: 50%|█████ | 5/10 [00:00<00:00, 16.33 it/sec, obj=23.2]
INFO - 12:32:27: 60%|██████ | 6/10 [00:00<00:00, 19.47 it/sec, obj=23]
INFO - 12:32:27: 70%|███████ | 7/10 [00:00<00:00, 22.56 it/sec, obj=23]
INFO - 12:32:27: 80%|████████ | 8/10 [00:00<00:00, 25.60 it/sec, obj=23]
INFO - 12:32:27: 90%|█████████ | 9/10 [00:00<00:00, 28.61 it/sec, obj=23]
INFO - 12:32:27: 100%|██████████| 10/10 [00:00<00:00, 31.59 it/sec, obj=23]
INFO - 12:32:27: Optimization result:
INFO - 12:32:27: Optimizer info:
INFO - 12:32:27: Status: None
INFO - 12:32:27: Message: Maximum number of iterations reached. GEMSEO stopped the driver.
INFO - 12:32:27: Number of calls to the objective function by the optimizer: 0
INFO - 12:32:27: Solution:
INFO - 12:32:27: The solution is feasible.
INFO - 12:32:27: Objective: 22.952625867476453
INFO - 12:32:27: Standardized constraints:
INFO - 12:32:27: c_1_c_2 = [-14.79259005 -13.76295031]
INFO - 12:32:27: Design space:
INFO - 12:32:27: +-------------+-------------+-------------------+-------------+-------+
INFO - 12:32:27: | Name | Lower bound | Value | Upper bound | Type |
INFO - 12:32:27: +-------------+-------------+-------------------+-------------+-------+
INFO - 12:32:27: | x_1 | 0 | 1 | 10 | float |
INFO - 12:32:27: | x_2 | 0 | 1 | 10 | float |
INFO - 12:32:27: | x_shared[0] | -10 | 4.000000000000002 | 10 | float |
INFO - 12:32:27: | x_shared[1] | 0 | 3 | 10 | float |
INFO - 12:32:27: +-------------+-------------+-------------------+-------------+-------+
INFO - 12:32:27: *** End MDOScenario execution ***
/builds/gemseo/dev/gemseo-jax/.tox/doc/lib64/python3.12/site-packages/gemseo/post/correlations.py:36: DeprecationWarning: numpy.core.shape_base is deprecated and has been renamed to numpy._core.shape_base. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.shape_base.hstack.
from numpy.core.shape_base import hstack
<gemseo.post.opt_history_view.OptHistoryView object at 0x797e8d5d2420>
Total running time of the script: ( 0 minutes 1.289 seconds)
Download Python source code: plot_sellar.py