Skip to content

Commit 060c7cd

Browse files
authored
Add files via upload
1 parent 62a7a59 commit 060c7cd

3 files changed

Lines changed: 20 additions & 22 deletions

File tree

multioptpy/Potential/keep_dihedral_angle_potential.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,16 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
114114
# Normalize vectors safely
115115
# We add a small epsilon to denominator to avoid NaN in the "linear" branch,
116116
# even though we will overwrite the result with 0 later.
117-
n1_norm = torch.sqrt(n1_sq_norm)
118-
n2_norm = torch.sqrt(n2_sq_norm)
117+
n1_norm = torch.clamp(torch.sqrt(n1_sq_norm), min=1e-12)
118+
n2_norm = torch.clamp(torch.sqrt(n2_sq_norm), min=1e-12)
119119

120-
n1_hat = n1 / (n1_norm.unsqueeze(-1) + 1e-12)
121-
n2_hat = n2 / (n2_norm.unsqueeze(-1) + 1e-12)
120+
121+
n1_hat = n1 / (n1_norm.unsqueeze(-1))
122+
n2_hat = n2 / (n2_norm.unsqueeze(-1))
122123

123124
# Normalize b2 to define the reference frame for sign
124-
b2_norm = torch.linalg.norm(b2)
125-
b2_hat = b2 / (b2_norm.unsqueeze(-1) + 1e-12)
125+
b2_norm = torch.clamp(torch.linalg.norm(b2), min=1e-12)
126+
b2_hat = b2 / (b2_norm.unsqueeze(-1))
126127

127128
# ========================================
128129
# 4. Angle Calculation (atan2)
@@ -259,14 +260,14 @@ def get_indices(key):
259260
is_linear = (n1_sq_norm < self.COLLINEAR_CUT_SQ) | (n2_sq_norm < self.COLLINEAR_CUT_SQ)
260261

261262
# Safe normalization
262-
n1_norm = torch.sqrt(n1_sq_norm)
263-
n2_norm = torch.sqrt(n2_sq_norm)
263+
n1_norm = torch.clamp(torch.sqrt(n1_sq_norm), min=1e-12)
264+
n2_norm = torch.clamp(torch.sqrt(n2_sq_norm), min=1e-12)
264265

265-
n1_hat = n1 / (n1_norm.unsqueeze(-1) + 1e-12)
266-
n2_hat = n2 / (n2_norm.unsqueeze(-1) + 1e-12)
266+
n1_hat = n1 / (n1_norm.unsqueeze(-1))
267+
n2_hat = n2 / (n2_norm.unsqueeze(-1))
267268

268-
b2_norm = torch.linalg.norm(b2)
269-
b2_hat = b2 / (b2_norm.unsqueeze(-1) + 1e-12)
269+
b2_norm = torch.clamp(torch.linalg.norm(b2), min=1e-12)
270+
b2_hat = b2 / (b2_norm.unsqueeze(-1))
270271

271272
# ========================================
272273
# 4. Angle Calculation (atan2)

multioptpy/Potential/keep_outofplain_angle_potential.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
100100

101101
# Safe normalization
102102
n_norm = torch.sqrt(n_sq_norm)
103-
n_hat = n / (n_norm.unsqueeze(-1) + 1e-12)
103+
n_hat_demon = torch.clamp(n_norm.unsqueeze(-1), min=1e-12)
104+
n_hat = n / n_hat_demon
104105

105106
# ========================================
106107
# 4. Angle Calculation (Robust atan2)
@@ -237,7 +238,8 @@ def get_indices(key):
237238

238239
# Safe normalization
239240
n_norm = torch.sqrt(n_sq_norm)
240-
n_hat = n / (n_norm.unsqueeze(-1) + 1e-12)
241+
n_hat_demon = torch.clamp(n_norm.unsqueeze(-1), min=1e-12)
242+
n_hat = n / n_hat_demon
241243

242244
# ========================================
243245
# 4. Angle Calculation (Robust atan2)

multioptpy/Potential/keep_potential.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,8 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
5050
vec = geom_num_list[idx1] - geom_num_list[idx2]
5151

5252
# 3. Robust Distance Calculation
53-
# Standard torch.linalg.norm gradient is undefined at 0.
54-
# Use sqrt(sum(x^2) + epsilon) for safety.
55-
dist = torch.sqrt(torch.sum(vec**2))
56-
if dist < 1e-12:
57-
dist = dist + 1e-12
53+
54+
dist = torch.clamp(torch.sqrt(torch.sum(vec**2)), min=1e-12)
5855

5956
# 4. Energy Calculation
6057
# E = 0.5 * k * (r - r0)^2
@@ -111,9 +108,7 @@ def calc_energy(self, geom_num_list, bias_pot_params=[]):
111108

112109
# 3. Robust Distance Calculation
113110
vec = fragm_1_center - fragm_2_center
114-
distance = torch.sqrt(torch.sum(vec**2))
115-
if distance < 1e-12:
116-
distance = distance + 1e-12
111+
distance = torch.clamp(torch.sqrt(torch.sum(vec**2)), min=1e-12)
117112

118113
# 4. Energy
119114
energy = 0.5 * k * (distance - r0) ** 2

0 commit comments

Comments
 (0)