In [None]:
%matplotlib inline

Solve the Sellar MDO problem with JAX.


In [None]:
from __future__ import annotations

from gemseo import configure_logger
from gemseo import create_scenario
from gemseo.core.mdo_functions.mdo_function import MDOFunction
from gemseo.problems.mdo.sellar.sellar_design_space import SellarDesignSpace

from gemseo_jax.jax_chain import JAXChain
from gemseo_jax.problems.sellar.sellar_1 import JAXSellar1
from gemseo_jax.problems.sellar.sellar_2 import JAXSellar2
from gemseo_jax.problems.sellar.sellar_system import JAXSellarSystem

configure_logger()

Create the disciplines:



In [None]:
sellar_1 = JAXSellar1()
sellar_2 = JAXSellar2()
sellar_system = JAXSellarSystem()

Make a `JAXChain` to assemble the 3 without reconverting to NumPy:



In [None]:
disciplines = [sellar_1, sellar_2, sellar_system]
jax_chain = JAXChain(disciplines, name="SellarChain")

Add the differentiated outputs to reduce the computation graph of the Jacobian:



In [None]:
jax_chain.add_differentiated_outputs(["obj", "c_1", "c_2"])

Compile functions, this takes an extra compilation time, but lowers the cost of
re-evaluation:



In [None]:
jax_chain.compile_jit()

Create the MDO scenario with an MDF formulation:



In [None]:
design_space = SellarDesignSpace()
scenario = create_scenario(
 jax_chain,
 "obj",
 design_space,
 formulation_name="MDF",
 main_mda_settings={"inner_mda_name": "MDAGaussSeidel"},
)
scenario.add_constraint(["c_1", "c_2"], MDOFunction.ConstraintType.INEQ)

Execute the scenario and post-process the results:



In [None]:
scenario.execute(algo_name="SLSQP", max_iter=10)
scenario.post_process(post_name="OptHistoryView", save=False, show=True)