Skip to content

Commit c5bb605

Browse files
committed
Add pytest fixture for cleaning work directory
1 parent 2776b94 commit c5bb605

9 files changed

Lines changed: 56 additions & 65 deletions

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import shutil
2+
from pathlib import Path
3+
import pytest
4+
5+
@pytest.fixture
6+
def clean_tmp_dir(tmp_path):
7+
path = Path(tmp_path)
8+
9+
if path.exists():
10+
shutil.rmtree(path)
11+
12+
path.mkdir(parents=True, exist_ok=True)
13+
yield path
14+
15+
if path.exists():
16+
shutil.rmtree(path)

tests/test_solver/test_competitive_pinn.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,8 @@ def test_solver_test(problem, batch_size, compile):
113113

114114

115115
@pytest.mark.parametrize("problem", [problem, inverse_problem])
116-
def test_train_load_restore(problem):
117-
dir = "tests/test_solver/tmp"
118-
problem = problem
116+
def test_train_load_restore(clean_tmp_dir, problem):
117+
dir = clean_tmp_dir
119118
solver = CompPINN(problem=problem, model=model)
120119
trainer = Trainer(
121120
solver=solver,
@@ -150,9 +149,4 @@ def test_train_load_restore(problem):
150149
)
151150
torch.testing.assert_close(
152151
new_solver.forward(test_pts), solver.forward(test_pts)
153-
)
154-
155-
# rm directories
156-
import shutil
157-
158-
shutil.rmtree("tests/test_solver/tmp")
152+
)

tests/test_solver/test_ensemble_pinn.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def test_solver_test(batch_size, compile):
107107
)
108108

109109

110-
def test_train_load_restore():
111-
dir = "tests/test_solver/tmp"
110+
def test_train_load_restore(clean_tmp_dir):
111+
dir = clean_tmp_dir
112112
solver = DeepEnsemblePINN(models=models, problem=problem)
113113
trainer = Trainer(
114114
solver=solver,
@@ -141,8 +141,3 @@ def test_train_load_restore():
141141
torch.testing.assert_close(
142142
new_solver.forward(test_pts), solver.forward(test_pts)
143143
)
144-
145-
# rm directories
146-
import shutil
147-
148-
shutil.rmtree("tests/test_solver/tmp")

tests/test_solver/test_ensemble_supervised_solver.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def test_solver_test_graph(batch_size, use_lt):
235235
trainer.test()
236236

237237

238-
def test_train_load_restore():
239-
dir = "tests/test_solver/tmp/"
238+
def test_train_load_restore(clean_tmp_dir):
239+
dir = clean_tmp_dir
240240
problem = LabelTensorProblem()
241241
solver = DeepEnsembleSupervisedSolver(problem=problem, models=models)
242242
trainer = Trainer(
@@ -270,8 +270,3 @@ def test_train_load_restore():
270270
torch.testing.assert_close(
271271
new_solver.forward(test_pts), solver.forward(test_pts)
272272
)
273-
274-
# rm directories
275-
import shutil
276-
277-
shutil.rmtree("tests/test_solver/tmp")

tests/test_solver/test_garom.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,8 @@ def test_solver_test(batch_size, compile):
163163
)
164164

165165

166-
def test_train_load_restore():
167-
dir = "tests/test_solver/tmp/"
168-
problem = TensorProblem()
166+
def test_train_load_restore(clean_tmp_dir):
167+
dir = clean_tmp_dir
169168
solver = GAROM(
170169
problem=TensorProblem(),
171170
generator=Generator(),
@@ -201,8 +200,3 @@ def test_train_load_restore():
201200
test_pts = torch.rand(20, 1)
202201
assert new_solver.forward(test_pts).shape == (20, 2)
203202
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
204-
205-
# rm directories
206-
import shutil
207-
208-
shutil.rmtree("tests/test_solver/tmp")

tests/test_solver/test_gradient_pinn.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ def test_solver_test(problem, batch_size, compile):
119119

120120

121121
@pytest.mark.parametrize("problem", [problem, inverse_problem])
122-
def test_train_load_restore(problem):
123-
dir = "tests/test_solver/tmp"
124-
problem = problem
122+
def test_train_load_restore(clean_tmp_dir, problem):
123+
dir = clean_tmp_dir
125124
solver = GradientPINN(model=model, problem=problem)
126125
trainer = Trainer(
127126
solver=solver,
@@ -157,8 +156,3 @@ def test_train_load_restore(problem):
157156
torch.testing.assert_close(
158157
new_solver.forward(test_pts), solver.forward(test_pts)
159158
)
160-
161-
# rm directories
162-
import shutil
163-
164-
shutil.rmtree("tests/test_solver/tmp")

tests/test_solver/test_pinn.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ def test_solver_test(problem, batch_size, compile):
101101

102102

103103
@pytest.mark.parametrize("problem", [problem, inverse_problem])
104-
def test_train_load_restore(problem):
105-
dir = "tests/test_solver/tmp"
106-
problem = problem
104+
def test_train_load_restore(clean_tmp_dir, problem):
105+
dir = clean_tmp_dir
107106
solver = PINN(model=model, problem=problem)
108107
trainer = Trainer(
109108
solver=solver,
@@ -116,6 +115,8 @@ def test_train_load_restore(problem):
116115
default_root_dir=dir,
117116
)
118117
trainer.train()
118+
import os
119+
print(os.listdir(f"{dir}/lightning_logs/version_0/checkpoints/"))
119120

120121
# restore
121122
new_trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
@@ -124,6 +125,7 @@ def test_train_load_restore(problem):
124125
+ "epoch=4-step=5.ckpt"
125126
)
126127

128+
127129
# loading
128130
new_solver = PINN.load_from_checkpoint(
129131
f"{dir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt",
@@ -136,9 +138,4 @@ def test_train_load_restore(problem):
136138
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
137139
torch.testing.assert_close(
138140
new_solver.forward(test_pts), solver.forward(test_pts)
139-
)
140-
141-
# rm directories
142-
import shutil
143-
144-
shutil.rmtree("tests/test_solver/tmp")
141+
)

tests/test_solver/test_rba_pinn.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,8 @@ def test_solver_test(problem, batch_size, loss, compile):
122122

123123

124124
@pytest.mark.parametrize("problem", [problem, inverse_problem])
125-
def test_train_load_restore(problem):
126-
dir = "tests/test_solver/tmp"
127-
problem = problem
125+
def test_train_load_restore(clean_tmp_dir, problem):
126+
dir = clean_tmp_dir
128127
solver = RBAPINN(model=model, problem=problem)
129128
trainer = Trainer(
130129
solver=solver,
@@ -159,9 +158,4 @@ def test_train_load_restore(problem):
159158
)
160159
torch.testing.assert_close(
161160
new_solver.forward(test_pts), solver.forward(test_pts)
162-
)
163-
164-
# rm directories
165-
import shutil
166-
167-
shutil.rmtree("tests/test_solver/tmp")
161+
)

tests/test_solver/test_supervised_solver.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,25 @@ def test_solver_test_graph(batch_size, use_lt):
212212
trainer.test()
213213

214214

215-
def test_train_load_restore():
216-
dir = "tests/test_solver/tmp/"
215+
import shutil
216+
from pathlib import Path
217+
import pytest
218+
219+
@pytest.fixture
220+
def clean_tmp_dir():
221+
path = Path("tests/test_solver/tmp/")
222+
223+
if path.exists():
224+
shutil.rmtree(path)
225+
226+
path.mkdir(parents=True, exist_ok=True)
227+
yield path
228+
229+
if path.exists():
230+
shutil.rmtree(path)
231+
232+
def test_train_load_restore(clean_tmp_dir):
233+
dir = clean_tmp_dir
217234
problem = LabelTensorProblem()
218235
solver = SupervisedSolver(problem=problem, model=model)
219236
trainer = Trainer(
@@ -247,9 +264,4 @@ def test_train_load_restore():
247264
assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape
248265
torch.testing.assert_close(
249266
new_solver.forward(test_pts), solver.forward(test_pts)
250-
)
251-
252-
# rm directories
253-
import shutil
254-
255-
shutil.rmtree("tests/test_solver/tmp")
267+
)

0 commit comments

Comments
 (0)