Model Tensor Planning is a sampling-based MPC framework that generates globally diverse trajectory candidates by sampling paths through a randomized M-partite graph and interpolating each path with a smooth spline. It runs entirely on GPU via JAX + MuJoCo MJX and plugs into hydrax as a drop-in controller.
Paper: Model Tensor Planning - An T. Le, Khai Nguyen, Minh Nhat Vu, João Carvalho, Jan Peters · TMLR 2025
MTP has been upstreamed into
hydrax as
hydrax.algs.MTP
(PRs #74,
#75,
#76). If you only need MTP
as one of many controllers, you can use it directly from hydrax:
from hydrax.algs import MTP # upstream copyThis repository (anindex/mtp) remains the canonical reference
implementation with:
- Extensive per-task tuning examples across 9 environments
- Self-contained Akima and B-spline spline code for research purposes
- The paper's official benchmark scripts and plotting utilities
The two codebases are kept API-compatible. Use whichever fits your workflow.
Local samplers (PS, MPPI, CEM) work well around an existing trajectory but get trapped in local minima - e.g. when the straight line from start to goal goes through a wall. MTP injects structured global exploration on top of a CEM-style local update:
- A tensor graph of
Mlayers xNcandidates encodes every combination of waypoints; a path is one index per layer. β · num_samplespaths are sampled from the graph and interpolated (Akima / B-spline / linear) into smooth control trajectories.- The remaining
(1 − β) · num_samplesare local CEM perturbations around the current best plan. - All trajectories are rolled out in parallel through MJX (with optional
domain randomization), elites are picked with
jax.lax.top_k, and the CEM mean / variance are updated with a softmax-weighted, baseline- subtracted, Bessel-corrected estimator.
See examples/navigation.py: PS / MPPI / CEM
get stuck behind the U-shaped wall; MTP routes around it.
Requires Python >= 3.12 and CUDA 13 for GPU rollouts.
git clone https://github.com/anindex/mtp.git
cd mtp
uv venv --python 3.12 .venv && source .venv/bin/activate
uv pip install -e .pip install -e . works equivalently. Hydrax is pinned to commit
33ec819
which includes the merged MTP PRs, spline bug fixes, MPPI-CMA, and
MjWarp backend support. Bump it explicitly in
pyproject.toml.
| Package | Minimum | Notes |
|---|---|---|
jax |
>= 0.8.0 | CUDA 13 required for GPU |
mujoco / mujoco-mjx |
>= 3.8.0 | MJX physics backend |
flax |
>= 0.12.0 | Immutable dataclasses |
interpax |
>= 0.3.12 | Akima spline utilities (hydrax dep) |
import mujoco
from hydrax.tasks.pendulum import Pendulum
from hydrax.simulation.deterministic import run_interactive
from mtp import MTP
task = Pendulum()
ctrl = MTP(
task,
num_samples=128,
m_pts=3, n_per_layer=50, # 3-layer graph, 50 candidates per layer
beta=0.5, # 50 % tensor paths, 50 % local CEM
mtp_interpolation="akima", # "akima" | "bspline" | "linear"
plan_horizon=1.0,
num_knots=10,
spline_type="zero", # hydrax low-level control spline
)
mj_model = task.mj_model
mj_data = mujoco.MjData(mj_model)
run_interactive(ctrl, mj_model, mj_data, frequency=25)Each example accepts mtp / ps / mppi / cem as a positional argument.
All examples also accept --warp for the experimental MjWarp backend
(note: --warp must come before the algorithm subcommand):
| Example | Highlights |
|---|---|
navigation.py |
U-maze (BugTrap) with a local minimum; MTP escapes, others don't |
pendulum.py · double_cart_pole.py · walker.py |
Classic underactuated benchmarks |
pusht.py · cube.py · crane.py |
Contact-rich manipulation |
g1_standup.py · g1_mocap.py |
Unitree G1 humanoid |
python examples/navigation.py mtp
python examples/pusht.py mppi # baseline comparison
python examples/walker.py --warp mtp # experimental MjWarp backendVisualize the spline tensor structures with
scripts/plot_splines.py.
MTP-specific (see mtp/mtp.py)
| Paper symbol | Argument | Description | Typical |
|---|---|---|---|
M |
m_pts |
Graph depth (waypoint layers) | 2-5 |
N |
n_per_layer |
Graph width (candidates / layer) | 20-100 |
β |
beta |
Tensor / CEM mix (1.0 = all tensor) | 0.1-1.0 |
K |
num_elites |
Elite count for the CEM update | 5-50 |
σ_min, σ_max |
sigma_min, sigma_max |
Variance clamp | 0.05-1.0 |
α |
alpha |
Variance smoothing (0 = full update) | 0.0-0.5 |
λ |
temperature |
Softmax temperature for elites | 0.01-1.0 |
| - | mtp_interpolation |
"akima" (local, no overshoot), "bspline" (globally smooth, requires m_pts >= degree + 1), "linear" |
- |
| - | degree |
B-spline degree (>= 2) | 2-4 |
| Argument | Description | Typical |
|---|---|---|
plan_horizon |
Planning horizon, seconds | 0.1-2.0 |
num_knots |
Hydrax spline knots | 4-20 |
spline_type |
"zero", "linear", "cubic" |
"zero" |
num_randomizations |
Domain-randomized rollouts | 1-8 |
hydrax.simulation.deterministic.run_interactive runs the controller and
viewer in the same thread, so realtime rate ≈ min(frequency, 1 / plan_time).
If the viewer feels choppy:
- Lower
frequency(e.g. 50 -> 25 Hz) to widen the per-replan budget. - Lower
num_samplesand/ornum_randomizations- total work scales asnum_samples · num_randomizations · ctrl_steps. - Lower
max_traces(andtrace_width) onrun_interactive- each trace is a Pythonmjv_connectorloop redrawn every replan. - Press Tab inside the viewer to hide the side panels.
mtp/
├── mtp.py # MTP controller (tensor sampling + CEM update)
├── splines/
│ ├── akima.py # Modified-Akima cubic, vectorized for JAX
│ └── bsplines.py # Cox-de Boor B-spline basis matrix
└── tasks/ # (empty - navigation moved upstream to hydrax)
examples/ # one runnable script per task, all four algorithms
scripts/ # plotting helpers
demos/ # GIFs used in this README
Note: The U-maze navigation task (
NavigationParticle) has been upstreamed ashydrax.tasks.bugtrap.BugTrap. The example scriptexamples/navigation.pynow imports directly from hydrax.
If you find this repository useful, please consider citing:
@article{le2025model,
title={Model Tensor Planning},
author={Le, An Thai and Nguyen, Khai and Vu, Minh Nhat and Carvalho, Joao and Peters, Jan},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=fk1ZZdXCE3}
}
@misc{kurtz2024hydrax,
title={Hydrax: Sampling-based model predictive control on GPU with JAX and MuJoCo MJX},
author={Kurtz, Vince},
year={2024},
note={https://github.com/vincekurtz/hydrax}
}Built on Hydrax and MuJoCo MJX. The author thanks Vince Kurtz for the upstream framework!


