🏋️ Train a Model¶
The ODETrainer class is the engine of node-fdm. It manages data loading, the Neural ODE integration loop, loss calculation, and checkpointing.
📋 Prerequisites¶
Before starting the training loop, ensure you have:
- Processed Data: Parquet files generated by the preprocessing pipeline (see Pipelines).
- Split File: A
dataset_split.csvcontaining at leastfilepath,aircraft_type, andsplitcolumns. - Registered Architecture: The architecture name (e.g.,
opensky_2025) must be registered inarchitectures.mapping.
⚡ Minimal Training Script¶
This script loads the configuration, filters the dataset for a specific aircraft type (e.g., A320), and launches the training.
import pandas as pd
import yaml
from pathlib import Path
from node_fdm.ode_trainer import ODETrainer
# 1. Load Project Configuration
# Assumes running from repository root
with open("scripts/opensky/config.yaml") as f:
cfg = yaml.safe_load(f)
paths = cfg["paths"]
process_dir = Path(paths["data_dir"]) / paths["process_dir"]
models_dir = Path(paths["data_dir"]) / paths["models_dir"]
# 2. Select Data Scope
TARGET_ACFT = "A320"
TARGET_ARCH = "opensky_2025" # Must match architectures.mapping.valid_names
# Load the split file and filter for the target aircraft
split_df = pd.read_csv(process_dir / "dataset_split.csv")
data_df = split_df[split_df.aircraft_type == TARGET_ACFT]
# 3. Define Hyperparameters
# These controls how data is sliced for the ODE
train_config = dict(
architecture_name=TARGET_ARCH,
model_name=f"{TARGET_ARCH}_{TARGET_ACFT}",
step=4, # Sampling period (seconds)
shift=60, # Stride between windows
seq_len=60, # Context window length (steps)
lr=1e-3, # Learning Rate
weight_decay=1e-4,
model_params=[3, 2, 48], # Architecture-specific args (see model.py)
loading_args=(False, False), # (Load Weights, Load Optimizer)
batch_size=512,
num_workers=4,
)
# 4. Initialize Trainer
trainer = ODETrainer(
data_df=data_df,
model_config=train_config,
model_dir=models_dir,
num_workers=train_config["num_workers"],
load_parallel=True, # Use multi-CPU
)
# 5. Execute Training Loop
trainer.train(
epochs=10,
batch_size=train_config["batch_size"],
val_batch_size=10_000,
method="euler", # Solver: "euler" or "rk4"
alpha_dict=None # Optional loss balancing
)
⚙️ Configuration Reference¶
Data Slicing (model_config)¶
| Parameter | Description | Typical Value |
|---|---|---|
step |
The time delta (\(dt\)) between two consecutive data points. | 4 (OpenSky) or 1 (QAR) |
seq_len |
Number of time steps fed to the ODE integration. | 60 to 120 |
shift |
Sliding window stride. Lower = more overlap but more data. | 60 (Non-overlapping) |
Solver Options (trainer.train)¶
| Parameter | Description |
|---|---|
method |
The integration method. euler is faster but less precise. rk4 (Runge-Kutta 4) is more stable but 4x slower. |
alpha_dict |
A dictionary mapping column names to loss weights. Useful if one variable (e.g., altitude) dominates the loss. |
💾 Outputs & Artifacts¶
Upon completion, the trainer generates the following structure in your models_dir:
models/opensky_2025_A320/
├── meta.json # ⚠️ CRITICAL: Contains scaling stats & hyperparams
├── training_losses.csv # Log of Train/Val loss per epoch
├── training_curve.png # Visualization of convergence
├── trajectory.pt # Checkpoint for Physics/Trajectory layers
└── data_ode.pt # Checkpoint for Neural/Data layers
Do not delete meta.json
The meta.json file stores the mean and standard deviation of the training data. The inference engine (NodeFDMPredictor) requires this file to normalize new data exactly as the model expects.
💡 Advanced Tips¶
To restart training from an existing checkpoint, update the loading_args in your config:
If your model struggles to learn specific dynamic variables (e.g., Vertical Speed vz), you can increase their weight in the loss function:
🚀 Next Steps¶
- Run Inference: Use your trained model to generate trajectory rollouts.