Skip to content

Utils

utils

Utilities.

Classes

Functions

create_jax_discipline_from_discipline

create_jax_discipline_from_discipline(
    discipline: Discipline, *args: Any, **kwargs: Any
) -> JAXDiscipline

Create a JAXDiscipline from a discipline using JAX instead of NumPy and SciPy.

It will use the same input variables, the same output variables and the same default input values.

Parameters:

  • discipline (Discipline) –

    The discipline using JAX instead of NumPy and SciPy.

  • *args (Any, default: () ) –

    The positional arguments of JAXDiscipline, except function, input_names, output_names and default_inputs.

  • **kwargs (Any, default: {} ) –

    The keyword arguments of JAXDiscipline.

Returns:

Warning

JAX's automatic differentiation works with Python control flow and logical operators. Using control flow and logical operators with jit (see compile_jit) is more complicated. If you have any difficulties, you can have a look at https://docs.jax.dev/en/latest/control-flow.html.

Source code in src/gemseo_jax/utils.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def create_jax_discipline_from_discipline(
    discipline: Discipline, *args: Any, **kwargs: Any
) -> JAXDiscipline:
    """Create a `JAXDiscipline` from a discipline using JAX instead of NumPy and SciPy.

    It will use the same input variables,
    the same output variables
    and the same default input values.

    Args:
        discipline: The discipline using JAX instead of NumPy and SciPy.
        *args: The positional arguments of `JAXDiscipline`,
            except `function`, `input_names`, `output_names` and `default_inputs`.
        **kwargs: The keyword arguments of `JAXDiscipline`.

    Returns:
        The JAX discipline.

    Warning:
        JAX's automatic differentiation works
        with Python control flow and logical operators.
        Using control flow and logical operators with jit
        (see [compile_jit][gemseo_jax.jax_discipline.JAXDiscipline.compile_jit])
        is more complicated.
        If you have any difficulties,
        you can have a look at https://docs.jax.dev/en/latest/control-flow.html.
    """
    return JAXDiscipline(
        _DisciplineBasedJAXFunction(discipline),
        discipline.io.input_grammar,
        discipline.io.output_grammar,
        discipline.io.input_grammar.defaults,
        *args,
        name=kwargs.pop("name", discipline.name),
        **kwargs,
    )