Skip to content

Commit 044b666

Browse files
committed
clean up test_random
cleans up visual indentation and various linter/style mistakes
1 parent 946a545 commit 044b666

1 file changed

Lines changed: 90 additions & 89 deletions

File tree

mkl_random/tests/test_random.py

Lines changed: 90 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -39,23 +39,24 @@
3939

4040
def test_zero_scalar_seed():
4141
evs_zero_seed = {
42-
'MT19937' : 844, 'SFMT19937' : 857,
43-
'WH' : 0, 'MT2203' : 890,
44-
'MCG31' : 0, 'R250' : 229,
45-
'MRG32K3A' : 0, 'MCG59' : 0 }
42+
'MT19937': 844, 'SFMT19937': 857,
43+
'WH': 0, 'MT2203': 890,
44+
'MCG31': 0, 'R250': 229,
45+
'MRG32K3A': 0, 'MCG59': 0}
4646
for brng_algo in evs_zero_seed:
47-
s = rnd.MKLRandomState(0, brng = brng_algo)
47+
s = rnd.MKLRandomState(0, brng=brng_algo)
4848
assert_equal(s.get_state()[0], brng_algo)
4949
assert_equal(s.randint(1000), evs_zero_seed[brng_algo])
5050

51+
5152
def test_max_scalar_seed():
5253
evs_max_seed = {
53-
'MT19937' : 635, 'SFMT19937' : 25,
54-
'WH' : 100, 'MT2203' : 527,
55-
'MCG31' : 0, 'R250' : 229,
56-
'MRG32K3A' : 961, 'MCG59' : 0 }
54+
'MT19937': 635, 'SFMT19937': 25,
55+
'WH': 100, 'MT2203': 527,
56+
'MCG31': 0, 'R250': 229,
57+
'MRG32K3A': 961, 'MCG59': 0}
5758
for brng_algo in evs_max_seed:
58-
s = rnd.MKLRandomState(4294967295, brng = brng_algo)
59+
s = rnd.MKLRandomState(4294967295, brng=brng_algo)
5960
assert_equal(s.get_state()[0], brng_algo)
6061
assert_equal(s.randint(1000), evs_max_seed[brng_algo])
6162

@@ -130,11 +131,10 @@ def test_size():
130131
assert_equal(rnd.multinomial(1, p, np.uint32(1)).shape, (1, 2))
131132
assert_equal(rnd.multinomial(1, p, [2, 2]).shape, (2, 2, 2))
132133
assert_equal(rnd.multinomial(1, p, (2, 2)).shape, (2, 2, 2))
133-
assert_equal(rnd.multinomial(1, p, np.array((2, 2))).shape,
134-
(2, 2, 2))
134+
assert_equal(rnd.multinomial(1, p, np.array((2, 2))).shape, (2, 2, 2))
135+
136+
pytest.raises(TypeError, rnd.multinomial, 1, p, np.float64(1))
135137

136-
pytest.raises(TypeError, rnd.multinomial, 1, p,
137-
np.float64(1))
138138

139139
class RngState(NamedTuple):
140140
seed: int
@@ -203,8 +203,8 @@ def test_set_state_negative_binomial(rng_state):
203203

204204

205205
class RandIntData(NamedTuple):
206-
rfunc : object
207-
itype : list
206+
rfunc: object
207+
itype: list
208208

209209

210210
@pytest.fixture
@@ -262,14 +262,14 @@ def test_randint_repeatability(randint):
262262
# in the range [0, 6) for all but np.bool, where the range
263263
# is [0, 2). Hashes are for little endian numbers.
264264
tgt = {'bool': '4fee98a6885457da67c39331a9ec336f',
265-
'int16': '80a5ff69c315ab6f80b03da1d570b656',
266-
'int32': '15a3c379b6c7b0f296b162194eab68bc',
267-
'int64': 'ea9875f9334c2775b00d4976b85a1458',
268-
'int8': '0f56333af47de94930c799806158a274',
269-
'uint16': '80a5ff69c315ab6f80b03da1d570b656',
270-
'uint32': '15a3c379b6c7b0f296b162194eab68bc',
271-
'uint64': 'ea9875f9334c2775b00d4976b85a1458',
272-
'uint8': '0f56333af47de94930c799806158a274'}
265+
'int16': '80a5ff69c315ab6f80b03da1d570b656',
266+
'int32': '15a3c379b6c7b0f296b162194eab68bc',
267+
'int64': 'ea9875f9334c2775b00d4976b85a1458',
268+
'int8': '0f56333af47de94930c799806158a274',
269+
'uint16': '80a5ff69c315ab6f80b03da1d570b656',
270+
'uint32': '15a3c379b6c7b0f296b162194eab68bc',
271+
'uint64': 'ea9875f9334c2775b00d4976b85a1458',
272+
'uint8': '0f56333af47de94930c799806158a274'}
273273

274274
for dt in randint.itype[1:]:
275275
rnd.seed(1234, brng='MT19937')
@@ -306,12 +306,12 @@ def test_randint_respect_dtype_singleton(randint):
306306
# gh-7284: Ensure that we get Python data types
307307
sample = randint.rfunc(lbnd, ubnd, dtype=dt)
308308
assert not hasattr(sample, 'dtype')
309-
assert (type(sample) == dt)
309+
assert (type(sample) is dt)
310310

311311

312312
class RandomDistData(NamedTuple):
313-
seed : int
314-
brng : str
313+
seed: int
314+
brng: str
315315

316316

317317
@pytest.fixture
@@ -321,9 +321,10 @@ def randomdist():
321321

322322
# Make sure the random distribution returns the correct value for a
323323
# given seed. Low value of decimal argument is intended, since functional
324-
# transformations's implementation or approximations thereof used to produce non-uniform
325-
# random variates can vary across platforms, yet be statistically indistinguishable to the end user,
326-
# that is no computationally feasible statistical experiment can detect the difference.
324+
# transformations's implementation or approximations thereof used to produce
325+
# non-uniform random variates can vary across platforms, yet be statistically
326+
# indistinguishable to the end user, that is no computationally feasible
327+
# statistical experiment can detect the difference.
327328

328329
def test_randomdist_rand(randomdist):
329330
rnd.seed(randomdist.seed, brng=randomdist.brng)
@@ -369,8 +370,7 @@ def test_random_integers_max_int():
369370
# to generate this integer.
370371
with suppress_warnings() as sup:
371372
w = sup.record(DeprecationWarning)
372-
actual = rnd.random_integers(np.iinfo('l').max,
373-
np.iinfo('l').max)
373+
actual = rnd.random_integers(np.iinfo('l').max, np.iinfo('l').max)
374374
assert len(w) == 1
375375
desired = np.iinfo('l').max
376376
np.testing.assert_equal(actual, desired)
@@ -381,14 +381,17 @@ def test_random_integers_deprecated():
381381
warnings.simplefilter("error", DeprecationWarning)
382382

383383
# DeprecationWarning raised with high == None
384-
assert_raises(DeprecationWarning,
385-
rnd.random_integers,
386-
np.iinfo('l').max)
384+
assert_raises(
385+
DeprecationWarning, rnd.random_integers, np.iinfo('l').max
386+
)
387387

388388
# DeprecationWarning raised with high != None
389-
assert_raises(DeprecationWarning,
390-
rnd.random_integers,
391-
np.iinfo('l').max, np.iinfo('l').max)
389+
assert_raises(
390+
DeprecationWarning,
391+
rnd.random_integers,
392+
np.iinfo('l').max,
393+
np.iinfo('l').max
394+
)
392395

393396

394397
def test_randomdist_random_sample(randomdist):
@@ -414,13 +417,6 @@ def test_randomdist_choice_nonuniform_replace(randomdist):
414417
np.testing.assert_array_equal(actual, desired)
415418

416419

417-
def test_randomdist_choice_nonuniform_replace(randomdist):
418-
rnd.seed(randomdist.seed, brng=randomdist.brng)
419-
actual = rnd.choice(4, 4, p=[0.4, 0.4, 0.1, 0.1])
420-
desired = np.array([3, 0, 0, 1])
421-
np.testing.assert_array_equal(actual, desired)
422-
423-
424420
def test_randomdist_choice_uniform_noreplace(randomdist):
425421
rnd.seed(randomdist.seed, brng=randomdist.brng)
426422
actual = rnd.choice(4, 3, replace=False)
@@ -430,8 +426,7 @@ def test_randomdist_choice_uniform_noreplace(randomdist):
430426

431427
def test_randomdist_choice_nonuniform_noreplace(randomdist):
432428
rnd.seed(randomdist.seed, brng=randomdist.brng)
433-
actual = rnd.choice(4, 3, replace=False,
434-
p=[0.1, 0.3, 0.5, 0.1])
429+
actual = rnd.choice(4, 3, replace=False, p=[0.1, 0.3, 0.5, 0.1])
435430
desired = np.array([3, 0, 1])
436431
np.testing.assert_array_equal(actual, desired)
437432

@@ -449,14 +444,16 @@ def test_choice_exceptions():
449444
pytest.raises(ValueError, sample, 3., 3)
450445
pytest.raises(ValueError, sample, [[1, 2], [3, 4]], 3)
451446
pytest.raises(ValueError, sample, [], 3)
452-
pytest.raises(ValueError, sample, [1, 2, 3, 4], 3,
453-
p=[[0.25, 0.25], [0.25, 0.25]])
447+
pytest.raises(
448+
ValueError, sample, [1, 2, 3, 4], 3, p=[[0.25, 0.25], [0.25, 0.25]]
449+
)
454450
pytest.raises(ValueError, sample, [1, 2], 3, p=[0.4, 0.4, 0.2])
455451
pytest.raises(ValueError, sample, [1, 2], 3, p=[1.1, -0.1])
456452
pytest.raises(ValueError, sample, [1, 2], 3, p=[0.4, 0.4])
457453
pytest.raises(ValueError, sample, [1, 2, 3], 4, replace=False)
458-
pytest.raises(ValueError, sample, [1, 2, 3], 2, replace=False,
459-
p=[1, 0, 0])
454+
pytest.raises(
455+
ValueError, sample, [1, 2, 3], 2, replace=False, p=[1, 0, 0]
456+
)
460457

461458

462459
def test_choice_return_shape():
@@ -507,18 +504,19 @@ def test_randomdist_shuffle(randomdist):
507504
# Test lists, arrays (of various dtypes), and multidimensional versions
508505
# of both, c-contiguous or not:
509506
for conv in [lambda x: np.array([]),
510-
lambda x: x,
511-
lambda x: np.asarray(x).astype(np.int8),
512-
lambda x: np.asarray(x).astype(np.float32),
513-
lambda x: np.asarray(x).astype(np.complex64),
514-
lambda x: np.asarray(x).astype(object),
515-
lambda x: [(i, i) for i in x],
516-
lambda x: np.asarray([[i, i] for i in x]),
517-
lambda x: np.vstack([x, x]).T,
518-
# gh-4270
519-
lambda x: np.asarray([(i, i) for i in x],
520-
[("a", object, (1,)),
521-
("b", np.int32, (1,))])]:
507+
lambda x: x,
508+
lambda x: np.asarray(x).astype(np.int8),
509+
lambda x: np.asarray(x).astype(np.float32),
510+
lambda x: np.asarray(x).astype(np.complex64),
511+
lambda x: np.asarray(x).astype(object),
512+
lambda x: [(i, i) for i in x],
513+
lambda x: np.asarray([[i, i] for i in x]),
514+
lambda x: np.vstack([x, x]).T,
515+
# gh-4270
516+
lambda x: np.asarray(
517+
[(i, i) for i in x],
518+
[("a", object, (1,)),
519+
("b", np.int32, (1,))])]:
522520
rnd.seed(randomdist.seed, brng=randomdist.brng)
523521
alist = conv([1, 2, 3, 4, 5, 6, 7, 8, 9, 0])
524522
rnd.shuffle(alist)
@@ -529,7 +527,7 @@ def test_randomdist_shuffle(randomdist):
529527

530528
def test_shuffle_masked():
531529
# gh-3263
532-
a = np.ma.masked_values(np.reshape(range(20), (5,4)) % 3 - 1, -1)
530+
a = np.ma.masked_values(np.reshape(range(20), (5, 4)) % 3 - 1, -1)
533531
b = np.ma.masked_values(np.arange(20) % 3 - 1, -1)
534532
a_orig = a.copy()
535533
b_orig = b.copy()
@@ -563,8 +561,8 @@ def test_randomdist_chisquare(randomdist):
563561
rnd.seed(randomdist.seed, brng=randomdist.brng)
564562
actual = rnd.chisquare(50, size=(3, 2))
565563
desired = np.array([[50.955833609920589, 50.133178918244099],
566-
[61.513615847062013, 50.757127871422448],
567-
[52.79816819717081, 49.973023331993552]])
564+
[61.513615847062013, 50.757127871422448],
565+
[52.79816819717081, 49.973023331993552]])
568566
np.testing.assert_allclose(actual, desired, atol=1e-7, rtol=1e-10)
569567

570568

@@ -573,11 +571,11 @@ def test_randomdist_dirichlet(randomdist):
573571
alpha = np.array([51.72840233779265162, 39.74494232180943953])
574572
actual = rnd.dirichlet(alpha, size=(3, 2))
575573
desired = np.array([[[0.6332947001908874, 0.36670529980911254],
576-
[0.5376828907571894, 0.4623171092428107]],
574+
[0.5376828907571894, 0.4623171092428107]],
577575
[[0.6835615930093024, 0.3164384069906976],
578-
[0.5452378139016114, 0.45476218609838875]],
576+
[0.5452378139016114, 0.45476218609838875]],
579577
[[0.6498494402738553, 0.3501505597261446],
580-
[0.5622024400324822, 0.43779755996751785]]])
578+
[0.5622024400324822, 0.43779755996751785]]])
581579
np.testing.assert_allclose(actual, desired, atol=4e-10, rtol=4e-10)
582580

583581

@@ -687,8 +685,9 @@ def test_randomdist_lognormal(randomdist):
687685
[0.1769118704670423, 3.415299544410577],
688686
[1.2417099625339398, 102.0631392685238]])
689687
np.testing.assert_allclose(actual, desired, atol=1e-6, rtol=1e-10)
690-
actual = rnd.lognormal(mean=.123456789, sigma=2.0, size=(3,2),
691-
method='Box-Muller2')
688+
actual = rnd.lognormal(
689+
mean=.123456789, sigma=2.0, size=(3, 2), method='Box-Muller2'
690+
)
692691
desired = np.array([[0.2585388231094821, 0.43734953048924663],
693692
[26.050836228611697, 26.76266237820882],
694693
[0.24216420175675096, 0.2481945765083541]])
@@ -781,11 +780,11 @@ def test_randomdist_multinormal_cholesky(randomdist):
781780
size = (3, 2)
782781
actual = rnd.multinormal_cholesky(mean, chol_mat, size, method='ICDF')
783782
desired = np.array([[[2.26461778189133, 6.857632824379853],
784-
[-0.8043233941855025, 11.01629429884193]],
783+
[-0.8043233941855025, 11.01629429884193]],
785784
[[0.1699731103551746, 12.227809261928217],
786-
[-0.6146263106001378, 9.893801873973892]],
785+
[-0.6146263106001378, 9.893801873973892]],
787786
[[1.691753328795276, 10.797627196240155],
788-
[-0.647341237129921, 9.626899489691816]]])
787+
[-0.647341237129921, 9.626899489691816]]])
789788
np.testing.assert_allclose(actual, desired, atol=1e-10, rtol=1e-10)
790789

791790

@@ -813,8 +812,7 @@ def test_randomdist_noncentral_chisquare(randomdist):
813812

814813
def test_randomdist_noncentral_f(randomdist):
815814
rnd.seed(randomdist.seed, brng=randomdist.brng)
816-
actual = rnd.noncentral_f(dfnum=5, dfden=2, nonc=1,
817-
size=(3, 2))
815+
actual = rnd.noncentral_f(dfnum=5, dfden=2, nonc=1, size=(3, 2))
818816
desired = np.array([[0.2216297348371284, 0.7632696724492449],
819817
[98.67664232828238, 0.9500319825372799],
820818
[0.3489618249246971, 1.5035633972571092]])
@@ -830,14 +828,18 @@ def test_randomdist_normal(randomdist):
830828
np.testing.assert_allclose(actual, desired, atol=1e-7, rtol=1e-10)
831829

832830
rnd.seed(randomdist.seed, brng=randomdist.brng)
833-
actual = rnd.normal(loc=.123456789, scale=2.0, size=(3, 2), method="BoxMuller")
831+
actual = rnd.normal(
832+
loc=.123456789, scale=2.0, size=(3, 2), method="BoxMuller"
833+
)
834834
desired = np.array([[0.16673479781277187, -3.4809986872165952],
835835
[-0.05193761082535492, 3.249201213154922],
836836
[-0.11915582299214138, 3.555636100927892]])
837837
np.testing.assert_allclose(actual, desired, atol=1e-8, rtol=1e-8)
838838

839839
rnd.seed(randomdist.seed, brng=randomdist.brng)
840-
actual = rnd.normal(loc=.123456789, scale=2.0, size=(3, 2), method="BoxMuller2")
840+
actual = rnd.normal(
841+
loc=.123456789, scale=2.0, size=(3, 2), method="BoxMuller2"
842+
)
841843
desired = np.array([[0.16673479781277187, 0.48153966449249175],
842844
[-3.4809986872165952, -0.8101190082826486],
843845
[-0.051937610825354905, 2.4088402362484342]])
@@ -849,8 +851,8 @@ def test_randomdist_pareto(randomdist):
849851
actual = rnd.pareto(a=.123456789, size=(3, 2))
850852
desired = np.array(
851853
[[0.14079174875385214, 82372044085468.92],
852-
[1247881.6368437486, 15.086855668610944],
853-
[203.2638558933401, 0.10445383654349749]])
854+
[1247881.6368437486, 15.086855668610944],
855+
[203.2638558933401, 0.10445383654349749]])
854856
# For some reason on 32-bit x86 Ubuntu 12.10 the [1, 0] entry in this
855857
# matrix differs by 24 nulps. Discussion:
856858
# http://mail.scipy.org/pipermail/numpy-discussion/2012-September/063801.html
@@ -953,8 +955,7 @@ def test_randomdist_standard_t(randomdist):
953955

954956
def test_randomdist_triangular(randomdist):
955957
rnd.seed(randomdist.seed, brng=randomdist.brng)
956-
actual = rnd.triangular(left=5.12, mode=10.23, right=20.34,
957-
size=(3, 2))
958+
actual = rnd.triangular(left=5.12, mode=10.23, right=20.34, size=(3, 2))
958959
desired = np.array([[18.764540652669638, 6.340166306695037],
959960
[8.827752689522429, 13.65605077739865],
960961
[11.732872979633328, 18.970392754850423]])
@@ -977,8 +978,8 @@ def test_uniform_range_bounds():
977978
func = rnd.uniform
978979
np.testing.assert_raises(OverflowError, func, -np.inf, 0)
979980
np.testing.assert_raises(OverflowError, func, 0, np.inf)
980-
# this should not throw any error, since rng can be sampled as fmin*u + fmax*(1-u)
981-
# for 0<u<1 and it stays completely in range
981+
# this should not throw any error, since rng can be sampled as
982+
# fmin*u + fmax*(1-u) for 0<u<1 and it stays completely in range
982983
rnd.uniform(fmin, fmax)
983984

984985
# (fmax / 1e17) - fmin is within range, so this should not throw
@@ -1006,9 +1007,9 @@ def test_randomdist_wald(randomdist):
10061007
actual = rnd.wald(mean=1.23, scale=1.54, size=(3, 2))
10071008
desired = np.array(
10081009
[[0.22448558337033758, 0.23485255518098838],
1009-
[2.756850184899666, 2.005347850108636],
1010-
[1.179918636588408, 0.20928649815442452]
1011-
])
1010+
[2.756850184899666, 2.005347850108636],
1011+
[1.179918636588408, 0.20928649815442452]]
1012+
)
10121013
np.testing.assert_allclose(actual, desired, atol=1e-10, rtol=1e-10)
10131014

10141015

@@ -1042,7 +1043,7 @@ def _check_function(seed_list, function, sz):
10421043

10431044
# threaded generation
10441045
t = [Thread(target=function, args=(rnd.MKLRandomState(s), o))
1045-
for s, o in zip(seed_list, out1)]
1046+
for s, o in zip(seed_list, out1)]
10461047
[x.start() for x in t]
10471048
[x.join() for x in t]
10481049

@@ -1074,4 +1075,4 @@ def test_multinomial(seed_vector):
10741075
# make sure each state produces the same sequence even in threads
10751076
def gen_random(state, out):
10761077
out[...] = state.multinomial(10, [1/6.]*6, size=10000)
1077-
_check_function(seed_vector, gen_random, sz=(10000,6))
1078+
_check_function(seed_vector, gen_random, sz=(10000, 6))

0 commit comments

Comments
 (0)