Skip to content

Commit 946a545

Browse files
committed
update random tests for new behavior and keywords in multivariate_normal
now ignore irrelevant RuntimeWarnings and align with the test in NumPy's test suite
1 parent f76067e commit 946a545

1 file changed

Lines changed: 42 additions & 9 deletions

File tree

mkl_random/tests/test_random.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
import mkl_random as rnd
3131
from numpy.testing import (
3232
assert_, assert_raises, assert_equal,
33-
assert_warns, suppress_warnings)
33+
suppress_warnings, assert_no_warnings)
3434
import sys
3535
import warnings
3636

@@ -720,24 +720,57 @@ def test_randomdist_multivariate_normal(randomdist):
720720
# Hmm... not even symmetric.
721721
cov = [[1, 0], [1, 0]]
722722
size = (3, 2)
723-
actual = rnd.multivariate_normal(mean, cov, size)
723+
# ignore RuntimeWarning from non-positive-semidefinite covariance
724+
with warnings.catch_warnings():
725+
warnings.simplefilter("ignore", RuntimeWarning)
726+
actual = rnd.multivariate_normal(mean, cov, size)
724727
desired = np.array([[[-2.42282709811266, 10.0],
725-
[1.2267795840027274, 10.0]],
728+
[1.2267795840027274, 10.0]],
726729
[[0.06813924868067336, 10.0],
727-
[1.001190462507746, 10.0]],
730+
[1.001190462507746, 10.0]],
728731
[[-1.74157261455869, 10.0],
729-
[1.0400952859037553, 10.0]]])
732+
[1.0400952859037553, 10.0]]])
730733
np.testing.assert_allclose(actual, desired, atol=1e-10, rtol=1e-10)
731734

732735
# Check for default size, was raising deprecation warning
733-
actual = rnd.multivariate_normal(mean, cov)
736+
# ignore RuntimeWarning from non-positive-semidefinite covariance
737+
with warnings.catch_warnings():
738+
warnings.simplefilter("ignore", RuntimeWarning)
739+
actual = rnd.multivariate_normal(mean, cov)
734740
desired = np.array([1.0579899448949994, 10.0])
735741
np.testing.assert_allclose(actual, desired, atol=1e-10, rtol=1e-10)
736742

737-
# Check that non positive-semidefinite covariance raises warning
743+
# Check that non positive-semidefinite covariance warns with
744+
# RuntimeWarning
738745
mean = [0, 0]
739-
cov = [[1, 1 + 1e-10], [1 + 1e-10, 1]]
740-
assert_warns(RuntimeWarning, rnd.multivariate_normal, mean, cov)
746+
cov = [[1, 2], [2, 1]]
747+
pytest.warns(RuntimeWarning, rnd.multivariate_normal, mean, cov)
748+
749+
# and that it doesn't warn with RuntimeWarning check_valid='ignore'
750+
assert_no_warnings(
751+
rnd.multivariate_normal, mean, cov, check_valid="ignore"
752+
)
753+
754+
# and that it raises with RuntimeWarning check_valid='raises'
755+
assert_raises(
756+
ValueError, rnd.multivariate_normal, mean, cov, check_valid="raise"
757+
)
758+
759+
cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32)
760+
with warnings.catch_warnings():
761+
warnings.simplefilter("error", RuntimeWarning)
762+
rnd.multivariate_normal(mean, cov)
763+
764+
mu = np.zeros(2)
765+
cov = np.eye(2)
766+
assert_raises(
767+
ValueError, rnd.multivariate_normal, mean, cov, check_valid="other"
768+
)
769+
assert_raises(
770+
ValueError, rnd.multivariate_normal, np.zeros((2, 1, 1)), cov
771+
)
772+
assert_raises(ValueError, rnd.multivariate_normal, mu, np.empty((3, 2)))
773+
assert_raises(ValueError, rnd.multivariate_normal, mu, np.eye(3))
741774

742775

743776
def test_randomdist_multinormal_cholesky(randomdist):

0 commit comments

Comments
 (0)