Skip to content

Commit d0f648c

Browse files
authored
Merge pull request #526 from numpy/better-cov-corrcoef
2 parents 68846ac + 8e33f50 commit d0f648c

2 files changed

Lines changed: 122 additions & 34 deletions

File tree

src/numpy-stubs/@test/static/accept/lib_function_base.pyi

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,15 @@ assert_type(np.extract(AR_i8, AR_LIKE_f8), npt.NDArray[Any])
126126

127127
assert_type(np.place(AR_f8, mask=AR_i8, vals=5.0), None)
128128

129-
assert_type(np.cov(AR_f8, bias=True), npt.NDArray[np.floating])
130-
assert_type(np.cov(AR_f8, AR_c16, ddof=1), npt.NDArray[np.complexfloating])
129+
assert_type(np.cov(AR_f8, bias=True), npt.NDArray[np.float64])
130+
assert_type(np.cov(AR_f8, AR_c16, ddof=1), npt.NDArray[np.complex128])
131131
assert_type(np.cov(AR_f8, aweights=AR_f8, dtype=np.float32), npt.NDArray[np.float32])
132-
assert_type(np.cov(AR_f8, fweights=AR_f8, dtype=float), npt.NDArray[Any])
132+
assert_type(np.cov(AR_f8, fweights=AR_i8, dtype=float), npt.NDArray[np.float64])
133133

134-
assert_type(np.corrcoef(AR_f8, rowvar=True), npt.NDArray[np.floating])
135-
assert_type(np.corrcoef(AR_f8, AR_c16), npt.NDArray[np.complexfloating])
134+
assert_type(np.corrcoef(AR_f8, rowvar=True), npt.NDArray[np.float64])
135+
assert_type(np.corrcoef(AR_f8, AR_c16), npt.NDArray[np.complex128])
136136
assert_type(np.corrcoef(AR_f8, dtype=np.float32), npt.NDArray[np.float32])
137-
assert_type(np.corrcoef(AR_f8, dtype=float), npt.NDArray[Any])
137+
assert_type(np.corrcoef(AR_f8, dtype=float), npt.NDArray[np.float64])
138138

139139
assert_type(np.blackman(5), _Array1D[np.floating])
140140
assert_type(np.bartlett(6), _Array1D[np.floating])

src/numpy-stubs/lib/_function_base_impl.pyi

Lines changed: 116 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -545,49 +545,97 @@ def place(arr: _nt.Array[Incomplete], mask: ArrayLike, vals: Incomplete) -> None
545545
#
546546
@overload
547547
def cov(
548-
m: _nt.CoFloating_1nd,
548+
m: _nt.CoFloat64_1nd,
549+
y: _nt.CoFloat64_1nd | None = None,
550+
rowvar: bool = True,
551+
bias: bool = False,
552+
ddof: _ToInt | None = None,
553+
fweights: _nt.ToInteger_1d | None = None,
554+
aweights: _nt.CoFloating_1d | None = None,
555+
*,
556+
dtype: _nt.ToDTypeFloat64 | None = None,
557+
) -> _nt.Array[np.float64]: ...
558+
@overload
559+
def cov(
560+
m: _nt.ToLongDouble_1nd,
549561
y: _nt.CoFloating_1nd | None = None,
550562
rowvar: bool = True,
551563
bias: bool = False,
552564
ddof: _ToInt | None = None,
553-
fweights: ArrayLike | None = None,
554-
aweights: ArrayLike | None = None,
565+
fweights: _nt.ToInteger_1d | None = None,
566+
aweights: _nt.CoFloating_1d | None = None,
555567
*,
556-
dtype: None = None,
557-
) -> _nt.Array[np.floating]: ...
568+
dtype: _nt.ToDTypeLongDouble | None = None,
569+
) -> _nt.Array[np.longdouble]: ...
570+
@overload
571+
def cov(
572+
m: _nt.CoFloating_1nd,
573+
y: _nt.ToLongDouble_1nd,
574+
rowvar: bool = True,
575+
bias: bool = False,
576+
ddof: _ToInt | None = None,
577+
fweights: _nt.ToInteger_1d | None = None,
578+
aweights: _nt.CoFloating_1d | None = None,
579+
*,
580+
dtype: _nt.ToDTypeLongDouble | None = None,
581+
) -> _nt.Array[np.longdouble]: ...
582+
@overload
583+
def cov(
584+
m: _nt.ToComplex128_1nd | _nt.ToComplex64_1nd,
585+
y: _nt.CoComplex128_1nd | None = None,
586+
rowvar: bool = True,
587+
bias: bool = False,
588+
ddof: _ToInt | None = None,
589+
fweights: _nt.ToInteger_1d | None = None,
590+
aweights: _nt.CoFloating_1d | None = None,
591+
*,
592+
dtype: _nt.ToDTypeComplex128 | None = None,
593+
) -> _nt.Array[np.complex128]: ...
558594
@overload
559595
def cov(
560-
m: _nt.ToComplex_1nd,
596+
m: _nt.CoComplex128_1nd,
597+
y: _nt.ToComplex128_1nd | _nt.ToComplex64_1nd,
598+
rowvar: bool = True,
599+
bias: bool = False,
600+
ddof: _ToInt | None = None,
601+
fweights: _nt.ToInteger_1d | None = None,
602+
aweights: _nt.CoFloating_1d | None = None,
603+
*,
604+
dtype: _nt.ToDTypeComplex128 | None = None,
605+
) -> _nt.Array[np.complex128]: ...
606+
@overload
607+
def cov(
608+
m: _nt.ToCLongDouble_1nd,
561609
y: _nt.CoComplex_1nd | None = None,
562610
rowvar: bool = True,
563611
bias: bool = False,
564612
ddof: _ToInt | None = None,
565-
fweights: ArrayLike | None = None,
566-
aweights: ArrayLike | None = None,
613+
fweights: _nt.ToInteger_1d | None = None,
614+
aweights: _nt.CoFloating_1d | None = None,
567615
*,
568-
dtype: None = None,
569-
) -> _nt.Array[np.complexfloating]: ...
616+
dtype: _nt.ToDTypeCLongDouble | None = None,
617+
) -> _nt.Array[np.clongdouble]: ...
570618
@overload
571619
def cov(
572620
m: _nt.CoComplex_1nd,
573-
y: _nt.ToComplex_1nd,
621+
y: _nt.ToCLongDouble_1nd,
574622
rowvar: bool = True,
575623
bias: bool = False,
576624
ddof: _ToInt | None = None,
577-
fweights: ArrayLike | None = None,
578-
aweights: ArrayLike | None = None,
625+
fweights: _nt.ToInteger_1d | None = None,
626+
aweights: _nt.CoFloating_1d | None = None,
579627
*,
580-
dtype: None = None,
581-
) -> _nt.Array[np.complexfloating]: ...
628+
dtype: _nt.ToDTypeCLongDouble | None = None,
629+
) -> _nt.Array[np.clongdouble]: ...
582630
@overload
583631
def cov(
584632
m: _nt.CoComplex_1nd,
585633
y: _nt.CoComplex_1nd | None = None,
586634
rowvar: bool = True,
587635
bias: bool = False,
588636
ddof: _ToInt | None = None,
589-
fweights: ArrayLike | None = None,
590-
aweights: ArrayLike | None = None,
637+
fweights: _nt.ToInteger_1d | None = None,
638+
aweights: _nt.CoFloating_1d | None = None,
591639
*,
592640
dtype: _DTypeLike[_ScalarT],
593641
) -> _nt.Array[_ScalarT]: ...
@@ -598,43 +646,83 @@ def cov(
598646
rowvar: bool = True,
599647
bias: bool = False,
600648
ddof: _ToInt | None = None,
601-
fweights: ArrayLike | None = None,
602-
aweights: ArrayLike | None = None,
649+
fweights: _nt.ToInteger_1d | None = None,
650+
aweights: _nt.CoFloating_1d | None = None,
603651
*,
604652
dtype: DTypeLike,
605653
) -> _nt.Array[Incomplete]: ...
606654

607655
# NOTE `bias` and `ddof` are deprecated and ignored
608656
@overload
609657
def corrcoef(
610-
x: _nt.CoFloating_1nd,
658+
x: _nt.CoFloat64_1nd,
659+
y: _nt.CoFloat64_1nd | None = None,
660+
rowvar: bool = True,
661+
bias: _NoValueType = ...,
662+
ddof: _NoValueType = ...,
663+
*,
664+
dtype: _nt.ToDTypeFloat64 | None = None,
665+
) -> _nt.Array[np.float64]: ...
666+
@overload
667+
def corrcoef(
668+
x: _nt.ToLongDouble_1nd,
611669
y: _nt.CoFloating_1nd | None = None,
612670
rowvar: bool = True,
613671
bias: _NoValueType = ...,
614672
ddof: _NoValueType = ...,
615673
*,
616-
dtype: None = None,
617-
) -> _nt.Array[np.floating]: ...
674+
dtype: _nt.ToDTypeLongDouble | None = None,
675+
) -> _nt.Array[np.longdouble]: ...
618676
@overload
619677
def corrcoef(
620-
x: _nt.ToComplex_1nd,
678+
x: _nt.CoFloating_1nd,
679+
y: _nt.ToLongDouble_1nd,
680+
rowvar: bool = True,
681+
bias: _NoValueType = ...,
682+
ddof: _NoValueType = ...,
683+
*,
684+
dtype: _nt.ToDTypeLongDouble | None = None,
685+
) -> _nt.Array[np.longdouble]: ...
686+
@overload
687+
def corrcoef(
688+
x: _nt.ToComplex128_1nd | _nt.ToComplex64_1nd,
689+
y: _nt.CoComplex128_1nd | None = None,
690+
rowvar: bool = True,
691+
bias: _NoValueType = ...,
692+
ddof: _NoValueType = ...,
693+
*,
694+
dtype: _nt.ToDTypeComplex128 | None = None,
695+
) -> _nt.Array[np.complex128]: ...
696+
@overload
697+
def corrcoef(
698+
x: _nt.CoComplex128_1nd,
699+
y: _nt.ToComplex128_1nd | _nt.ToComplex64_1nd,
700+
rowvar: bool = True,
701+
bias: _NoValueType = ...,
702+
ddof: _NoValueType = ...,
703+
*,
704+
dtype: _nt.ToDTypeComplex128 | None = None,
705+
) -> _nt.Array[np.complex128]: ...
706+
@overload
707+
def corrcoef(
708+
x: _nt.ToCLongDouble_1nd,
621709
y: _nt.CoComplex_1nd | None = None,
622710
rowvar: bool = True,
623711
bias: _NoValueType = ...,
624712
ddof: _NoValueType = ...,
625713
*,
626-
dtype: None = None,
627-
) -> _nt.Array[np.complexfloating]: ...
714+
dtype: _nt.ToDTypeCLongDouble | None = None,
715+
) -> _nt.Array[np.clongdouble]: ...
628716
@overload
629717
def corrcoef(
630718
x: _nt.CoComplex_1nd,
631-
y: _nt.ToComplex_1nd,
719+
y: _nt.ToCLongDouble_1nd,
632720
rowvar: bool = True,
633721
bias: _NoValueType = ...,
634722
ddof: _NoValueType = ...,
635723
*,
636-
dtype: None = None,
637-
) -> _nt.Array[np.complexfloating]: ...
724+
dtype: _nt.ToDTypeCLongDouble | None = None,
725+
) -> _nt.Array[np.clongdouble]: ...
638726
@overload
639727
def corrcoef(
640728
x: _nt.CoComplex_1nd,

0 commit comments

Comments
 (0)