Skip to content

🏋️ Train a Model

The ODETrainer class is the engine of node-fdm. It manages data loading, the Neural ODE integration loop, loss calculation, and checkpointing.


📋 Prerequisites

Before starting the training loop, ensure you have:

  1. Processed Data: Parquet files generated by the preprocessing pipeline (see Pipelines).
  2. Split File: A dataset_split.csv containing at least filepath, aircraft_type, and split columns.
  3. Registered Architecture: The architecture name (e.g., opensky_2025) must be registered in architectures.mapping.

⚡ Minimal Training Script

This script loads the configuration, filters the dataset for a specific aircraft type (e.g., A320), and launches the training.

scripts/train_custom.py
import pandas as pd
import yaml
from pathlib import Path
from node_fdm.ode_trainer import ODETrainer

# 1. Load Project Configuration
# Assumes running from repository root
with open("scripts/opensky/config.yaml") as f:
    cfg = yaml.safe_load(f)

paths = cfg["paths"]
process_dir = Path(paths["data_dir"]) / paths["process_dir"]
models_dir = Path(paths["data_dir"]) / paths["models_dir"]

# 2. Select Data Scope
TARGET_ACFT = "A320"
TARGET_ARCH = "opensky_2025"  # Must match architectures.mapping.valid_names

# Load the split file and filter for the target aircraft
split_df = pd.read_csv(process_dir / "dataset_split.csv")
data_df = split_df[split_df.aircraft_type == TARGET_ACFT]

# 3. Define Hyperparameters
# These controls how data is sliced for the ODE
train_config = dict(
    architecture_name=TARGET_ARCH,
    model_name=f"{TARGET_ARCH}_{TARGET_ACFT}",
    step=4,                       # Sampling period (seconds)
    shift=60,                     # Stride between windows
    seq_len=60,                   # Context window length (steps)
    lr=1e-3,                      # Learning Rate
    weight_decay=1e-4,
    model_params=[3, 2, 48],      # Architecture-specific args (see model.py)
    loading_args=(False, False),  # (Load Weights, Load Optimizer)
    batch_size=512,
    num_workers=4,
)

# 4. Initialize Trainer
trainer = ODETrainer(
    data_df=data_df,
    model_config=train_config,
    model_dir=models_dir,
    num_workers=train_config["num_workers"],
    load_parallel=True,           # Use multi-CPU
)

# 5. Execute Training Loop
trainer.train(
    epochs=10,
    batch_size=train_config["batch_size"],
    val_batch_size=10_000,
    method="euler",               # Solver: "euler" or "rk4"
    alpha_dict=None               # Optional loss balancing
)

⚙️ Configuration Reference

Data Slicing (model_config)

Parameter Description Typical Value
step The time delta (\(dt\)) between two consecutive data points. 4 (OpenSky) or 1 (QAR)
seq_len Number of time steps fed to the ODE integration. 60 to 120
shift Sliding window stride. Lower = more overlap but more data. 60 (Non-overlapping)

Solver Options (trainer.train)

Parameter Description
method The integration method. euler is faster but less precise. rk4 (Runge-Kutta 4) is more stable but 4x slower.
alpha_dict A dictionary mapping column names to loss weights. Useful if one variable (e.g., altitude) dominates the loss.

💾 Outputs & Artifacts

Upon completion, the trainer generates the following structure in your models_dir:

models/opensky_2025_A320/
├── meta.json                # ⚠️ CRITICAL: Contains scaling stats & hyperparams
├── training_losses.csv      # Log of Train/Val loss per epoch
├── training_curve.png       # Visualization of convergence
├── trajectory.pt            # Checkpoint for Physics/Trajectory layers
└── data_ode.pt             # Checkpoint for Neural/Data layers

Do not delete meta.json

The meta.json file stores the mean and standard deviation of the training data. The inference engine (NodeFDMPredictor) requires this file to normalize new data exactly as the model expects.


💡 Advanced Tips

To restart training from an existing checkpoint, update the loading_args in your config:

model_config = dict(
    # ...
    # (Load Weights=True, Load Optimizer State=True)
    loading_args=(True, True),
)

If your model struggles to learn specific dynamic variables (e.g., Vertical Speed vz), you can increase their weight in the loss function:

# Penalize errors on 'vz' 5x more than other variables
trainer.train(
    # ...
    alpha_dict={"vz": 5.0}
)

You can train multiple models sequentially by iterating over the architecture names and dataset subsets:

for arch in ["opensky_2025", "my_custom_arch"]:
    # Update config
    train_config["architecture_name"] = arch
    # Re-initialize trainer...

🚀 Next Steps

  • Run Inference: Use your trained model to generate trajectory rollouts.