Solve the Sellar MDO problem with JAX.
Note
Click here to download the full example code
Solve the Sellar MDO problem with JAX.
from __future__ import annotations
from gemseo import configure_logger
from gemseo import create_scenario
from gemseo.core.mdo_functions.mdo_function import MDOFunction
from gemseo.problems.mdo.sellar.sellar_design_space import SellarDesignSpace
from gemseo_jax.jax_chain import JAXChain
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem
configure_logger()
Out:
<RootLogger root (INFO)>
Create the disciplines:
sellar_1 = JAXSellar1()
sellar_2 = JAXSellar2()
sellar_system = JAXSellarSystem()
Make a JAXChain
to assemble the 3 without reconverting to NumPy:
disciplines = [sellar_1, sellar_2, sellar_system]
jax_chain = JAXChain(disciplines, name="SellarChain")
Add the differentiated outputs to reduce the computation graph of the Jacobian:
jax_chain.add_differentiated_outputs(["obj", "c_1", "c_2"])
Compile functions, this takes an extra compilation time, but lowers the cost of re-evaluation:
jax_chain.compile_jit()
Out:
INFO - 13:32:05: Compilation of the output function SellarChain: 0:00:00.036432 seconds.
INFO - 13:32:05: Compilation of the Jacobian function SellarChain: 0:00:00.066468 seconds.
Create the MDO scenario with an MDF formulation:
design_space = SellarDesignSpace()
scenario = create_scenario(
jax_chain,
"obj",
design_space,
formulation_name="MDF",
main_mda_settings={"inner_mda_name": "MDAGaussSeidel"},
)
scenario.add_constraint(["c_1", "c_2"], MDOFunction.ConstraintType.INEQ)
Out:
/builds/gemseo/dev/gemseo-jax/.tox/doc/lib64/python3.9/site-packages/gemseo/algos/design_space.py:426: ComplexWarning: Casting complex values to real discards the imaginary part
self.__current_value[name] = array_value.astype(
INFO - 13:32:05: Variable y_1 was removed from the Design Space, it is not an input of any discipline.
Execute the scenario and post-process the results:
scenario.execute(algo_name="SLSQP", max_iter=10)
scenario.post_process(post_name="OptHistoryView", save=False, show=True)




Out:
INFO - 13:32:05:
INFO - 13:32:05: *** Start MDOScenario execution ***
INFO - 13:32:05: MDOScenario
INFO - 13:32:05: Disciplines: SellarChain
INFO - 13:32:05: MDO formulation: MDF
INFO - 13:32:05: Optimization problem:
INFO - 13:32:05: minimize obj(x_1, x_2, x_shared)
INFO - 13:32:05: with respect to x_1, x_2, x_shared
INFO - 13:32:05: subject to constraints:
INFO - 13:32:05: c_1_c_2(x_1, x_2, x_shared) <= 0
INFO - 13:32:05: over the design space:
INFO - 13:32:05: +-------------+-------------+-------+-------------+-------+
INFO - 13:32:05: | Name | Lower bound | Value | Upper bound | Type |
INFO - 13:32:05: +-------------+-------------+-------+-------------+-------+
INFO - 13:32:05: | x_1 | 0 | 1 | 10 | float |
INFO - 13:32:05: | x_2 | 0 | 1 | 10 | float |
INFO - 13:32:05: | x_shared[0] | -10 | 4 | 10 | float |
INFO - 13:32:05: | x_shared[1] | 0 | 3 | 10 | float |
INFO - 13:32:05: +-------------+-------------+-------+-------------+-------+
INFO - 13:32:05: Solving optimization problem with algorithm SLSQP:
INFO - 13:32:05: 10%|█ | 1/10 [00:00<00:00, 570.50 it/sec, obj=23]
INFO - 13:32:05: 20%|██ | 2/10 [00:00<00:01, 7.60 it/sec, obj=56.9]
INFO - 13:32:05: 30%|███ | 3/10 [00:00<00:00, 11.33 it/sec, obj=27.1]
INFO - 13:32:05: 40%|████ | 4/10 [00:00<00:00, 14.98 it/sec, obj=23.9]
INFO - 13:32:05: 50%|█████ | 5/10 [00:00<00:00, 18.58 it/sec, obj=23.2]
INFO - 13:32:05: 60%|██████ | 6/10 [00:00<00:00, 22.13 it/sec, obj=23]
INFO - 13:32:05: 70%|███████ | 7/10 [00:00<00:00, 25.63 it/sec, obj=23]
INFO - 13:32:05: 80%|████████ | 8/10 [00:00<00:00, 29.07 it/sec, obj=23]
INFO - 13:32:05: 90%|█████████ | 9/10 [00:00<00:00, 32.47 it/sec, obj=23]
INFO - 13:32:05: 100%|██████████| 10/10 [00:00<00:00, 35.81 it/sec, obj=23]
INFO - 13:32:05: Optimization result:
INFO - 13:32:05: Optimizer info:
INFO - 13:32:05: Status: None
INFO - 13:32:05: Message: Maximum number of iterations reached. GEMSEO stopped the driver.
INFO - 13:32:05: Number of calls to the objective function by the optimizer: 0
INFO - 13:32:05: Solution:
INFO - 13:32:05: The solution is feasible.
INFO - 13:32:05: Objective: 22.952625867476453
INFO - 13:32:05: Standardized constraints:
INFO - 13:32:05: c_1_c_2 = [-14.79259005 -13.76295031]
INFO - 13:32:05: Design space:
INFO - 13:32:05: +-------------+-------------+-------------------+-------------+-------+
INFO - 13:32:05: | Name | Lower bound | Value | Upper bound | Type |
INFO - 13:32:05: +-------------+-------------+-------------------+-------------+-------+
INFO - 13:32:05: | x_1 | 0 | 1 | 10 | float |
INFO - 13:32:05: | x_2 | 0 | 1 | 10 | float |
INFO - 13:32:05: | x_shared[0] | -10 | 4.000000000000002 | 10 | float |
INFO - 13:32:05: | x_shared[1] | 0 | 3 | 10 | float |
INFO - 13:32:05: +-------------+-------------+-------------------+-------------+-------+
INFO - 13:32:05: *** End MDOScenario execution (time: 0:00:00.283994) ***
<gemseo.post.opt_history_view.OptHistoryView object at 0x7884e19ffaf0>
Total running time of the script: ( 0 minutes 1.658 seconds)
Download Python source code: plot_sellar.py