@@ -15,18 +15,17 @@ class StructKeepAnglePotential:
1515
1616 Singularity Handling Strategies:
1717 1. **Taylor Expansion (Physical Accuracy)**:
18- Applied when the equilibrium angle :math:`\\ theta_0` is EXACTLY 0 or :math:`\\ pi` (within `EPSILON_PARAM`) .
19- This uses a high-order Taylor expansion of :math:` \\ arccos` to maintain physical accuracy for linear
18+ Applied when the equilibrium angle :math:`\\ theta_0` is EXACTLY 0 or :math:`\\ pi`.
19+ This uses a high-order Taylor expansion to maintain physical accuracy for linear
2020 molecules or planar transition states without gradient explosion.
21-
22- *Note*: To avoid approximation errors at large angles, the method switches back to the exact
23- analytical solution when the current angle is far from the singularity.
2421
2522 2. **Quadratic Extrapolation (Numerical Stability)**:
26- Applied when a normally bent molecule (:math:`\\ theta_0 \\ neq 0, \\ pi`) is forced into linearity.
27- The potential is replaced by a quadratic polynomial near the singularity (`THETA_CUT`).
28- A **Gauss-Newton Approximation** is used for the Hessian (:math:`d^2E/du^2 \\ approx k (d\\ theta/du)^2`),
29- ensuring positive curvature and preventing optimizer instability.
23+ Applied in two cases:
24+ - When a normally bent molecule (:math:`\\ theta_0 \\ neq 0, \\ pi`) is forced into linearity (0 or 180 deg).
25+ - When a linear molecule (:math:`\\ theta_0 = 0`) bends all the way to the antipodal singularity (:math:`180` deg), or vice versa.
26+
27+ This uses a Gauss-Newton Approximation for the Hessian (:math:`d^2E/du^2 \\ approx k (d\\ theta/du)^2`),
28+ ensuring positive curvature and preventing optimizer instability at the poles.
3029
3130 Attributes:
3231 config (dict): Configuration dictionary containing potential parameters.
@@ -40,21 +39,19 @@ def __init__(self, **kwarg):
4039
4140 Args:
4241 **kwarg: Arbitrary keyword arguments containing configuration parameters.
43- Expected keys include 'keep_angle_spring_const', 'keep_angle_angle', etc.
42+ Expected keys include 'keep_angle_spring_const', 'keep_angle_angle', 'keep_angle_atom_pairs', etc.
4443 """
4544 self .config = kwarg
4645 UVL = UnitValueLib ()
4746 self .hartree2kcalmol = UVL .hartree2kcalmol
4847 self .bohr2angstroms = UVL .bohr2angstroms
4948 self .hartree2kjmol = UVL .hartree2kjmol
5049
51- # Threshold: Below this angle (rad), switch to extrapolation.
50+ # Threshold: Below this angle (rad), switch to extrapolation/Taylor .
5251 self .THETA_CUT = 1e-3
5352
5453 # Threshold for Equilibrium check.
55- # Set to 1e-9 to strictly distinguish between theta_0=0 (Taylor) and theta_0=0.001 (General).
5654 self .EPSILON_PARAM = 1e-9
57-
5855 return
5956
6057 def calc_energy (self , geom_num_list , bias_pot_params = []):
@@ -118,86 +115,109 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
118115 # 3. Singularity Handling Logic
119116 # ========================================
120117
121- # Pre-calculate thresholds for switching
122- u_cut_pos = torch .cos (theta_cut_val )
123- u_cut_neg = torch .cos (PI - theta_cut_val )
118+ # Thresholds
119+ u_cut_pos = torch .cos (theta_cut_val ) # Corresponds to theta ~ 0
120+ u_cut_neg = torch .cos (PI - theta_cut_val ) # Corresponds to theta ~ pi
121+
122+ # --- Helper for Quadratic Extrapolation ---
123+ # Used for any region where the angle approaches a singularity that is NOT the equilibrium.
124+ def get_quad_params (th_cut ):
125+ # Calculate properties at the cutoff boundary
126+ sin_cut = torch .sin (th_cut )
127+ dth_du = - 1.0 / sin_cut # Chain rule: d(acos(u))/du = -1/sqrt(1-u^2) = -1/sin(theta)
128+
129+ # Energy value at cutoff
130+ val = 0.5 * k * (th_cut - theta_0 )** 2
131+
132+ # First derivative w.r.t u at cutoff: dE/du = dE/dth * dth/du
133+ dE_dth = k * (th_cut - theta_0 )
134+ d1 = dE_dth * dth_du
135+
136+ # Second derivative w.r.t u (Gauss-Newton Approx): d2E/du2 approx k * (dth/du)^2
137+ # We ignore the d2th/du2 term to ensure positive curvature (convexity)
138+ d2 = k * (dth_du ** 2 )
139+
140+ return val .detach (), d1 .detach (), d2 .detach ()
141+
124142
125143 # --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
126144 if torch .abs (theta_0 ) < epsilon_param :
127- # Region 1: Singularity (theta ~ 0, u ~ 1) -> Use Taylor Expansion
128- # Formula: theta^2 approx 2(1-u) + (1-u)^2/3 + 8/45*(1-u)^3
145+ # Region 1: Singularity at Equilibrium (theta ~ 0) -> Taylor Expansion
129146 delta = 1.0 - u
130- # Corrected coeff: 4.0/45.0 -> 8.0/45.0 for (acos(1-x)) ^2 expansion
147+ # Corrected Taylor expansion for theta ^2 around u=1
131148 theta_sq_taylor = delta * (2.0 + delta * (1.0 / 3.0 + delta * 8.0 / 45.0 ))
132149 E_taylor = 0.5 * k * theta_sq_taylor
133150
134- # Region 2: Normal (theta > cut) -> Use Exact Analytical Solution
135- # Note: Force u to be safe for acos to prevent NaN gradients in masked region
136- u_safe = torch .clamp (u , - 1.0 , u_cut_pos )
151+ # Region 2: Singularity at Opposite Pole (theta ~ pi) -> Quadratic Extrapolation
152+ # Even if equilibrium is 0, we must protect against 180 degrees.
153+ theta_cut_pi = PI - theta_cut_val
154+ val_pi , d1_pi , d2_pi = get_quad_params (theta_cut_pi )
155+ diff_pi = u - u_cut_neg
156+ E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi ** 2 )
157+
158+ # Region 3: Normal -> Exact Analytical
159+ u_safe = torch .clamp (u , - 1.0 , u_cut_pos )
137160 theta_exact = torch .acos (u_safe )
138161 E_exact = 0.5 * k * (theta_exact ** 2 )
139162
140- # Switch based on current angle to avoid Taylor error at large angles
141- return torch .where (u > u_cut_pos , E_taylor , E_exact )
163+ # Combine
164+ return torch .where (
165+ u > u_cut_pos ,
166+ E_taylor ,
167+ torch .where (u < u_cut_neg , E_quad_pi , E_exact )
168+ )
142169
143170 # --- BRANCH B: EXACTLY Planar Equilibrium (theta_0 ~ pi) ---
144171 elif torch .abs (theta_0 - PI ) < epsilon_param :
145- # Region 1: Singularity (theta ~ pi, u ~ -1 ) -> Use Taylor Expansion
172+ # Region 1: Singularity at Equilibrium (theta ~ pi) -> Taylor Expansion
146173 delta = 1.0 + u
147- # Corrected coeff: 4.0/45.0 -> 8.0/45.0
174+ # Corrected Taylor expansion for (theta-pi)^2 around u=-1
148175 diff_sq_taylor = delta * (2.0 + delta * (1.0 / 3.0 + delta * 8.0 / 45.0 ))
149176 E_taylor = 0.5 * k * diff_sq_taylor
150177
151- # Region 2: Normal -> Use Exact Analytical Solution
178+ # Region 2: Singularity at Opposite Pole (theta ~ 0) -> Quadratic Extrapolation
179+ val_0 , d1_0 , d2_0 = get_quad_params (theta_cut_val )
180+ diff_0 = u - u_cut_pos
181+ E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0 ** 2 )
182+
183+ # Region 3: Normal -> Exact Analytical
152184 u_safe = torch .clamp (u , u_cut_neg , 1.0 )
153185 theta_exact = torch .acos (u_safe )
154186 E_exact = 0.5 * k * (theta_exact - theta_0 ) ** 2
155187
156- return torch .where (u < u_cut_neg , E_taylor , E_exact )
188+ # Combine
189+ return torch .where (
190+ u < u_cut_neg ,
191+ E_taylor ,
192+ torch .where (u > u_cut_pos , E_quad_0 , E_exact )
193+ )
157194
158195 # --- BRANCH C: General Angle ---
159196 else :
160197 is_singular_0 = (u > u_cut_pos ) # theta -> 0
161198 is_singular_pi = (u < u_cut_neg ) # theta -> pi
162- is_safe = ~ (is_singular_0 | is_singular_pi )
163199
164200 # 1. Normal Region (Safe acos)
165201 theta_safe = torch .acos (u )
166202 E_safe = 0.5 * k * (theta_safe - theta_0 ) ** 2
167203
168- # Helper for Quadratic Extrapolation with Gauss-Newton Hessian
169- def get_quad_params (th_cut , u_bnd ):
170- sin_cut = torch .sin (th_cut )
171- dth_du = - 1.0 / sin_cut
172-
173- # --- Gauss-Newton Approximation ---
174- # Drop d2th/du2 term to ensure positive curvature (convexity)
175-
176- val = 0.5 * k * (th_cut - theta_0 )** 2
177- d1 = k * (th_cut - theta_0 ) * dth_du
178- d2 = k * (dth_du ** 2 ) # Gauss-Newton Approx: strictly positive
179-
180- return val .detach (), d1 .detach (), d2 .detach ()
181-
182204 # 2. Extrapolation: theta -> 0
183- val_0 , d1_0 , d2_0 = get_quad_params (theta_cut_val , u_cut_pos )
205+ val_0 , d1_0 , d2_0 = get_quad_params (theta_cut_val )
184206 diff_0 = u - u_cut_pos
185207 E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0 ** 2 )
186208
187209 # 3. Extrapolation: theta -> pi
188210 theta_cut_pi = PI - theta_cut_val
189- val_pi , d1_pi , d2_pi = get_quad_params (theta_cut_pi , u_cut_neg )
211+ val_pi , d1_pi , d2_pi = get_quad_params (theta_cut_pi )
190212 diff_pi = u - u_cut_neg
191213 E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi ** 2 )
192214
193215 # Integration using masks
194- energy = torch .where (
216+ return torch .where (
195217 is_singular_0 ,
196218 E_quad_0 ,
197219 torch .where (is_singular_pi , E_quad_pi , E_safe )
198220 )
199-
200- return energy
201221
202222
203223class StructKeepAnglePotentialv2 :
@@ -209,12 +229,19 @@ class StructKeepAnglePotentialv2:
209229 but operates on the geometric centers of three specified atom fragments.
210230
211231 Singularity Handling Strategies:
212- 1. **Taylor Expansion**: For exactly linear/planar equilibrium geometries (F1-F2-F3 aligned).
213- 2. **Quadratic Extrapolation**: For general bent geometries forced into linearity,
214- using Gauss-Newton approximation for Hessian stability.
232+ 1. **Taylor Expansion (Physical Accuracy)**:
233+ Applied when the equilibrium angle :math:`\\ theta_0` is EXACTLY 0 or :math:`\\ pi`.
234+ Maintains accuracy for linear/planar equilibrium states.
235+
236+ 2. **Quadratic Extrapolation (Numerical Stability)**:
237+ Applied to singularities (0 or 180 deg) that are NOT the equilibrium angle.
238+ This includes the antipodal pole for linear/planar molecules (e.g., theta ~ 180 when theta_0 ~ 0),
239+ ensuring the gradient never explodes even in extreme bending configurations.
215240
216241 Attributes:
217242 config (dict): Configuration dictionary containing potential parameters.
243+ THETA_CUT (float): Angle threshold (radians) to switch to extrapolation. Default is 1e-3.
244+ EPSILON_PARAM (float): Threshold (radians) to distinguish exactly linear/planar equilibrium. Default is 1e-9.
218245 """
219246
220247 def __init__ (self , ** kwarg ):
@@ -291,70 +318,98 @@ def get_centroid(key):
291318 u = torch .dot (vec1 , vec2 ) / (norm1_2 )
292319 u = torch .clamp (u , - 1.0 , 1.0 )
293320
321+ # ========================================
294322 # 3. Singularity Handling Logic
323+ # ========================================
324+
325+ # Thresholds
295326 u_cut_pos = torch .cos (theta_cut_val )
296327 u_cut_neg = torch .cos (PI - theta_cut_val )
297328
298- # --- BRANCH A: EXACTLY Linear Equilibrium ---
329+ # --- Helper for Quadratic Extrapolation ---
330+ def get_quad_params (th_cut ):
331+ sin_cut = torch .sin (th_cut )
332+ dth_du = - 1.0 / sin_cut
333+
334+ val = 0.5 * k * (th_cut - theta_0 )** 2
335+
336+ dE_dth = k * (th_cut - theta_0 )
337+ d1 = dE_dth * dth_du
338+
339+ # Gauss-Newton Approximation
340+ d2 = k * (dth_du ** 2 )
341+
342+ return val .detach (), d1 .detach (), d2 .detach ()
343+
344+ # --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
299345 if torch .abs (theta_0 ) < epsilon_param :
346+ # Region 1: Singularity at Equilibrium (theta ~ 0) -> Taylor Expansion
300347 delta = 1.0 - u
301- # Corrected Taylor expansion for theta^2
302348 theta_sq_taylor = delta * (2.0 + delta * (1.0 / 3.0 + delta * 8.0 / 45.0 ))
303349 E_taylor = 0.5 * k * theta_sq_taylor
304350
351+ # Region 2: Singularity at Opposite Pole (theta ~ pi) -> Quadratic Extrapolation
352+ theta_cut_pi = PI - theta_cut_val
353+ val_pi , d1_pi , d2_pi = get_quad_params (theta_cut_pi )
354+ diff_pi = u - u_cut_neg
355+ E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi ** 2 )
356+
357+ # Region 3: Normal -> Exact Analytical
305358 u_safe = torch .clamp (u , - 1.0 , u_cut_pos )
306359 theta_exact = torch .acos (u_safe )
307360 E_exact = 0.5 * k * (theta_exact ** 2 )
308361
309- return torch .where (u > u_cut_pos , E_taylor , E_exact )
362+ return torch .where (
363+ u > u_cut_pos ,
364+ E_taylor ,
365+ torch .where (u < u_cut_neg , E_quad_pi , E_exact )
366+ )
310367
311- # --- BRANCH B: EXACTLY Planar Equilibrium ---
368+ # --- BRANCH B: EXACTLY Planar Equilibrium (theta_0 ~ pi) ---
312369 elif torch .abs (theta_0 - PI ) < epsilon_param :
370+ # Region 1: Singularity at Equilibrium (theta ~ pi) -> Taylor Expansion
313371 delta = 1.0 + u
314372 diff_sq_taylor = delta * (2.0 + delta * (1.0 / 3.0 + delta * 8.0 / 45.0 ))
315373 E_taylor = 0.5 * k * diff_sq_taylor
316374
375+ # Region 2: Singularity at Opposite Pole (theta ~ 0) -> Quadratic Extrapolation
376+ val_0 , d1_0 , d2_0 = get_quad_params (theta_cut_val )
377+ diff_0 = u - u_cut_pos
378+ E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0 ** 2 )
379+
380+ # Region 3: Normal -> Exact Analytical
317381 u_safe = torch .clamp (u , u_cut_neg , 1.0 )
318382 theta_exact = torch .acos (u_safe )
319383 E_exact = 0.5 * k * (theta_exact - theta_0 ) ** 2
320384
321- return torch .where (u < u_cut_neg , E_taylor , E_exact )
385+ return torch .where (
386+ u < u_cut_neg ,
387+ E_taylor ,
388+ torch .where (u > u_cut_pos , E_quad_0 , E_exact )
389+ )
322390
323391 # --- BRANCH C: General Angle ---
324392 else :
325393 is_singular_0 = (u > u_cut_pos )
326394 is_singular_pi = (u < u_cut_neg )
327- is_safe = ~ (is_singular_0 | is_singular_pi )
328395
329396 theta_safe = torch .acos (u )
330397 E_safe = 0.5 * k * (theta_safe - theta_0 ) ** 2
331398
332- def get_quad_params (th_cut , u_bnd ):
333- sin_cut = torch .sin (th_cut )
334- dth_du = - 1.0 / sin_cut
335-
336- val = 0.5 * k * (th_cut - theta_0 )** 2
337- d1 = k * (th_cut - theta_0 ) * dth_du
338- d2 = k * (dth_du ** 2 ) # Gauss-Newton Approximation
339- return val .detach (), d1 .detach (), d2 .detach ()
340-
341- val_0 , d1_0 , d2_0 = get_quad_params (theta_cut_val , u_cut_pos )
399+ val_0 , d1_0 , d2_0 = get_quad_params (theta_cut_val )
342400 diff_0 = u - u_cut_pos
343401 E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0 ** 2 )
344402
345403 theta_cut_pi = PI - theta_cut_val
346- val_pi , d1_pi , d2_pi = get_quad_params (theta_cut_pi , u_cut_neg )
404+ val_pi , d1_pi , d2_pi = get_quad_params (theta_cut_pi )
347405 diff_pi = u - u_cut_neg
348406 E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi ** 2 )
349407
350- energy = torch .where (
408+ return torch .where (
351409 is_singular_0 ,
352410 E_quad_0 ,
353411 torch .where (is_singular_pi , E_quad_pi , E_safe )
354- )
355- return energy
356-
357-
412+ )
358413class StructKeepAnglePotentialAtomDistDependent :
359414 def __init__ (self , ** kwarg ):
360415 self .config = kwarg
0 commit comments