@@ -214,6 +214,94 @@ def _save_population(
214214 except OSError as exc :
215215 logger .warning ("RCMC population CSV could not be saved: %s" , exc )
216216
217+ def _solve_ode_rk4 (
218+ self ,
219+ D : np .ndarray ,
220+ p0 : np .ndarray ,
221+ t_end : float ,
222+ ) -> "tuple[np.ndarray, bool]" :
223+ """Integrate the contracted master equation dp/dt = D p via classical RK4.
224+
225+ RCMC contraction guarantees that D is non-stiff: Schur-complement
226+ elimination has already removed the fast (S-state) modes from K, so
227+ explicit RK4 is appropriate here and avoids the overhead of implicit
228+ stiff solvers.
229+
230+ The step size is chosen so that dt * rho(D) stays inside the RK4
231+ stability region on the real axis (|z| <= 2.78). A safety factor of
232+ ~0.9 gives dt_max = 2.5 / rho(D).
233+
234+ Parameters
235+ ----------
236+ D : np.ndarray
237+ Contracted effective rate matrix (|T| x |T|) [s^-1]. Non-stiff
238+ by RCMC construction.
239+ p0 : np.ndarray
240+ Initial population vector over T super-states, already normalised.
241+ t_end : float
242+ Target integration time [s].
243+
244+ Returns
245+ -------
246+ y : np.ndarray
247+ Population distribution at t_end, clipped to [0, 1] and normalised.
248+ reliable : bool
249+ True — integration completed within _MAX_STEPS; result is trustworthy.
250+ False — required step count exceeded _MAX_STEPS, meaning D remained
251+ effectively stiff after RCMC contraction (e.g. a strongly
252+ irreversible edge or an isolated node still in T).
253+ The caller should fall back to the QSS algebraic approximation.
254+ """
255+ if t_end < 1e-16 :
256+ return p0 .copy (), True
257+
258+ abs_D = np .abs (D )
259+ # Gershgorin-based spectral radius estimate: min of max row and column
260+ # absolute sums (tighter than either bound alone).
261+ rho_D = min (
262+ float (np .max (abs_D .sum (axis = 1 ))),
263+ float (np .max (abs_D .sum (axis = 0 ))),
264+ )
265+
266+ if rho_D < 1e-16 :
267+ # D is effectively zero — population does not evolve.
268+ return p0 .copy (), True
269+
270+ # RK4 stability region on the negative real axis ends at z ≈ -2.78.
271+ # Safety factor 0.9 keeps dt * rho_D <= 2.5.
272+ dt_max = 2.5 / rho_D
273+ n_steps = int (np .ceil (t_end / dt_max ))
274+
275+ _MAX_STEPS = 10_000_000
276+ if n_steps > _MAX_STEPS :
277+ # D is still stiff despite RCMC contraction — signal caller to use
278+ # the QSS algebraic fallback rather than capping steps silently.
279+ logger .warning (
280+ "RK4: %d steps required for t_end=%.3e s (rho_D=%.3e s^-1) "
281+ "exceeds _MAX_STEPS=%d — D is still stiff after RCMC contraction "
282+ "(likely a strongly irreversible edge or isolated node in T). "
283+ "Falling back to quasi-steady-state (QSS) algebraic approximation." ,
284+ n_steps , t_end , rho_D , _MAX_STEPS ,
285+ )
286+ return p0 .copy (), False
287+
288+ dt = t_end / n_steps
289+ y = p0 .copy ()
290+
291+ for _ in range (n_steps ):
292+ k1 = D @ y
293+ k2 = D @ (y + 0.5 * dt * k1 )
294+ k3 = D @ (y + 0.5 * dt * k2 )
295+ k4 = D @ (y + dt * k3 )
296+ y = y + (dt / 6.0 ) * (k1 + 2.0 * k2 + 2.0 * k3 + k4 )
297+ # Clip and renormalise every step to prevent slow numerical drift.
298+ y = np .maximum (y , 0.0 )
299+ total = y .sum ()
300+ if total > 0.0 :
301+ y /= total
302+
303+ return y , True
304+
217305 def pop (self ) -> ExplorationTask | None :
218306 if not self ._tasks :
219307 return None
@@ -421,18 +509,67 @@ def _edge_ts_e(edge):
421509 p_T = p [T ]
422510
423511 try :
424- # Factorise K_SS once; reuse for the three back-solves.
512+ # ── Build initial population for T super-states ───────────
513+ # Each T super-state collects the probability of every node
514+ # absorbed into it during RCMC contraction (the T representative
515+ # itself plus any S nodes whose fast dynamics were projected onto
516+ # it via Schur-complement elimination).
517+ # Summing p[members] is the standard probability-conserving
518+ # lumping step and is self-consistent with the quasi-steady-state
519+ # assumption used to derive D: within each super-state the fast
520+ # internal modes are already eliminated, so aggregating the raw
521+ # probability gives the correct effective initial condition for
522+ # the contracted dynamics.
523+ p_T_init = np .zeros (len (T ), dtype = np .float64 )
524+ for i , t_global in enumerate (T ):
525+ members = superstate_members .get (t_global , [t_global ])
526+ p_T_init [i ] = float (np .sum (p [members ]))
527+
528+ # Normalise: guards against floating-point drift and the rare
529+ # edge case where the start node ended up in S rather than T.
530+ total_p_T_init = p_T_init .sum ()
531+ if total_p_T_init > 0.0 :
532+ p_T_init /= total_p_T_init
533+
534+ # ── Factorise K_SS once; reused by both RK4 and QSS paths ─
425535 lu , piv = lu_factor (K_SS_buf )
426- X_ST = lu_solve ((lu , piv ), K_ST )
427- X_pS = lu_solve ((lu , piv ), p_S )
428- X_ST_2 = lu_solve ((lu , piv ), X_ST )
429-
430- M = np .eye (len (T )) + K_TS @ X_ST_2
431- m_vec = np .sum (M , axis = 0 )
432- V_TT_diag = 1.0 / np .where (np .abs (m_vec ) > 1e-16 , m_vec , 1e-16 )
536+ X_ST = lu_solve ((lu , piv ), K_ST )
537+
538+ # ── Primary path: integrate contracted dynamics via RK4 ───
539+ # D is non-stiff by RCMC construction (all fast modes removed),
540+ # so classical explicit RK4 is appropriate and efficient here.
541+ # _solve_ode_rk4 returns a reliability flag: False means the
542+ # required step count exceeded _MAX_STEPS, i.e. D was still
543+ # effectively stiff (isolated node, strongly irreversible edge).
544+ q_T , rk4_reliable = self ._solve_ode_rk4 (
545+ D , p_T_init , self .reaction_time_s
546+ )
433547
434- q_T = V_TT_diag * (p_T - K_TS @ X_pS )
435- q_S = - X_ST @ q_T
548+ if not rk4_reliable :
549+ # ── Fallback: quasi-steady-state algebraic approximation ─
550+ # Used only when RK4 would need more than _MAX_STEPS steps.
551+ # QSS solves the adiabatic-elimination equations directly
552+ # and is step-count-independent, making it robust for stiff
553+ # residual dynamics that RCMC failed to fully remove.
554+ logger .warning (
555+ "RK4 fallback to QSS for pop_step=%d "
556+ "(|T|=%d, |S|=%d, t=%.3e s)." ,
557+ self ._pop_count , len (T ), len (S ), self .reaction_time_s ,
558+ )
559+ X_pS = lu_solve ((lu , piv ), p_S )
560+ X_ST_2 = lu_solve ((lu , piv ), X_ST )
561+ M = np .eye (len (T )) + K_TS @ X_ST_2
562+ m_vec = np .sum (M , axis = 0 )
563+ V_TT_diag = 1.0 / np .where (
564+ np .abs (m_vec ) > 1e-16 , m_vec , 1e-16
565+ )
566+ q_T = V_TT_diag * (p_T - K_TS @ X_pS )
567+
568+ # ── Back-project S populations via quasi-steady-state ─────
569+ # Under the adiabatic-elimination assumption that underpins D:
570+ # K_SS * q_S + K_ST * q_T = 0 => q_S = -K_SS^{-1} K_ST q_T
571+ # X_ST = K_SS^{-1} K_ST was computed above and is reused here.
572+ q_S = - (X_ST @ q_T )
436573
437574 # Guarantee non-negativity and normalize
438575 q_T = np .maximum (q_T , 0.0 )
@@ -745,7 +882,7 @@ def all_edges(self) -> list:
745882 print (f"\n { 'Node' :<{col_w }} { 'Population' :>18} { 'Rank' :>6} " )
746883 print (f"{ '─' * col_w } { '─' * 18 } { '─' * 6 } " )
747884 for rank , task in enumerate (all_tasks , start = 1 ):
748- marker = " ◀ highest priority" if rank == 1 else ""
885+ marker = " <- highest priority" if rank == 1 else ""
749886 print (
750887 f"EQ{ task .node_id :<{col_w - 2 }} { task .priority :18.8e} { rank :6d} { marker } "
751888 )
0 commit comments