@@ -103,10 +103,7 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
103103 norm2 = torch .linalg .norm (vec2 )
104104
105105 # u = cos(theta)
106- # Add epsilon to denominator to prevent NaN if atoms overlap exactly
107- norm1_2 = norm1 * norm2
108- if norm1_2 < 1e-12 :
109- norm1_2 = norm1_2 + 1e-12
106+ norm1_2 = torch .clamp (norm1 * norm2 , min = 1e-12 )
110107
111108 u = torch .dot (vec1 , vec2 ) / (norm1_2 )
112109 u = torch .clamp (u , - 1.0 , 1.0 )
@@ -137,7 +134,7 @@ def get_quad_params(th_cut):
137134 # We ignore the d2th/du2 term to ensure positive curvature (convexity)
138135 d2 = k * (dth_du ** 2 )
139136
140- return val . detach () , d1 . detach () , d2 . detach ()
137+ return val , d1 , d2
141138
142139
143140 # --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
@@ -311,9 +308,7 @@ def get_centroid(key):
311308 norm1 = torch .linalg .norm (vec1 )
312309 norm2 = torch .linalg .norm (vec2 )
313310
314- norm1_2 = norm1 * norm2
315- if norm1_2 < 1e-12 :
316- norm1_2 = norm1_2 + 1e-12
311+ norm1_2 = torch .clamp (norm1 * norm2 , min = 1e-12 )
317312
318313 u = torch .dot (vec1 , vec2 ) / (norm1_2 )
319314 u = torch .clamp (u , - 1.0 , 1.0 )
@@ -339,7 +334,7 @@ def get_quad_params(th_cut):
339334 # Gauss-Newton Approximation
340335 d2 = k * (dth_du ** 2 )
341336
342- return val . detach () , d1 . detach () , d2 . detach ()
337+ return val , d1 , d2
343338
344339 # --- BRANCH A: EXACTLY Linear Equilibrium (theta_0 ~ 0) ---
345340 if torch .abs (theta_0 ) < epsilon_param :
0 commit comments