Solve the Sellar MDO problem with JAX.

Note

Click here to download the full example code

Solve the Sellar MDO problem with JAX.

from __future__ import annotations

from gemseo import configure_logger
from gemseo import create_scenario
from gemseo.core.mdo_functions.mdo_function import MDOFunction
from gemseo.problems.mdo.sellar.sellar_design_space import SellarDesignSpace

from gemseo_jax.jax_chain import JAXChain
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem

configure_logger()

Out:

<RootLogger root (INFO)>

Create the disciplines:

sellar_1 = JAXSellar1()
sellar_2 = JAXSellar2()
sellar_system = JAXSellarSystem()

Make a JAXChain to assemble the 3 without reconverting to NumPy:

disciplines = [sellar_1, sellar_2, sellar_system]
jax_chain = JAXChain(disciplines, name="SellarChain")

Add the differentiated outputs to reduce the computation graph of the Jacobian:

jax_chain.add_differentiated_outputs(["obj", "c_1", "c_2"])

Compile functions, this takes an extra compilation time, but lowers the cost of re-evaluation:

jax_chain.compile_jit()

Out:

    INFO - 12:32:27: Compilation of the output function SellarChain: 0:00:00.041348 seconds.
    INFO - 12:32:27: Compilation of the Jacobian function SellarChain: 0:00:00.060866 seconds.

Create the MDO scenario with an MDF formulation:

design_space = SellarDesignSpace()
scenario = create_scenario(
    jax_chain,
    "obj",
    design_space,
    formulation_name="MDF",
    main_mda_settings={"inner_mda_name": "MDAGaussSeidel"},
)
scenario.add_constraint(["c_1", "c_2"], MDOFunction.ConstraintType.INEQ)

Out:

/builds/gemseo/dev/gemseo-jax/.tox/doc/lib64/python3.12/site-packages/gemseo/algos/design_space.py:426: ComplexWarning: Casting complex values to real discards the imaginary part
  self.__current_value[name] = array_value.astype(
    INFO - 12:32:27: Variable y_1 was removed from the Design Space, it is not an input of any discipline.

Execute the scenario and post-process the results:

scenario.execute(algo_name="SLSQP", max_iter=10)
scenario.post_process(post_name="OptHistoryView", save=False, show=True)
    Evolution of the optimization variablesEvolution of the objective valueEvolution of the distance to the optimumEvolution of the inequality constraints

Out:

    INFO - 12:32:27: *** Start MDOScenario execution ***
    INFO - 12:32:27: MDOScenario
    INFO - 12:32:27:    Disciplines: SellarChain
    INFO - 12:32:27:    MDO formulation: MDF
    INFO - 12:32:27: Optimization problem:
    INFO - 12:32:27:    minimize obj(x_1, x_2, x_shared)
    INFO - 12:32:27:    with respect to x_1, x_2, x_shared
    INFO - 12:32:27:    under the inequality constraints
    INFO - 12:32:27:       c_1_c_2(x_1, x_2, x_shared) <= 0
    INFO - 12:32:27:    over the design space:
    INFO - 12:32:27:       +-------------+-------------+-------+-------------+-------+
    INFO - 12:32:27:       | Name        | Lower bound | Value | Upper bound | Type  |
    INFO - 12:32:27:       +-------------+-------------+-------+-------------+-------+
    INFO - 12:32:27:       | x_1         |      0      |   1   |      10     | float |
    INFO - 12:32:27:       | x_2         |      0      |   1   |      10     | float |
    INFO - 12:32:27:       | x_shared[0] |     -10     |   4   |      10     | float |
    INFO - 12:32:27:       | x_shared[1] |      0      |   3   |      10     | float |
    INFO - 12:32:27:       +-------------+-------------+-------+-------------+-------+
    INFO - 12:32:27: Solving optimization problem with algorithm SLSQP:
    INFO - 12:32:27:     10%|         | 1/10 [00:00<00:00, 520.13 it/sec, obj=23]
    INFO - 12:32:27:     20%|██        | 2/10 [00:00<00:01,  6.67 it/sec, obj=56.9]
    INFO - 12:32:27:     30%|███       | 3/10 [00:00<00:00,  9.93 it/sec, obj=27.1]
    INFO - 12:32:27:     40%|████      | 4/10 [00:00<00:00, 13.15 it/sec, obj=23.9]
    INFO - 12:32:27:     50%|█████     | 5/10 [00:00<00:00, 16.33 it/sec, obj=23.2]
    INFO - 12:32:27:     60%|██████    | 6/10 [00:00<00:00, 19.47 it/sec, obj=23]
    INFO - 12:32:27:     70%|███████   | 7/10 [00:00<00:00, 22.56 it/sec, obj=23]
    INFO - 12:32:27:     80%|████████  | 8/10 [00:00<00:00, 25.60 it/sec, obj=23]
    INFO - 12:32:27:     90%|█████████ | 9/10 [00:00<00:00, 28.61 it/sec, obj=23]
    INFO - 12:32:27:    100%|██████████| 10/10 [00:00<00:00, 31.59 it/sec, obj=23]
    INFO - 12:32:27: Optimization result:
    INFO - 12:32:27:    Optimizer info:
    INFO - 12:32:27:       Status: None
    INFO - 12:32:27:       Message: Maximum number of iterations reached. GEMSEO stopped the driver.
    INFO - 12:32:27:       Number of calls to the objective function by the optimizer: 0
    INFO - 12:32:27:    Solution:
    INFO - 12:32:27:       The solution is feasible.
    INFO - 12:32:27:       Objective: 22.952625867476453
    INFO - 12:32:27:       Standardized constraints:
    INFO - 12:32:27:          c_1_c_2 = [-14.79259005 -13.76295031]
    INFO - 12:32:27:       Design space:
    INFO - 12:32:27:          +-------------+-------------+-------------------+-------------+-------+
    INFO - 12:32:27:          | Name        | Lower bound |       Value       | Upper bound | Type  |
    INFO - 12:32:27:          +-------------+-------------+-------------------+-------------+-------+
    INFO - 12:32:27:          | x_1         |      0      |         1         |      10     | float |
    INFO - 12:32:27:          | x_2         |      0      |         1         |      10     | float |
    INFO - 12:32:27:          | x_shared[0] |     -10     | 4.000000000000002 |      10     | float |
    INFO - 12:32:27:          | x_shared[1] |      0      |         3         |      10     | float |
    INFO - 12:32:27:          +-------------+-------------+-------------------+-------------+-------+
    INFO - 12:32:27: *** End MDOScenario execution ***
/builds/gemseo/dev/gemseo-jax/.tox/doc/lib64/python3.12/site-packages/gemseo/post/correlations.py:36: DeprecationWarning: numpy.core.shape_base is deprecated and has been renamed to numpy._core.shape_base. The numpy._core namespace contains private NumPy internals and its use is discouraged, as NumPy internals can change without warning in any release. In practice, most real-world usage of numpy.core is to access functionality in the public NumPy API. If that is the case, use the public NumPy API. If not, you are using NumPy internals. If you would still like to access an internal attribute, use numpy._core.shape_base.hstack.
  from numpy.core.shape_base import hstack

<gemseo.post.opt_history_view.OptHistoryView object at 0x797e8d5d2420>

Total running time of the script: ( 0 minutes 1.289 seconds)

Download Python source code: plot_sellar.py

Download Jupyter notebook: plot_sellar.ipynb

Gallery generated by mkdocs-gallery