Skip to content

🧠 Model Wrappers API

The node_fdm.models namespace provides the high-level PyTorch nn.Module wrappers that encapsulate the Neural ODE logic.

These classes serve as the core mathematical engine of the framework. They handle the forward pass integration, the batch processing of trajectories, and the orchestration of the various layers defined in your architecture.


📘 Class Reference

Flight Dynamics Model (Base)

The primary wrapper used during training. It manages the input encoding, connects to the ODE solver, and handles state reconstruction.

flight_dynamics_model

Neural flight dynamics model assembled from architecture layers.

FlightDynamicsModel

Bases: Module

Compute state derivatives using a layered flight dynamics architecture.

Source code in src/node_fdm/models/flight_dynamics_model.py
class FlightDynamicsModel(nn.Module):
    """Compute state derivatives using a layered flight dynamics architecture."""

    def __init__(
        self,
        architecture: Sequence[Any],
        stats_dict: Dict[Any, Dict[str, float]],
        model_cols: Tuple[Any, Any, Any, Any, Any],
        model_params: Sequence[int] = (2, 1, 48),
    ) -> None:
        """Initialize the model with architecture definition and statistics.

        Args:
            architecture: Iterable of layer definitions `(name, class, inputs, outputs, structured_flag)`.
            stats_dict: Mapping from column to normalization/denormalization statistics.
            model_cols: Tuple of model column groups (state, control, env, env_extra, derivatives).
            model_params: Sequence defining backbone depth, head depth, and hidden width.
        """
        super().__init__()
        self.architecture = architecture
        self.stats_dict = stats_dict
        self.x_cols, self.u_cols, self.e0_cols, self.e_cols, self.dx_cols = model_cols
        self.backbone_depth, self.head_depth, self.neurons_num = model_params
        self.layers_dict = nn.ModuleDict({})
        self.layers_name = []

        for name, layer_class, input_cols, ouput_cols, structured in self.architecture:
            self.layers_name.append(name)
            if structured:
                self.layers_dict[name] = self.create_structured_layer(
                    input_cols,
                    ouput_cols,
                    layer_class=layer_class,
                )
            else:
                self.layers_dict[name] = layer_class()

    def reset_history(self):
        """Reset internal history buffers.

        Clears stored layer outputs used for debugging or analysis between runs.
        """
        self.history = {}

    def create_structured_layer(
        self,
        input_cols: Sequence[Any],
        output_cols: Sequence[Any],
        layer_class: Any = StructuredLayer,
    ) -> nn.Module:
        """Build a structured layer with normalization and denormalization stats.

        Args:
            input_cols: Columns consumed by the layer.
            output_cols: Columns produced by the layer.
            layer_class: Layer implementation to instantiate.

        Returns:
            Configured structured layer instance.
        """
        input_stats = [
            {
                col.col_name: self.stats_dict[col][metric]
                for col in input_cols
                if col.normalize_mode is not None
            }
            for metric in ["mean", "std"]
        ]
        output_stats = [
            {
                col.col_name: self.stats_dict[col][metric]
                for col in output_cols
                if col.denormalize_mode is not None
            }
            for metric in ["mean", "std", "max"]
        ]

        layer = layer_class(
            input_cols,
            input_stats,
            output_cols,
            output_stats,
            backbone_dim=self.neurons_num,
            backbone_depth=self.backbone_depth,
            head_dim=self.neurons_num // 2,
            head_depth=self.head_depth,
        )

        return layer

    def forward(
        self, x: torch.Tensor, u_t: torch.Tensor, e_t: torch.Tensor
    ) -> torch.Tensor:
        """Compute state derivatives for the current batch.

        Args:
            x: State tensor.
            u_t: Control tensor interpolated at current time.
            e_t: Environment tensor interpolated at current time.

        Returns:
            Tensor of state derivatives assembled from architecture outputs.
        """

        vects = torch.cat([x, u_t, e_t], dim=1)
        vect_dict = dict()
        for i, col in enumerate(self.x_cols + self.u_cols + self.e0_cols):
            vect_dict[col] = vects[..., i]

        for name in self.layers_name:
            vect_dict = vect_dict | self.layers_dict[name](vect_dict)

        ode_output = torch.stack(
            [coeff * vect_dict[col] for coeff, col in self.dx_cols],
            dim=1,
        )

        for col, vect in vect_dict.items():
            if torch.isnan(vect).any():
                pass
            if col in self.history.keys():
                self.history[col] = torch.cat(
                    [self.history[col], vect.unsqueeze(1)], dim=1
                )
            else:
                self.history[col] = vect.unsqueeze(1)

        return ode_output

__init__(architecture, stats_dict, model_cols, model_params=(2, 1, 48))

Initialize the model with architecture definition and statistics.

Parameters:

Name Type Description Default
architecture Sequence[Any]

Iterable of layer definitions (name, class, inputs, outputs, structured_flag).

required
stats_dict Dict[Any, Dict[str, float]]

Mapping from column to normalization/denormalization statistics.

required
model_cols Tuple[Any, Any, Any, Any, Any]

Tuple of model column groups (state, control, env, env_extra, derivatives).

required
model_params Sequence[int]

Sequence defining backbone depth, head depth, and hidden width.

(2, 1, 48)
Source code in src/node_fdm/models/flight_dynamics_model.py
def __init__(
    self,
    architecture: Sequence[Any],
    stats_dict: Dict[Any, Dict[str, float]],
    model_cols: Tuple[Any, Any, Any, Any, Any],
    model_params: Sequence[int] = (2, 1, 48),
) -> None:
    """Initialize the model with architecture definition and statistics.

    Args:
        architecture: Iterable of layer definitions `(name, class, inputs, outputs, structured_flag)`.
        stats_dict: Mapping from column to normalization/denormalization statistics.
        model_cols: Tuple of model column groups (state, control, env, env_extra, derivatives).
        model_params: Sequence defining backbone depth, head depth, and hidden width.
    """
    super().__init__()
    self.architecture = architecture
    self.stats_dict = stats_dict
    self.x_cols, self.u_cols, self.e0_cols, self.e_cols, self.dx_cols = model_cols
    self.backbone_depth, self.head_depth, self.neurons_num = model_params
    self.layers_dict = nn.ModuleDict({})
    self.layers_name = []

    for name, layer_class, input_cols, ouput_cols, structured in self.architecture:
        self.layers_name.append(name)
        if structured:
            self.layers_dict[name] = self.create_structured_layer(
                input_cols,
                ouput_cols,
                layer_class=layer_class,
            )
        else:
            self.layers_dict[name] = layer_class()

create_structured_layer(input_cols, output_cols, layer_class=StructuredLayer)

Build a structured layer with normalization and denormalization stats.

Parameters:

Name Type Description Default
input_cols Sequence[Any]

Columns consumed by the layer.

required
output_cols Sequence[Any]

Columns produced by the layer.

required
layer_class Any

Layer implementation to instantiate.

StructuredLayer

Returns:

Type Description
Module

Configured structured layer instance.

Source code in src/node_fdm/models/flight_dynamics_model.py
def create_structured_layer(
    self,
    input_cols: Sequence[Any],
    output_cols: Sequence[Any],
    layer_class: Any = StructuredLayer,
) -> nn.Module:
    """Build a structured layer with normalization and denormalization stats.

    Args:
        input_cols: Columns consumed by the layer.
        output_cols: Columns produced by the layer.
        layer_class: Layer implementation to instantiate.

    Returns:
        Configured structured layer instance.
    """
    input_stats = [
        {
            col.col_name: self.stats_dict[col][metric]
            for col in input_cols
            if col.normalize_mode is not None
        }
        for metric in ["mean", "std"]
    ]
    output_stats = [
        {
            col.col_name: self.stats_dict[col][metric]
            for col in output_cols
            if col.denormalize_mode is not None
        }
        for metric in ["mean", "std", "max"]
    ]

    layer = layer_class(
        input_cols,
        input_stats,
        output_cols,
        output_stats,
        backbone_dim=self.neurons_num,
        backbone_depth=self.backbone_depth,
        head_dim=self.neurons_num // 2,
        head_depth=self.head_depth,
    )

    return layer

forward(x, u_t, e_t)

Compute state derivatives for the current batch.

Parameters:

Name Type Description Default
x Tensor

State tensor.

required
u_t Tensor

Control tensor interpolated at current time.

required
e_t Tensor

Environment tensor interpolated at current time.

required

Returns:

Type Description
Tensor

Tensor of state derivatives assembled from architecture outputs.

Source code in src/node_fdm/models/flight_dynamics_model.py
def forward(
    self, x: torch.Tensor, u_t: torch.Tensor, e_t: torch.Tensor
) -> torch.Tensor:
    """Compute state derivatives for the current batch.

    Args:
        x: State tensor.
        u_t: Control tensor interpolated at current time.
        e_t: Environment tensor interpolated at current time.

    Returns:
        Tensor of state derivatives assembled from architecture outputs.
    """

    vects = torch.cat([x, u_t, e_t], dim=1)
    vect_dict = dict()
    for i, col in enumerate(self.x_cols + self.u_cols + self.e0_cols):
        vect_dict[col] = vects[..., i]

    for name in self.layers_name:
        vect_dict = vect_dict | self.layers_dict[name](vect_dict)

    ode_output = torch.stack(
        [coeff * vect_dict[col] for coeff, col in self.dx_cols],
        dim=1,
    )

    for col, vect in vect_dict.items():
        if torch.isnan(vect).any():
            pass
        if col in self.history.keys():
            self.history[col] = torch.cat(
                [self.history[col], vect.unsqueeze(1)], dim=1
            )
        else:
            self.history[col] = vect.unsqueeze(1)

    return ode_output

reset_history()

Reset internal history buffers.

Clears stored layer outputs used for debugging or analysis between runs.

Source code in src/node_fdm/models/flight_dynamics_model.py
def reset_history(self):
    """Reset internal history buffers.

    Clears stored layer outputs used for debugging or analysis between runs.
    """
    self.history = {}

Batch Neural ODE

The core utility responsible for solving the system of differential equations over batches of time sequences. It interfaces with the numerical solvers (Euler, RK4).

batch_neural_ode

Batch-compatible Neural ODE wrapper that interpolates inputs over time.

BatchNeuralODE

Bases: Module

Wrap a neural ODE with batched control and environment inputs.

Source code in src/node_fdm/models/batch_neural_ode.py
class BatchNeuralODE(nn.Module):
    """Wrap a neural ODE with batched control and environment inputs."""

    def __init__(
        self,
        model: nn.Module,
        u_seq: torch.Tensor,
        e_seq: torch.Tensor,
        t_grid: torch.Tensor,
    ) -> None:
        """Initialize the ODE wrapper and reset model history.

        Args:
            model: Base neural ODE model taking `(x, u_t, e_t)`.
            u_seq: Control inputs over time with shape `(batch, time, features)`.
            e_seq: Environment inputs over time with shape `(batch, time, features)`.
            t_grid: Monotonic time grid corresponding to `u_seq` and `e_seq`.
        """
        super().__init__()
        self.model = model
        self.model.reset_history()
        self.u_seq = u_seq
        self.e_seq = e_seq
        self.t_grid = t_grid

    def forward(self, t: torch.Tensor, x: torch.Tensor) -> Any:
        """Evaluate the ODE dynamics at time `t` with linear interpolation.

        Args:
            t: Scalar tensor containing the evaluation time.
            x: Current state tensor.

        Returns:
            Model output of the wrapped dynamics at time `t`.
        """
        t = t.item()
        idx = torch.searchsorted(
            self.t_grid, torch.tensor(t, device=self.t_grid.device)
        ).item()
        idx0 = max(0, idx - 1)
        idx1 = min(idx, self.t_grid.shape[0] - 1)

        t0, t1 = self.t_grid[idx0].item(), self.t_grid[idx1].item()
        alpha = 0 if t1 == t0 else (t - t0) / (t1 - t0)

        u0, u1 = self.u_seq[:, idx0, :], self.u_seq[:, idx1, :]
        e0, e1 = self.e_seq[:, idx0, :], self.e_seq[:, idx1, :]

        u_t = (1 - alpha) * u0 + alpha * u1
        e_t = (1 - alpha) * e0 + alpha * e1

        return self.model(x, u_t, e_t)

__init__(model, u_seq, e_seq, t_grid)

Initialize the ODE wrapper and reset model history.

Parameters:

Name Type Description Default
model Module

Base neural ODE model taking (x, u_t, e_t).

required
u_seq Tensor

Control inputs over time with shape (batch, time, features).

required
e_seq Tensor

Environment inputs over time with shape (batch, time, features).

required
t_grid Tensor

Monotonic time grid corresponding to u_seq and e_seq.

required
Source code in src/node_fdm/models/batch_neural_ode.py
def __init__(
    self,
    model: nn.Module,
    u_seq: torch.Tensor,
    e_seq: torch.Tensor,
    t_grid: torch.Tensor,
) -> None:
    """Initialize the ODE wrapper and reset model history.

    Args:
        model: Base neural ODE model taking `(x, u_t, e_t)`.
        u_seq: Control inputs over time with shape `(batch, time, features)`.
        e_seq: Environment inputs over time with shape `(batch, time, features)`.
        t_grid: Monotonic time grid corresponding to `u_seq` and `e_seq`.
    """
    super().__init__()
    self.model = model
    self.model.reset_history()
    self.u_seq = u_seq
    self.e_seq = e_seq
    self.t_grid = t_grid

forward(t, x)

Evaluate the ODE dynamics at time t with linear interpolation.

Parameters:

Name Type Description Default
t Tensor

Scalar tensor containing the evaluation time.

required
x Tensor

Current state tensor.

required

Returns:

Type Description
Any

Model output of the wrapped dynamics at time t.

Source code in src/node_fdm/models/batch_neural_ode.py
def forward(self, t: torch.Tensor, x: torch.Tensor) -> Any:
    """Evaluate the ODE dynamics at time `t` with linear interpolation.

    Args:
        t: Scalar tensor containing the evaluation time.
        x: Current state tensor.

    Returns:
        Model output of the wrapped dynamics at time `t`.
    """
    t = t.item()
    idx = torch.searchsorted(
        self.t_grid, torch.tensor(t, device=self.t_grid.device)
    ).item()
    idx0 = max(0, idx - 1)
    idx1 = min(idx, self.t_grid.shape[0] - 1)

    t0, t1 = self.t_grid[idx0].item(), self.t_grid[idx1].item()
    alpha = 0 if t1 == t0 else (t - t0) / (t1 - t0)

    u0, u1 = self.u_seq[:, idx0, :], self.u_seq[:, idx1, :]
    e0, e1 = self.e_seq[:, idx0, :], self.e_seq[:, idx1, :]

    u_t = (1 - alpha) * u0 + alpha * u1
    e_t = (1 - alpha) * e0 + alpha * e1

    return self.model(x, u_t, e_t)

Production Model

An optimized wrapper designed strictly for inference environments. It streamlines the forward pass by removing training-specific hooks and overhead.

flight_dynamics_model_prod

Production-ready flight dynamics model loader and evaluator.

FlightDynamicsModelProd

Bases: Module

Load pretrained flight dynamics layers and expose an evaluation interface.

Source code in src/node_fdm/models/flight_dynamics_model_prod.py
class FlightDynamicsModelProd(nn.Module):
    """Load pretrained flight dynamics layers and expose an evaluation interface."""

    def __init__(
        self,
        model_path: Any,
    ) -> None:
        """Initialize and load pretrained layers from a model directory.

        Args:
            model_path: Path-like pointing to the directory containing checkpoints and meta.json.
        """
        super().__init__()
        self.model_path = model_path
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        meta_path = model_path / "meta.json"
        self.architecture, self.model_cols, model_params, self.stats_dict = (
            get_architecture_params_from_meta(meta_path)
        )
        self.backbone_depth, self.head_depth, self.neurons_num = model_params
        self.layers_dict = nn.ModuleDict({})
        self.layers_name = []
        for name, layer_class, input_cols, ouput_cols, structured in self.architecture:
            self.layers_name.append(name)
            if structured:
                self.layers_dict[name] = self.create_structured_layer(
                    input_cols,
                    ouput_cols,
                    layer_class=layer_class,
                )
            else:
                self.layers_dict[name] = layer_class()
            if name != "trajectory":
                checkpoint = self.load_layer_checkpoint(name)
                self.layers_dict[name].load_state_dict(
                    checkpoint["layer_state"], strict=False
                )
                self.layers_dict[name] = self.layers_dict[name].eval()

    def load_layer_checkpoint(self, layer_name: str) -> Any:
        """Load checkpoint for a given layer.

        Args:
            layer_name: Name of the layer whose weights should be loaded.

        Returns:
            Loaded checkpoint dictionary.
        """
        path = os.path.join(self.model_path, f"{layer_name}.pt")
        checkpoint = torch.load(path, map_location=self.device)
        return checkpoint

    def create_structured_layer(
        self,
        input_cols: Sequence[Any],
        output_cols: Sequence[Any],
        layer_class: Any = StructuredLayer,
    ) -> nn.Module:
        """Build a structured layer with normalization and denormalization stats.

        Args:
            input_cols: Columns consumed by the layer.
            output_cols: Columns produced by the layer.
            layer_class: Layer implementation to instantiate.

        Returns:
            Configured structured layer instance.
        """
        input_stats = [
            {
                col.col_name: self.stats_dict[col][metric]
                for col in input_cols
                if col.normalize_mode is not None
            }
            for metric in ["mean", "std"]
        ]
        output_stats = [
            {
                col.col_name: self.stats_dict[col][metric]
                for col in output_cols
                if col.denormalize_mode is not None
            }
            for metric in ["mean", "std", "max"]
        ]

        layer = layer_class(
            input_cols,
            input_stats,
            output_cols,
            output_stats,
            backbone_dim=self.neurons_num,
            backbone_depth=self.backbone_depth,
            head_dim=self.neurons_num // 2,
            head_depth=self.head_depth,
        )

        return layer

    def forward(self, vect_dict: dict) -> dict:
        """Run a forward pass through all layers.

        Args:
            vect_dict: Mapping from column identifiers to tensors.

        Returns:
            Updated mapping with newly computed columns.
        """
        for name in self.layers_name:
            res = self.layers_dict[name](vect_dict)
            vect_dict |= res

        return vect_dict

__init__(model_path)

Initialize and load pretrained layers from a model directory.

Parameters:

Name Type Description Default
model_path Any

Path-like pointing to the directory containing checkpoints and meta.json.

required
Source code in src/node_fdm/models/flight_dynamics_model_prod.py
def __init__(
    self,
    model_path: Any,
) -> None:
    """Initialize and load pretrained layers from a model directory.

    Args:
        model_path: Path-like pointing to the directory containing checkpoints and meta.json.
    """
    super().__init__()
    self.model_path = model_path
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    meta_path = model_path / "meta.json"
    self.architecture, self.model_cols, model_params, self.stats_dict = (
        get_architecture_params_from_meta(meta_path)
    )
    self.backbone_depth, self.head_depth, self.neurons_num = model_params
    self.layers_dict = nn.ModuleDict({})
    self.layers_name = []
    for name, layer_class, input_cols, ouput_cols, structured in self.architecture:
        self.layers_name.append(name)
        if structured:
            self.layers_dict[name] = self.create_structured_layer(
                input_cols,
                ouput_cols,
                layer_class=layer_class,
            )
        else:
            self.layers_dict[name] = layer_class()
        if name != "trajectory":
            checkpoint = self.load_layer_checkpoint(name)
            self.layers_dict[name].load_state_dict(
                checkpoint["layer_state"], strict=False
            )
            self.layers_dict[name] = self.layers_dict[name].eval()

create_structured_layer(input_cols, output_cols, layer_class=StructuredLayer)

Build a structured layer with normalization and denormalization stats.

Parameters:

Name Type Description Default
input_cols Sequence[Any]

Columns consumed by the layer.

required
output_cols Sequence[Any]

Columns produced by the layer.

required
layer_class Any

Layer implementation to instantiate.

StructuredLayer

Returns:

Type Description
Module

Configured structured layer instance.

Source code in src/node_fdm/models/flight_dynamics_model_prod.py
def create_structured_layer(
    self,
    input_cols: Sequence[Any],
    output_cols: Sequence[Any],
    layer_class: Any = StructuredLayer,
) -> nn.Module:
    """Build a structured layer with normalization and denormalization stats.

    Args:
        input_cols: Columns consumed by the layer.
        output_cols: Columns produced by the layer.
        layer_class: Layer implementation to instantiate.

    Returns:
        Configured structured layer instance.
    """
    input_stats = [
        {
            col.col_name: self.stats_dict[col][metric]
            for col in input_cols
            if col.normalize_mode is not None
        }
        for metric in ["mean", "std"]
    ]
    output_stats = [
        {
            col.col_name: self.stats_dict[col][metric]
            for col in output_cols
            if col.denormalize_mode is not None
        }
        for metric in ["mean", "std", "max"]
    ]

    layer = layer_class(
        input_cols,
        input_stats,
        output_cols,
        output_stats,
        backbone_dim=self.neurons_num,
        backbone_depth=self.backbone_depth,
        head_dim=self.neurons_num // 2,
        head_depth=self.head_depth,
    )

    return layer

forward(vect_dict)

Run a forward pass through all layers.

Parameters:

Name Type Description Default
vect_dict dict

Mapping from column identifiers to tensors.

required

Returns:

Type Description
dict

Updated mapping with newly computed columns.

Source code in src/node_fdm/models/flight_dynamics_model_prod.py
def forward(self, vect_dict: dict) -> dict:
    """Run a forward pass through all layers.

    Args:
        vect_dict: Mapping from column identifiers to tensors.

    Returns:
        Updated mapping with newly computed columns.
    """
    for name in self.layers_name:
        res = self.layers_dict[name](vect_dict)
        vect_dict |= res

    return vect_dict

load_layer_checkpoint(layer_name)

Load checkpoint for a given layer.

Parameters:

Name Type Description Default
layer_name str

Name of the layer whose weights should be loaded.

required

Returns:

Type Description
Any

Loaded checkpoint dictionary.

Source code in src/node_fdm/models/flight_dynamics_model_prod.py
def load_layer_checkpoint(self, layer_name: str) -> Any:
    """Load checkpoint for a given layer.

    Args:
        layer_name: Name of the layer whose weights should be loaded.

    Returns:
        Loaded checkpoint dictionary.
    """
    path = os.path.join(self.model_path, f"{layer_name}.pt")
    checkpoint = torch.load(path, map_location=self.device)
    return checkpoint