|
30 | 30 | import mkl_random as rnd |
31 | 31 | from numpy.testing import ( |
32 | 32 | assert_, assert_raises, assert_equal, |
33 | | - assert_warns, suppress_warnings) |
| 33 | + suppress_warnings, assert_no_warnings) |
34 | 34 | import sys |
35 | 35 | import warnings |
36 | 36 |
|
@@ -720,24 +720,57 @@ def test_randomdist_multivariate_normal(randomdist): |
720 | 720 | # Hmm... not even symmetric. |
721 | 721 | cov = [[1, 0], [1, 0]] |
722 | 722 | size = (3, 2) |
723 | | - actual = rnd.multivariate_normal(mean, cov, size) |
| 723 | + # ignore RuntimeWarning from non-positive-semidefinite covariance |
| 724 | + with warnings.catch_warnings(): |
| 725 | + warnings.simplefilter("ignore", RuntimeWarning) |
| 726 | + actual = rnd.multivariate_normal(mean, cov, size) |
724 | 727 | desired = np.array([[[-2.42282709811266, 10.0], |
725 | | - [1.2267795840027274, 10.0]], |
| 728 | + [1.2267795840027274, 10.0]], |
726 | 729 | [[0.06813924868067336, 10.0], |
727 | | - [1.001190462507746, 10.0]], |
| 730 | + [1.001190462507746, 10.0]], |
728 | 731 | [[-1.74157261455869, 10.0], |
729 | | - [1.0400952859037553, 10.0]]]) |
| 732 | + [1.0400952859037553, 10.0]]]) |
730 | 733 | np.testing.assert_allclose(actual, desired, atol=1e-10, rtol=1e-10) |
731 | 734 |
|
732 | 735 | # Check for default size, was raising deprecation warning |
733 | | - actual = rnd.multivariate_normal(mean, cov) |
| 736 | + # ignore RuntimeWarning from non-positive-semidefinite covariance |
| 737 | + with warnings.catch_warnings(): |
| 738 | + warnings.simplefilter("ignore", RuntimeWarning) |
| 739 | + actual = rnd.multivariate_normal(mean, cov) |
734 | 740 | desired = np.array([1.0579899448949994, 10.0]) |
735 | 741 | np.testing.assert_allclose(actual, desired, atol=1e-10, rtol=1e-10) |
736 | 742 |
|
737 | | - # Check that non positive-semidefinite covariance raises warning |
| 743 | + # Check that non positive-semidefinite covariance warns with |
| 744 | + # RuntimeWarning |
738 | 745 | mean = [0, 0] |
739 | | - cov = [[1, 1 + 1e-10], [1 + 1e-10, 1]] |
740 | | - assert_warns(RuntimeWarning, rnd.multivariate_normal, mean, cov) |
| 746 | + cov = [[1, 2], [2, 1]] |
| 747 | + pytest.warns(RuntimeWarning, rnd.multivariate_normal, mean, cov) |
| 748 | + |
| 749 | + # and that it doesn't warn with RuntimeWarning check_valid='ignore' |
| 750 | + assert_no_warnings( |
| 751 | + rnd.multivariate_normal, mean, cov, check_valid="ignore" |
| 752 | + ) |
| 753 | + |
| 754 | + # and that it raises with RuntimeWarning check_valid='raises' |
| 755 | + assert_raises( |
| 756 | + ValueError, rnd.multivariate_normal, mean, cov, check_valid="raise" |
| 757 | + ) |
| 758 | + |
| 759 | + cov = np.array([[1, 0.1], [0.1, 1]], dtype=np.float32) |
| 760 | + with warnings.catch_warnings(): |
| 761 | + warnings.simplefilter("error", RuntimeWarning) |
| 762 | + rnd.multivariate_normal(mean, cov) |
| 763 | + |
| 764 | + mu = np.zeros(2) |
| 765 | + cov = np.eye(2) |
| 766 | + assert_raises( |
| 767 | + ValueError, rnd.multivariate_normal, mean, cov, check_valid="other" |
| 768 | + ) |
| 769 | + assert_raises( |
| 770 | + ValueError, rnd.multivariate_normal, np.zeros((2, 1, 1)), cov |
| 771 | + ) |
| 772 | + assert_raises(ValueError, rnd.multivariate_normal, mu, np.empty((3, 2))) |
| 773 | + assert_raises(ValueError, rnd.multivariate_normal, mu, np.eye(3)) |
741 | 774 |
|
742 | 775 |
|
743 | 776 | def test_randomdist_multinormal_cholesky(randomdist): |
|
0 commit comments