Skip to content

Commit e203a2c

Browse files
authored
Add files via upload
1 parent 0faa2c4 commit e203a2c

1 file changed

Lines changed: 128 additions & 73 deletions

File tree

multioptpy/Potential/keep_angle_potential.py

Lines changed: 128 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@ class StructKeepAnglePotential:
1515
1616
Singularity Handling Strategies:
1717
1. **Taylor Expansion (Physical Accuracy)**:
18-
Applied when the equilibrium angle :math:`\\theta_0` is EXACTLY 0 or :math:`\\pi` (within `EPSILON_PARAM`).
19-
This uses a high-order Taylor expansion of :math:`\\arccos` to maintain physical accuracy for linear
18+
Applied when the equilibrium angle :math:`\\theta_0` is EXACTLY 0 or :math:`\\pi`.
19+
This uses a high-order Taylor expansion to maintain physical accuracy for linear
2020
molecules or planar transition states without gradient explosion.
21-
22-
*Note*: To avoid approximation errors at large angles, the method switches back to the exact
23-
analytical solution when the current angle is far from the singularity.
2421
2522
2. **Quadratic Extrapolation (Numerical Stability)**:
26-
Applied when a normally bent molecule (:math:`\\theta_0 \\neq 0, \\pi`) is forced into linearity.
27-
The potential is replaced by a quadratic polynomial near the singularity (`THETA_CUT`).
28-
A **Gauss-Newton Approximation** is used for the Hessian (:math:`d^2E/du^2 \\approx k (d\\theta/du)^2`),
29-
ensuring positive curvature and preventing optimizer instability.
23+
Applied in two cases:
24+
- When a normally bent molecule (:math:`\\theta_0 \\neq 0, \\pi`) is forced into linearity (0 or 180 deg).
25+
- When a linear molecule (:math:`\\theta_0 = 0`) bends all the way to the antipodal singularity (:math:`180` deg), or vice versa.
26+
27+
This uses a Gauss-Newton Approximation for the Hessian (:math:`d^2E/du^2 \\approx k (d\\theta/du)^2`),
28+
ensuring positive curvature and preventing optimizer instability at the poles.
3029
3130
Attributes:
3231
config (dict): Configuration dictionary containing potential parameters.
@@ -40,21 +39,19 @@ def __init__(self, **kwarg):
4039
4140
Args:
4241
**kwarg: Arbitrary keyword arguments containing configuration parameters.
43-
Expected keys include 'keep_angle_spring_const', 'keep_angle_angle', etc.
42+
Expected keys include 'keep_angle_spring_const', 'keep_angle_angle', 'keep_angle_atom_pairs', etc.
4443
"""
4544
self.config = kwarg
4645
UVL = UnitValueLib()
4746
self.hartree2kcalmol = UVL.hartree2kcalmol
4847
self.bohr2angstroms = UVL.bohr2angstroms
4948
self.hartree2kjmol = UVL.hartree2kjmol
5049

51-
# Threshold: Below this angle (rad), switch to extrapolation.
50+
# Threshold: Below this angle (rad), switch to extrapolation/Taylor.
5251
self.THETA_CUT = 1e-3
5352

5453
# Threshold for Equilibrium check.
55-
# Set to 1e-9 to strictly distinguish between theta_0=0 (Taylor) and theta_0=0.001 (General).
5654
self.EPSILON_PARAM = 1e-9
57-
5855
return
5956

6057
def calc_energy(self, geom_num_list, bias_pot_params=[]):
@@ -118,86 +115,109 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
118115
# 3. Singularity Handling Logic
119116
# ========================================
120117

121-
# Pre-calculate thresholds for switching
122-
u_cut_pos = torch.cos(theta_cut_val)
123-
u_cut_neg = torch.cos(PI - theta_cut_val)
118+
# Thresholds
119+
u_cut_pos = torch.cos(theta_cut_val) # Corresponds to theta ~ 0
120+
u_cut_neg = torch.cos(PI - theta_cut_val) # Corresponds to theta ~ pi
121+
122+
# --- Helper for Quadratic Extrapolation ---
123+
# Used for any region where the angle approaches a singularity that is NOT the equilibrium.
124+
def get_quad_params(th_cut):
125+
# Calculate properties at the cutoff boundary
126+
sin_cut = torch.sin(th_cut)
127+
dth_du = -1.0 / sin_cut # Chain rule: d(acos(u))/du = -1/sqrt(1-u^2) = -1/sin(theta)
128+
129+
# Energy value at cutoff
130+
val = 0.5 * k * (th_cut - theta_0)**2
131+
132+
# First derivative w.r.t u at cutoff: dE/du = dE/dth * dth/du
133+
dE_dth = k * (th_cut - theta_0)
134+
d1 = dE_dth * dth_du
135+
136+
# Second derivative w.r.t u (Gauss-Newton Approx): d2E/du2 approx k * (dth/du)^2
137+
# We ignore the d2th/du2 term to ensure positive curvature (convexity)
138+
d2 = k * (dth_du**2)
139+
140+
return val.detach(), d1.detach(), d2.detach()
141+
124142

125143
# --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
126144
if torch.abs(theta_0) < epsilon_param:
127-
# Region 1: Singularity (theta ~ 0, u ~ 1) -> Use Taylor Expansion
128-
# Formula: theta^2 approx 2(1-u) + (1-u)^2/3 + 8/45*(1-u)^3
145+
# Region 1: Singularity at Equilibrium (theta ~ 0) -> Taylor Expansion
129146
delta = 1.0 - u
130-
# Corrected coeff: 4.0/45.0 -> 8.0/45.0 for (acos(1-x))^2 expansion
147+
# Corrected Taylor expansion for theta^2 around u=1
131148
theta_sq_taylor = delta * (2.0 + delta * (1.0/3.0 + delta * 8.0/45.0))
132149
E_taylor = 0.5 * k * theta_sq_taylor
133150

134-
# Region 2: Normal (theta > cut) -> Use Exact Analytical Solution
135-
# Note: Force u to be safe for acos to prevent NaN gradients in masked region
136-
u_safe = torch.clamp(u, -1.0, u_cut_pos)
151+
# Region 2: Singularity at Opposite Pole (theta ~ pi) -> Quadratic Extrapolation
152+
# Even if equilibrium is 0, we must protect against 180 degrees.
153+
theta_cut_pi = PI - theta_cut_val
154+
val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi)
155+
diff_pi = u - u_cut_neg
156+
E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi**2)
157+
158+
# Region 3: Normal -> Exact Analytical
159+
u_safe = torch.clamp(u, -1.0, u_cut_pos)
137160
theta_exact = torch.acos(u_safe)
138161
E_exact = 0.5 * k * (theta_exact ** 2)
139162

140-
# Switch based on current angle to avoid Taylor error at large angles
141-
return torch.where(u > u_cut_pos, E_taylor, E_exact)
163+
# Combine
164+
return torch.where(
165+
u > u_cut_pos,
166+
E_taylor,
167+
torch.where(u < u_cut_neg, E_quad_pi, E_exact)
168+
)
142169

143170
# --- BRANCH B: EXACTLY Planar Equilibrium (theta_0 ~ pi) ---
144171
elif torch.abs(theta_0 - PI) < epsilon_param:
145-
# Region 1: Singularity (theta ~ pi, u ~ -1) -> Use Taylor Expansion
172+
# Region 1: Singularity at Equilibrium (theta ~ pi) -> Taylor Expansion
146173
delta = 1.0 + u
147-
# Corrected coeff: 4.0/45.0 -> 8.0/45.0
174+
# Corrected Taylor expansion for (theta-pi)^2 around u=-1
148175
diff_sq_taylor = delta * (2.0 + delta * (1.0/3.0 + delta * 8.0/45.0))
149176
E_taylor = 0.5 * k * diff_sq_taylor
150177

151-
# Region 2: Normal -> Use Exact Analytical Solution
178+
# Region 2: Singularity at Opposite Pole (theta ~ 0) -> Quadratic Extrapolation
179+
val_0, d1_0, d2_0 = get_quad_params(theta_cut_val)
180+
diff_0 = u - u_cut_pos
181+
E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0**2)
182+
183+
# Region 3: Normal -> Exact Analytical
152184
u_safe = torch.clamp(u, u_cut_neg, 1.0)
153185
theta_exact = torch.acos(u_safe)
154186
E_exact = 0.5 * k * (theta_exact - theta_0) ** 2
155187

156-
return torch.where(u < u_cut_neg, E_taylor, E_exact)
188+
# Combine
189+
return torch.where(
190+
u < u_cut_neg,
191+
E_taylor,
192+
torch.where(u > u_cut_pos, E_quad_0, E_exact)
193+
)
157194

158195
# --- BRANCH C: General Angle ---
159196
else:
160197
is_singular_0 = (u > u_cut_pos) # theta -> 0
161198
is_singular_pi = (u < u_cut_neg) # theta -> pi
162-
is_safe = ~(is_singular_0 | is_singular_pi)
163199

164200
# 1. Normal Region (Safe acos)
165201
theta_safe = torch.acos(u)
166202
E_safe = 0.5 * k * (theta_safe - theta_0) ** 2
167203

168-
# Helper for Quadratic Extrapolation with Gauss-Newton Hessian
169-
def get_quad_params(th_cut, u_bnd):
170-
sin_cut = torch.sin(th_cut)
171-
dth_du = -1.0 / sin_cut
172-
173-
# --- Gauss-Newton Approximation ---
174-
# Drop d2th/du2 term to ensure positive curvature (convexity)
175-
176-
val = 0.5 * k * (th_cut - theta_0)**2
177-
d1 = k * (th_cut - theta_0) * dth_du
178-
d2 = k * (dth_du**2) # Gauss-Newton Approx: strictly positive
179-
180-
return val.detach(), d1.detach(), d2.detach()
181-
182204
# 2. Extrapolation: theta -> 0
183-
val_0, d1_0, d2_0 = get_quad_params(theta_cut_val, u_cut_pos)
205+
val_0, d1_0, d2_0 = get_quad_params(theta_cut_val)
184206
diff_0 = u - u_cut_pos
185207
E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0**2)
186208

187209
# 3. Extrapolation: theta -> pi
188210
theta_cut_pi = PI - theta_cut_val
189-
val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi, u_cut_neg)
211+
val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi)
190212
diff_pi = u - u_cut_neg
191213
E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi**2)
192214

193215
# Integration using masks
194-
energy = torch.where(
216+
return torch.where(
195217
is_singular_0,
196218
E_quad_0,
197219
torch.where(is_singular_pi, E_quad_pi, E_safe)
198220
)
199-
200-
return energy
201221

202222

203223
class StructKeepAnglePotentialv2:
@@ -209,12 +229,19 @@ class StructKeepAnglePotentialv2:
209229
but operates on the geometric centers of three specified atom fragments.
210230
211231
Singularity Handling Strategies:
212-
1. **Taylor Expansion**: For exactly linear/planar equilibrium geometries (F1-F2-F3 aligned).
213-
2. **Quadratic Extrapolation**: For general bent geometries forced into linearity,
214-
using Gauss-Newton approximation for Hessian stability.
232+
1. **Taylor Expansion (Physical Accuracy)**:
233+
Applied when the equilibrium angle :math:`\\theta_0` is EXACTLY 0 or :math:`\\pi`.
234+
Maintains accuracy for linear/planar equilibrium states.
235+
236+
2. **Quadratic Extrapolation (Numerical Stability)**:
237+
Applied to singularities (0 or 180 deg) that are NOT the equilibrium angle.
238+
This includes the antipodal pole for linear/planar molecules (e.g., theta ~ 180 when theta_0 ~ 0),
239+
ensuring the gradient never explodes even in extreme bending configurations.
215240
216241
Attributes:
217242
config (dict): Configuration dictionary containing potential parameters.
243+
THETA_CUT (float): Angle threshold (radians) to switch to extrapolation. Default is 1e-3.
244+
EPSILON_PARAM (float): Threshold (radians) to distinguish exactly linear/planar equilibrium. Default is 1e-9.
218245
"""
219246

220247
def __init__(self, **kwarg):
@@ -291,70 +318,98 @@ def get_centroid(key):
291318
u = torch.dot(vec1, vec2) / (norm1_2)
292319
u = torch.clamp(u, -1.0, 1.0)
293320

321+
# ========================================
294322
# 3. Singularity Handling Logic
323+
# ========================================
324+
325+
# Thresholds
295326
u_cut_pos = torch.cos(theta_cut_val)
296327
u_cut_neg = torch.cos(PI - theta_cut_val)
297328

298-
# --- BRANCH A: EXACTLY Linear Equilibrium ---
329+
# --- Helper for Quadratic Extrapolation ---
330+
def get_quad_params(th_cut):
331+
sin_cut = torch.sin(th_cut)
332+
dth_du = -1.0 / sin_cut
333+
334+
val = 0.5 * k * (th_cut - theta_0)**2
335+
336+
dE_dth = k * (th_cut - theta_0)
337+
d1 = dE_dth * dth_du
338+
339+
# Gauss-Newton Approximation
340+
d2 = k * (dth_du**2)
341+
342+
return val.detach(), d1.detach(), d2.detach()
343+
344+
# --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
299345
if torch.abs(theta_0) < epsilon_param:
346+
# Region 1: Singularity at Equilibrium (theta ~ 0) -> Taylor Expansion
300347
delta = 1.0 - u
301-
# Corrected Taylor expansion for theta^2
302348
theta_sq_taylor = delta * (2.0 + delta * (1.0/3.0 + delta * 8.0/45.0))
303349
E_taylor = 0.5 * k * theta_sq_taylor
304350

351+
# Region 2: Singularity at Opposite Pole (theta ~ pi) -> Quadratic Extrapolation
352+
theta_cut_pi = PI - theta_cut_val
353+
val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi)
354+
diff_pi = u - u_cut_neg
355+
E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi**2)
356+
357+
# Region 3: Normal -> Exact Analytical
305358
u_safe = torch.clamp(u, -1.0, u_cut_pos)
306359
theta_exact = torch.acos(u_safe)
307360
E_exact = 0.5 * k * (theta_exact ** 2)
308361

309-
return torch.where(u > u_cut_pos, E_taylor, E_exact)
362+
return torch.where(
363+
u > u_cut_pos,
364+
E_taylor,
365+
torch.where(u < u_cut_neg, E_quad_pi, E_exact)
366+
)
310367

311-
# --- BRANCH B: EXACTLY Planar Equilibrium ---
368+
# --- BRANCH B: EXACTLY Planar Equilibrium (theta_0 ~ pi) ---
312369
elif torch.abs(theta_0 - PI) < epsilon_param:
370+
# Region 1: Singularity at Equilibrium (theta ~ pi) -> Taylor Expansion
313371
delta = 1.0 + u
314372
diff_sq_taylor = delta * (2.0 + delta * (1.0/3.0 + delta * 8.0/45.0))
315373
E_taylor = 0.5 * k * diff_sq_taylor
316374

375+
# Region 2: Singularity at Opposite Pole (theta ~ 0) -> Quadratic Extrapolation
376+
val_0, d1_0, d2_0 = get_quad_params(theta_cut_val)
377+
diff_0 = u - u_cut_pos
378+
E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0**2)
379+
380+
# Region 3: Normal -> Exact Analytical
317381
u_safe = torch.clamp(u, u_cut_neg, 1.0)
318382
theta_exact = torch.acos(u_safe)
319383
E_exact = 0.5 * k * (theta_exact - theta_0) ** 2
320384

321-
return torch.where(u < u_cut_neg, E_taylor, E_exact)
385+
return torch.where(
386+
u < u_cut_neg,
387+
E_taylor,
388+
torch.where(u > u_cut_pos, E_quad_0, E_exact)
389+
)
322390

323391
# --- BRANCH C: General Angle ---
324392
else:
325393
is_singular_0 = (u > u_cut_pos)
326394
is_singular_pi = (u < u_cut_neg)
327-
is_safe = ~(is_singular_0 | is_singular_pi)
328395

329396
theta_safe = torch.acos(u)
330397
E_safe = 0.5 * k * (theta_safe - theta_0) ** 2
331398

332-
def get_quad_params(th_cut, u_bnd):
333-
sin_cut = torch.sin(th_cut)
334-
dth_du = -1.0 / sin_cut
335-
336-
val = 0.5 * k * (th_cut - theta_0)**2
337-
d1 = k * (th_cut - theta_0) * dth_du
338-
d2 = k * (dth_du**2) # Gauss-Newton Approximation
339-
return val.detach(), d1.detach(), d2.detach()
340-
341-
val_0, d1_0, d2_0 = get_quad_params(theta_cut_val, u_cut_pos)
399+
val_0, d1_0, d2_0 = get_quad_params(theta_cut_val)
342400
diff_0 = u - u_cut_pos
343401
E_quad_0 = val_0 + d1_0 * diff_0 + 0.5 * d2_0 * (diff_0**2)
344402

345403
theta_cut_pi = PI - theta_cut_val
346-
val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi, u_cut_neg)
404+
val_pi, d1_pi, d2_pi = get_quad_params(theta_cut_pi)
347405
diff_pi = u - u_cut_neg
348406
E_quad_pi = val_pi + d1_pi * diff_pi + 0.5 * d2_pi * (diff_pi**2)
349407

350-
energy = torch.where(
408+
return torch.where(
351409
is_singular_0,
352410
E_quad_0,
353411
torch.where(is_singular_pi, E_quad_pi, E_safe)
354-
)
355-
return energy
356-
357-
412+
)
358413
class StructKeepAnglePotentialAtomDistDependent:
359414
def __init__(self, **kwarg):
360415
self.config = kwarg

0 commit comments

Comments
 (0)