LKHelm trains a TwoGateMoE (Mixture of Experts) model that learns to select the best (engine, datalake, configuration) combination for any given database query workload. The system covers 9 execution combos — {Spark, Presto, Trino} × {Delta, Iceberg, Hudi} — across 5 benchmarks at multiple scale factors.
Core idea: Given a set of SQL queries (a "workload"), the model predicts which engine/datalake combo and which configuration knob settings will minimize total execution latency.
Requirements: Python 3.8+, PyTorch, NumPy.
pip install torch numpy
# CUDA 12.x example:
pip install torch --index-url https://download.pytorch.org/whl/cu124python3 train_local.py --benchmark tpcds --sf 100 --epochs 30 --eval-mode per_query \
--stage1-subepochs 5 --stage2-subepochs 10 --stage3-subepochs 10bash run_all.sh| Argument | Default | Description |
|---|---|---|
--benchmark |
tpcds |
Which benchmark: tpcds, tpch, ssb, ssb_flat, job. If unset, all benchmarks are pooled. |
--sf |
None | Scale-factor filter (e.g. 1, 10, 100). Only queries at that sf are used. |
--epochs |
40 |
Outer training epochs |
--seed |
42 |
Random seed for reproducibility |
--stage1-subepochs |
1 |
Stage-1 (end-to-end) sub-epochs per outer epoch |
--stage2-subepochs |
2 |
Stage-2 (gate-focused) sub-epochs per outer epoch |
--stage3-subepochs |
2 |
Stage-3 (expert-focused) sub-epochs per outer epoch |
--lambda-div |
0.1 |
Weight on diversity regularization (paper L_div) |
--lambda-diversity |
5.0 |
Weight on entropy-max anti-collapse term |
--lambda-emb-spread |
2.0 |
Weight on workload-embedding spread regularizer |
--tree-weight-decay |
1e-3 |
Weight decay specifically for tree-conv encoder |
--gumbel-tau |
1.0 |
Temperature for Gumbel-softmax routing |
Step 1: Load CSV data from the chosen benchmark(s)
Step 2: Random query-level split: 70% train / 15% valid / 15% test
Step 3: Build tree-conv embeddings from SQL execution plans
Step 4: Generate workloads (random subsets of queries)
Step 5: Train TwoGateMoE for `epochs` outer epochs with three stages each
Step 6: Pick the checkpoint with the best validation ratio; report its test ratio
Random query-level split is the default — there is no leave-one-out evaluation in this codebase.
Converts SQL execution plans into 288-dimensional query embeddings:
- Input: SQL execution plan tree
- Feature extraction: Node type, referenced tables, column histograms → feature vector per node
- Tree convolution:
BatchTreeConvCBAMwith 4 kernels (channel attention) - Output:
feat_dim × num_kernels = 72 × 4 = 288per query - LayerNorm: per-query unit variance to prevent encoder collapse
- Fallback: queries without plan files get learnable
nn.Embeddingvectors
Per-query embeddings → workload embedding via multi-head attention pool with concat(mean, max) residual:
output = concat(head_1, head_2, head_3, head_4, mean, max) # → 6 × 288 = 1728-dim
Each attention head learns a different scoring function over the per-query embeddings; this prevents the gate from receiving near-identical inputs across different workloads.
┌────────────────────┐
│ Workload Embedding │ (1728-dim, AttentionPool output)
└─────────┬──────────┘
│
┌───────────────┼───────────────┐
▼ ▼
┌─────────────────┐ ┌─────────────────┐
│ Engine Gate │ │ Lake Gate │
│ MLP [128, 256] │ │ MLP [128, 256] │
│ → 3 classes │ │ → 3 classes │
└───────┬─────────┘ └───────┬─────────┘
│ Gumbel-softmax │
▼ ▼
┌─────────────────┐ ┌─────────────────┐
│ 3 Engine Experts│ │ 3 Lake Experts │
│ MLP each: │ │ MLP each: │
│ → 128-dim │ │ → 128-dim │
└───────┬─────────┘ └───────┬─────────┘
│ │
└──────────┬──────────────────────┘
▼
┌───────────────────┐
│ Concat │ (256-dim = 128 + 128)
│ + Config Encoder │ (64-dim ConfEncoder)
└─────────┬────────┘
▼
┌───────────────────┐
│ Post-MLP │
│ → 1 (predicted │
│ ratio) │
└───────────────────┘
L_total = L_MSE + L_CE + λ_div × L_div
| Component | Definition | Purpose |
|---|---|---|
| L_MSE | (r̂ - r)² × p_gumbel_eng[c*] × p_gumbel_lake[f*] |
Predict ratio r = lat / optimal_lat for the (combo, conf), weighted by Gumbel probabilities of correct gates |
| L_CE | -log p_eng[c*] - log p_lake[f*] |
Cross-entropy on gate predictions vs ground-truth best subsystem |
| L_div | Σ (p̄_eng - 1/3)² + Σ (p̄_lake - 1/3)² |
Diversity regularizer — keep batch-mean gate probabilities near uniform |
Additional anti-collapse regularizers (configurable):
lambda_diversity× entropy-max term (push batch-mean prob away from 1-hot)lambda_emb_spread× variance-of-workload-embeddings + InfoNCE-style cosine penalty
Each outer epoch runs three sub-stages back to back:
- End-to-end (Stage 1) — All params trained with
L_MSE + L_CE + λ_div × L_divvia Gumbel-soft routing. - Gate-focused (Stage 2) — Only gates trained on
L_CE + L_div(tree-conv frozen). - Expert-focused (Stage 3) — Each expert trained on every (config, ratio) record routed via the actual (engine, lake) ID (tree-conv frozen).
- Adam lr=3e-4
- Weight decay: 1e-3 on tree-conv params (anti-collapse), 1e-5 elsewhere
- CosineAnnealing scheduler, eta_min=1e-5
- Gradient clipping at 1.0
For each test query:
- Compute query embedding (single query → AttentionPool)
- Run gates → pick
(eng*, lake*)via argmax - Score every conf in
(eng*, lake*)for that query viaforward_for_eng_lak - Pick argmin pred →
chosen_actual_lat ratio(q) = chosen_actual_lat / min_actual_latency(across all combos)
Average ratio across all test queries. Lower is better (1.0 = always optimal).
Random query-level split, default 70 / 15 / 15. Best checkpoint is selected on the validation set; the test set is evaluated only once at the end with the best-validation checkpoint.
- Tree cache: First run processes plan files into tree tensors and saves
.tree_cache.pt. Subsequent runs load instantly. - Latency floor repair: Exactly-1500ms records (timeout artifacts) are replaced with samples drawn from that query+combo's latency distribution.
- Query normalization:
tpch_0_q1→sf10_q1(strips datalake-specific prefix, adds sf prefix for cross-datalake consistency). - Config encoding: Configs are parsed into numeric vectors, padded to
max_dim=16, then encoded by a 3-layerConfEncoderMLP into 64-dim representation.