@@ -9,45 +9,40 @@ use ndarray::prelude::*;
99use ndarray:: Zip ;
1010
1111fn main ( ) {
12- let n = 16 ;
12+ let n = 6 ;
13+
1314 let mut a = Array :: < f32 , _ > :: zeros ( ( n, n) ) ;
14- let mut b = Array :: < f32 , _ > :: from_elem ( ( n, n) , 1. ) ;
15+ let mut b = Array :: < f32 , _ > :: zeros ( ( n, n) ) ;
1516 for ( ( i, j) , elt) in b. indexed_iter_mut ( ) {
16- * elt / = 1. + ( i + 2 * j) as f32 ;
17+ * elt = 1. / ( 1. + ( i + 2 * j) as f32 ) ;
1718 }
1819 let c = Array :: < f32 , _ > :: from_elem ( ( n, n + 1 ) , 1.7 ) ;
1920 let c = c. slice ( s ! [ .., ..-1 ] ) ;
2021
21- {
22- let a = a. view_mut ( ) . reversed_axes ( ) ;
23- azip ! ( ( a in a, & b in b. t( ) ) * a = b) ;
24- }
25- assert_eq ! ( a, b) ;
22+ // Using Zip for arithmetic ops across a, b, c
23+ Zip :: from ( & mut a) . and ( & b) . and ( & c)
24+ . for_each ( |a, & b, & c| * a = b + c) ;
25+ assert_eq ! ( a, & b + & c) ;
2626
27+ // and this is how to do the *same thing* with azip!()
2728 azip ! ( ( a in & mut a, & b in & b, & c in c) * a = b + c) ;
28- assert_eq ! ( a, & b + & c) ;
29+
30+ println ! ( "{:8.4}" , a) ;
2931
3032 // sum of each row
31- let ax = Axis ( 0 ) ;
32- let mut sums = Array :: zeros ( a. len_of ( ax) ) ;
33- azip ! ( ( s in & mut sums, a in a. axis_iter( ax) ) * s = a. sum( ) ) ;
33+ let mut sums = Array :: zeros ( a. nrows ( ) ) ;
34+ Zip :: from ( a. rows ( ) ) . and ( & mut sums)
35+ . for_each ( |row, sum| * sum = row. sum ( ) ) ;
36+ // show sums as a column matrix
37+ println ! ( "{:8.4}" , sums. insert_axis( Axis ( 1 ) ) ) ;
3438
35- // sum of each chunk
39+ // sum of each 2x2 chunk
3640 let chunk_sz = ( 2 , 2 ) ;
3741 let nchunks = ( n / chunk_sz. 0 , n / chunk_sz. 1 ) ;
3842 let mut sums = Array :: zeros ( nchunks) ;
39- azip ! ( ( s in & mut sums, a in a. exact_chunks( chunk_sz) ) * s = a. sum( ) ) ;
40-
41- // Let's imagine we split to parallelize
42- {
43- let ( x, y) = Zip :: indexed ( & mut a) . split ( ) ;
44- x. for_each ( |( _, j) , elt| {
45- * elt = elt. powi ( j as i32 ) ;
46- } ) ;
47-
48- y. for_each ( |( _, j) , elt| {
49- * elt = elt. powi ( j as i32 ) ;
50- } ) ;
51- }
52- println ! ( "{:8.3?}" , a) ;
43+
44+ Zip :: from ( a. exact_chunks ( chunk_sz) )
45+ . and ( & mut sums)
46+ . for_each ( |chunk, sum| * sum = chunk. sum ( ) ) ;
47+ println ! ( "{:8.4}" , sums) ;
5348}
0 commit comments