2727import pytest
2828
2929import mkl_random
30+ import mkl_random .interfaces .numpy_random as _nrand
3031
3132
3233def test_is_patched ():
@@ -53,9 +54,9 @@ def test_patch_and_restore():
5354 assert np .random .randint is not orig_randint
5455 assert np .random .RandomState is not orig_RandomState
5556
56- # Check that they are from mkl_random
57- assert np .random .normal is mkl_random .normal
58- assert np .random .RandomState is mkl_random .RandomState
57+ # Check that they are from mkl_random interface module
58+ assert np .random .normal is _nrand .normal
59+ assert np .random .RandomState is _nrand .RandomState
5960
6061 finally :
6162 mkl_random .restore_numpy_random ()
@@ -135,10 +136,10 @@ def test_patch_redundant_patching():
135136 mkl_random .patch_numpy_random (np )
136137 mkl_random .patch_numpy_random (np )
137138 assert mkl_random .is_patched ()
138- assert np .random .normal is mkl_random .normal
139+ assert np .random .normal is _nrand .normal
139140 mkl_random .restore_numpy_random ()
140141 assert mkl_random .is_patched ()
141- assert np .random .normal is mkl_random .normal
142+ assert np .random .normal is _nrand .normal
142143 mkl_random .restore_numpy_random ()
143144 assert not mkl_random .is_patched ()
144145 assert np .random .normal is orig_normal
0 commit comments