Skip to content

Commit 29aca16

Browse files
AnHeuermannclaude
andauthored
Harden simulation and reference comparison validation (#29)
- Fail run_simulate when sol.t is empty or the system has no states or observed variables, preventing empty CSVs and downstream panics - Include observed variables in the simulation CSV output so models without state variables (e.g. BusUsage) still produce usable results - Fail compare_with_reference (instead of skip) when reference signals are absent from the simulation; write NaN sim/relerr columns for signals with no simulation counterpart - Validate upfront that all comparisonSignals.txt entries exist in the reference CSV, and error if the simulation time interval does not cover the reference interval - Extend the per-model status line to show a CMP OK/FAIL phase and signal counts including skipped signals Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 92cacb5 commit 29aca16

3 files changed

Lines changed: 76 additions & 33 deletions

File tree

src/compare.jl

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,17 @@ function compare_with_reference(
370370
isempty(times) && return 0, 0, 0, ""
371371

372372
# Determine which signals to compare: prefer comparisonSignals.txt
373-
sig_file = joinpath(dirname(ref_csv_path), "comparisonSignals.txt")
374-
signals = if isfile(sig_file)
375-
filter(s -> lowercase(s) != "time" && !isempty(s), strip.(readlines(sig_file)))
373+
sig_file = joinpath(dirname(ref_csv_path), "comparisonSignals.txt")
374+
using_sig_file = isfile(sig_file)
375+
signals = if using_sig_file
376+
sigs = filter(s -> lowercase(s) != "time" && !isempty(s), strip.(readlines(sig_file)))
377+
sigs_missing = filter(s -> !haskey(ref_data, s), sigs)
378+
isempty(sigs_missing) || error("Signal(s) listed in comparisonSignals.txt not present in reference CSV: $(join(sigs_missing, ", "))")
379+
sigs
376380
else
377381
filter(k -> lowercase(k) != "time", collect(keys(ref_data)))
378382
end
383+
n_total = length(signals)
379384

380385
# ── Build variable accessor map ──────────────────────────────────────────────
381386
# var_access: normalized name → Int (state index) or MTK symbolic (observed).
@@ -397,32 +402,40 @@ function compare_with_reference(
397402
@warn "Could not enumerate observed variables: $(sprint(showerror, e))"
398403
end
399404

400-
# Clip reference time to the simulation interval
405+
# Verify the simulation covers the expected reference time interval.
406+
# A large gap means the solver stopped early or started late.
407+
isempty(sol.t) && return n_total, 0, 0, ""
401408
t_start = sol.t[1]
402409
t_end = sol.t[end]
410+
ref_t_start = times[1]
411+
ref_t_end = times[end]
412+
if t_start > ref_t_start || t_end < ref_t_end
413+
@error "Simulation interval [$(t_start), $(t_end)] does not cover " *
414+
"reference interval [$(ref_t_start), $(ref_t_end)]"
415+
return n_total, 0, 0, ""
416+
end
417+
418+
# Clip reference time to the simulation interval
403419
valid_mask = (times .>= t_start) .& (times .<= t_end)
404420
t_ref = times[valid_mask]
405-
isempty(t_ref) && return 0, 0, 0, ""
421+
isempty(t_ref) && return n_total, 0, 0, ""
406422

407-
n_total = 0
408423
n_pass = 0
409424
pass_sigs = String[]
410425
fail_sigs = String[]
411-
skip_sigs = String[]
412426
fail_scales = Dict{String,Float64}()
413427

414428
for sig in signals
415-
haskey(ref_data, sig) || continue # signal absent from ref CSV entirely
429+
signal_name = _normalize_var(sig)
430+
ref_vals = ref_data[sig][valid_mask]
416431

417-
norm = _normalize_var(sig)
418-
if !haskey(var_access, norm)
419-
push!(skip_sigs, sig)
432+
if !haskey(var_access, signal_name)
433+
push!(fail_sigs, sig)
434+
fail_scales[sig] = isempty(ref_vals) ? 0.0 : maximum(abs, ref_vals)
420435
continue
421436
end
422437

423-
accessor = var_access[norm]
424-
ref_vals = ref_data[sig][valid_mask]
425-
n_total += 1
438+
accessor = var_access[signal_name]
426439

427440
# Peak |ref| — used as amplitude floor so relative error stays finite
428441
# near zero crossings.
@@ -431,10 +444,10 @@ function compare_with_reference(
431444
# Interpolate simulation at reference time points.
432445
sim_vals = [_eval_sim(sol, accessor, t) for t in t_ref]
433446

434-
# If evaluation returned NaN (observed-var access failed), treat as skip.
447+
# If evaluation returned NaN (observed-var access failed), treat as fail.
435448
if any(isnan, sim_vals)
436-
n_total -= 1
437-
push!(skip_sigs, sig)
449+
push!(fail_sigs, sig)
450+
fail_scales[sig] = ref_scale
438451
continue
439452
end
440453

@@ -468,9 +481,10 @@ function compare_with_reference(
468481
for sig in fail_sigs
469482
ref_vals = ref_data[sig][valid_mask]
470483
r = ref_vals[ti]
471-
s = _eval_sim(sol, var_access[_normalize_var(sig)], t)
484+
acc = get(var_access, _normalize_var(sig), nothing)
485+
s = acc === nothing ? NaN : _eval_sim(sol, acc, t)
472486
ref_scale = get(fail_scales, sig, 0.0)
473-
relerr = abs(s - r) / max(abs(r), ref_scale, settings.abs_tol)
487+
relerr = isnan(s) ? NaN : abs(s - r) / max(abs(r), ref_scale, settings.abs_tol)
474488
push!(row, @sprintf("%.10g", r),
475489
@sprintf("%.10g", s),
476490
@sprintf("%.6g", relerr))
@@ -481,12 +495,11 @@ function compare_with_reference(
481495
end
482496

483497
# ── Write detail HTML whenever there is anything worth showing ───────────────
484-
if !isempty(fail_sigs) || !isempty(skip_sigs)
498+
if !isempty(fail_sigs)
485499
write_diff_html(model_dir, model;
486500
diff_csv_path = diff_csv,
487-
pass_sigs = pass_sigs,
488-
skip_sigs = skip_sigs)
501+
pass_sigs = pass_sigs)
489502
end
490503

491-
return n_total, n_pass, length(skip_sigs), diff_csv
504+
return n_total, n_pass, 0, diff_csv
492505
end

src/pipeline.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,11 +189,23 @@ function main(;
189189
result = test_model(omc, model, results_root, ref_root; csv_max_size_mb)
190190
push!(results, result)
191191

192-
phase = result.sim_success ? "SIM OK" :
193-
result.parse_success ? "SIM FAIL" :
194-
result.export_success ? "PARSE FAIL" : "EXPORT FAIL"
195-
cmp_info = result.cmp_total > 0 ?
196-
" cmp=$(result.cmp_pass)/$(result.cmp_total)" : ""
192+
phase = if result.sim_success && result.cmp_total > 0
193+
result.cmp_pass == result.cmp_total ? "CMP OK" : "CMP FAIL"
194+
elseif result.sim_success
195+
"SIM OK"
196+
elseif result.parse_success
197+
"SIM FAIL"
198+
elseif result.export_success
199+
"PARSE FAIL"
200+
else
201+
"EXPORT FAIL"
202+
end
203+
cmp_info = if result.cmp_total > 0
204+
skip_note = result.cmp_skip > 0 ? " skip=$(result.cmp_skip)" : ""
205+
" cmp=$(result.cmp_pass)/$(result.cmp_total)$skip_note"
206+
else
207+
""
208+
end
197209
@info "$phase export=$(round(result.export_time;digits=2))s" *
198210
" parse=$(round(result.parse_time;digits=2))s" *
199211
" sim=$(round(result.sim_time;digits=2))s$cmp_info"

src/simulate.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,16 @@ function run_simulate(ode_prob, model_dir::String, model::String;
3737
end
3838
sim_time = time() - t0
3939
if sol.retcode == ReturnCode.Success
40-
sim_success = true
40+
sys = sol.prob.f.sys
41+
n_vars = length(ModelingToolkit.unknowns(sys))
42+
n_obs = length(ModelingToolkit.observed(sys))
43+
if isempty(sol.t)
44+
sim_error = "Simulation produced no time points"
45+
elseif n_vars == 0 && n_obs == 0
46+
sim_error = "Simulation produced no output variables (no states or observed)"
47+
else
48+
sim_success = true
49+
end
4150
else
4251
sim_error = "Solver returned: $(sol.retcode)"
4352
end
@@ -50,21 +59,30 @@ function run_simulate(ode_prob, model_dir::String, model::String;
5059
isempty(sim_error) || println(log_file, "\n--- Error ---\n$sim_error")
5160
close(log_file)
5261

53-
# Write simulation results CSV (time + all state variables)
62+
# Write simulation results CSV (time + state variables + observed variables)
5463
if sim_success && sol !== nothing
5564
short_name = split(model, ".")[end]
5665
sim_csv = joinpath(model_dir, "$(short_name)_sim.csv")
5766
try
58-
sys = sol.prob.f.sys
59-
vars = ModelingToolkit.unknowns(sys)
60-
col_names = [_clean_var_name(string(v)) for v in vars]
67+
sys = sol.prob.f.sys
68+
vars = ModelingToolkit.unknowns(sys)
69+
obs_eqs = ModelingToolkit.observed(sys)
70+
obs_syms = [eq.lhs for eq in obs_eqs]
71+
col_names = vcat(
72+
[_clean_var_name(string(v)) for v in vars],
73+
[_clean_var_name(string(s)) for s in obs_syms],
74+
)
6175
open(sim_csv, "w") do f
6276
println(f, join(["time"; col_names], ","))
6377
for (ti, t) in enumerate(sol.t)
6478
row = [@sprintf("%.10g", t)]
6579
for vi in eachindex(vars)
6680
push!(row, @sprintf("%.10g", sol[vi, ti]))
6781
end
82+
for sym in obs_syms
83+
val = try Float64(sol(t; idxs = sym)) catch; NaN end
84+
push!(row, @sprintf("%.10g", val))
85+
end
6886
println(f, join(row, ","))
6987
end
7088
end

0 commit comments

Comments
 (0)