Skip to content

Commit 3ae1ee1

Browse files
committed
O(1) bit-level nextafter with steps parameter
Replace the O(n) loop in nextafter(x, y, steps) with O(1) IEEE 754 bit manipulation matching math_nextafter_impl in mathmodule.c. Handles sign-crossing, saturation, and all edge cases. Add pyo3 proptest and edge/extreme-step tests for nextafter.
1 parent bcf75ff commit 3ae1ee1

1 file changed

Lines changed: 196 additions & 13 deletions

File tree

src/math/misc.rs

Lines changed: 196 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,63 @@ super::libm_simple!(@1 ceil, floor, trunc);
66

77
/// Return the next floating-point value after x towards y.
88
///
9-
/// If steps is provided, move that many steps towards y.
10-
/// Steps must be non-negative.
9+
/// If steps is provided, move that many steps towards y using O(1) bit
10+
/// manipulation on the IEEE 754 representation. Steps that overshoot y
11+
/// are clamped so the result never passes y.
12+
///
13+
/// See math_nextafter_impl in mathmodule.c.
1114
#[inline]
1215
pub fn nextafter(x: f64, y: f64, steps: Option<u64>) -> f64 {
13-
match steps {
14-
Some(n) => {
15-
let mut result = x;
16-
for _ in 0..n {
17-
result = crate::m::nextafter(result, y);
18-
if result == y {
19-
break;
20-
}
21-
}
22-
result
16+
let usteps = match steps {
17+
None => return crate::m::nextafter(x, y),
18+
Some(n) => n,
19+
};
20+
21+
if usteps == 0 || x.is_nan() {
22+
return x;
23+
}
24+
if y.is_nan() {
25+
return y;
26+
}
27+
28+
let mut ux = x.to_bits();
29+
let uy = y.to_bits();
30+
if ux == uy {
31+
return x;
32+
}
33+
34+
const SIGN_BIT: u64 = 1u64 << 63;
35+
let ax = ux & !SIGN_BIT;
36+
let ay = uy & !SIGN_BIT;
37+
38+
if (ux ^ uy) & SIGN_BIT != 0 {
39+
// opposite signs — may need to cross zero
40+
// ax + ay can never overflow because bit 63 is cleared in both
41+
if ax + ay <= usteps {
42+
return y;
43+
} else if ax < usteps {
44+
// cross zero: remaining steps land on y's side
45+
return f64::from_bits((uy & SIGN_BIT) | (usteps - ax));
46+
} else {
47+
ux -= usteps;
48+
return f64::from_bits(ux);
49+
}
50+
} else if ax > ay {
51+
// same sign, moving toward zero
52+
if ax - ay >= usteps {
53+
ux -= usteps;
54+
f64::from_bits(ux)
55+
} else {
56+
y
57+
}
58+
} else {
59+
// same sign, moving away from zero
60+
if ay - ax >= usteps {
61+
ux += usteps;
62+
f64::from_bits(ux)
63+
} else {
64+
y
2365
}
24-
None => crate::m::nextafter(x, y),
2566
}
2667
}
2768

@@ -587,4 +628,146 @@ mod tests {
587628
test_fma_impl(x, y, z);
588629
}
589630
}
631+
632+
fn test_nextafter(x: f64, y: f64) {
633+
use pyo3::prelude::*;
634+
635+
let rs = nextafter(x, y, None);
636+
pyo3::Python::attach(|py| {
637+
let math = pyo3::types::PyModule::import(py, "math").unwrap();
638+
let py_f: f64 = math
639+
.getattr("nextafter")
640+
.unwrap()
641+
.call1((x, y))
642+
.unwrap()
643+
.extract()
644+
.unwrap();
645+
if py_f.is_nan() && rs.is_nan() {
646+
return;
647+
}
648+
assert_eq!(
649+
py_f.to_bits(),
650+
rs.to_bits(),
651+
"nextafter({x}, {y}): py={py_f} vs rs={rs}"
652+
);
653+
});
654+
}
655+
656+
fn test_nextafter_steps(x: f64, y: f64, steps: u64) {
657+
use pyo3::prelude::*;
658+
659+
let rs = nextafter(x, y, Some(steps));
660+
pyo3::Python::attach(|py| {
661+
let math = pyo3::types::PyModule::import(py, "math").unwrap();
662+
let kwargs = pyo3::types::PyDict::new(py);
663+
kwargs.set_item("steps", steps).unwrap();
664+
let py_f: f64 = math
665+
.getattr("nextafter")
666+
.unwrap()
667+
.call((x, y), Some(&kwargs))
668+
.unwrap()
669+
.extract()
670+
.unwrap();
671+
if py_f.is_nan() && rs.is_nan() {
672+
return;
673+
}
674+
assert_eq!(
675+
py_f.to_bits(),
676+
rs.to_bits(),
677+
"nextafter({x}, {y}, steps={steps}): py={py_f} vs rs={rs}"
678+
);
679+
});
680+
}
681+
682+
#[test]
683+
fn edgetest_nextafter() {
684+
for &x in crate::test::EDGE_VALUES {
685+
for &y in crate::test::EDGE_VALUES {
686+
test_nextafter(x, y);
687+
}
688+
}
689+
}
690+
691+
#[test]
692+
fn edgetest_nextafter_steps() {
693+
let x_vals = [
694+
0.0,
695+
-0.0,
696+
1.0,
697+
-1.0,
698+
f64::INFINITY,
699+
f64::NEG_INFINITY,
700+
f64::NAN,
701+
];
702+
let y_vals = [
703+
0.0,
704+
-0.0,
705+
1.0,
706+
-1.0,
707+
f64::INFINITY,
708+
f64::NEG_INFINITY,
709+
f64::NAN,
710+
];
711+
let steps = [0, 1, 2, 10, 100, 1000, u64::MAX];
712+
713+
for &x in &x_vals {
714+
for &y in &y_vals {
715+
for &s in &steps {
716+
test_nextafter_steps(x, y, s);
717+
}
718+
}
719+
}
720+
}
721+
722+
#[test]
723+
fn test_nextafter_steps_large() {
724+
// Large steps should saturate to target
725+
test_nextafter_steps(0.0, 1.0, u64::MAX);
726+
test_nextafter_steps(0.0, f64::INFINITY, u64::MAX);
727+
test_nextafter_steps(1.0, -1.0, u64::MAX);
728+
test_nextafter_steps(-1.0, 1.0, u64::MAX);
729+
730+
// Steps exactly reaching a value
731+
// From 0.0 toward inf, 10 steps = 10 * 5e-324
732+
test_nextafter_steps(0.0, f64::INFINITY, 10);
733+
test_nextafter_steps(0.0, f64::NEG_INFINITY, 10);
734+
735+
// Crossing zero
736+
test_nextafter_steps(5e-324, -5e-324, 1);
737+
test_nextafter_steps(5e-324, -5e-324, 2);
738+
test_nextafter_steps(5e-324, -5e-324, 3);
739+
test_nextafter_steps(-5e-324, 5e-324, 1);
740+
test_nextafter_steps(-5e-324, 5e-324, 2);
741+
test_nextafter_steps(-5e-324, 5e-324, 3);
742+
743+
// Extreme steps that would hang with O(n) loop
744+
let extreme_steps: &[u64] = &[
745+
10u64.pow(9),
746+
10u64.pow(15),
747+
10u64.pow(18),
748+
u64::MAX / 2,
749+
u64::MAX - 1,
750+
u64::MAX,
751+
];
752+
for &s in extreme_steps {
753+
test_nextafter_steps(0.0, 1.0, s);
754+
test_nextafter_steps(0.0, f64::INFINITY, s);
755+
test_nextafter_steps(1.0, 0.0, s);
756+
test_nextafter_steps(-1.0, 1.0, s);
757+
test_nextafter_steps(f64::MIN_POSITIVE, f64::MAX, s);
758+
test_nextafter_steps(f64::MAX, f64::MIN_POSITIVE, s);
759+
}
760+
}
761+
762+
proptest::proptest! {
763+
#[test]
764+
fn proptest_nextafter(x: f64, y: f64) {
765+
test_nextafter(x, y);
766+
}
767+
768+
#[test]
769+
fn proptest_nextafter_steps(x: f64, y: f64, steps: u64) {
770+
test_nextafter_steps(x, y, steps);
771+
}
772+
}
590773
}

0 commit comments

Comments
 (0)