@@ -114,15 +114,16 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
114114 # Normalize vectors safely
115115 # We add a small epsilon to denominator to avoid NaN in the "linear" branch,
116116 # even though we will overwrite the result with 0 later.
117- n1_norm = torch .sqrt (n1_sq_norm )
118- n2_norm = torch .sqrt (n2_sq_norm )
117+ n1_norm = torch .clamp ( torch . sqrt (n1_sq_norm ), min = 1e-12 )
118+ n2_norm = torch .clamp ( torch . sqrt (n2_sq_norm ), min = 1e-12 )
119119
120- n1_hat = n1 / (n1_norm .unsqueeze (- 1 ) + 1e-12 )
121- n2_hat = n2 / (n2_norm .unsqueeze (- 1 ) + 1e-12 )
120+
121+ n1_hat = n1 / (n1_norm .unsqueeze (- 1 ))
122+ n2_hat = n2 / (n2_norm .unsqueeze (- 1 ))
122123
123124 # Normalize b2 to define the reference frame for sign
124- b2_norm = torch .linalg .norm (b2 )
125- b2_hat = b2 / (b2_norm .unsqueeze (- 1 ) + 1e-12 )
125+ b2_norm = torch .clamp ( torch . linalg .norm (b2 ), min = 1e-12 )
126+ b2_hat = b2 / (b2_norm .unsqueeze (- 1 ))
126127
127128 # ========================================
128129 # 4. Angle Calculation (atan2)
@@ -259,14 +260,14 @@ def get_indices(key):
259260 is_linear = (n1_sq_norm < self .COLLINEAR_CUT_SQ ) | (n2_sq_norm < self .COLLINEAR_CUT_SQ )
260261
261262 # Safe normalization
262- n1_norm = torch .sqrt (n1_sq_norm )
263- n2_norm = torch .sqrt (n2_sq_norm )
263+ n1_norm = torch .clamp ( torch . sqrt (n1_sq_norm ), min = 1e-12 )
264+ n2_norm = torch .clamp ( torch . sqrt (n2_sq_norm ), min = 1e-12 )
264265
265- n1_hat = n1 / (n1_norm .unsqueeze (- 1 ) + 1e-12 )
266- n2_hat = n2 / (n2_norm .unsqueeze (- 1 ) + 1e-12 )
266+ n1_hat = n1 / (n1_norm .unsqueeze (- 1 ))
267+ n2_hat = n2 / (n2_norm .unsqueeze (- 1 ))
267268
268- b2_norm = torch .linalg .norm (b2 )
269- b2_hat = b2 / (b2_norm .unsqueeze (- 1 ) + 1e-12 )
269+ b2_norm = torch .clamp ( torch . linalg .norm (b2 ), min = 1e-12 )
270+ b2_hat = b2 / (b2_norm .unsqueeze (- 1 ))
270271
271272 # ========================================
272273 # 4. Angle Calculation (atan2)
0 commit comments