Skip to content

Commit 70c9b4b

Browse files
1. Fix banded matric derivative checking
2. Standardize the output generation.
1 parent 5bbe7ba commit 70c9b4b

405 files changed

Lines changed: 7728 additions & 3118 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

BLAS/test/test_caxpy.f90

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ subroutine run_test_for_size(n, passed)
4747
integer :: incy
4848

4949
! Derivative variables
50+
complex(4), dimension(n) :: cx_d
5051
complex(4), dimension(n) :: cy_d
5152
complex(4) :: ca_d
52-
complex(4), dimension(n) :: cx_d
5353

5454
! Array restoration and derivative storage
55+
complex(4), dimension(n) :: cx_orig, cx_d_orig
5556
complex(4), dimension(n) :: cy_orig, cy_d_orig
5657
complex(4) :: ca_orig, ca_d_orig
57-
complex(4), dimension(n) :: cx_orig, cx_d_orig
5858
real(4) :: temp_re, temp_im ! For complex random init
5959
integer :: i, j
6060

@@ -80,24 +80,24 @@ subroutine run_test_for_size(n, passed)
8080
do i = 1, n
8181
call random_number(temp_re)
8282
call random_number(temp_im)
83-
cy_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
83+
cx_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
8484
end do
85-
call random_number(temp_re)
86-
call random_number(temp_im)
87-
ca_d = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
8885
do i = 1, n
8986
call random_number(temp_re)
9087
call random_number(temp_im)
91-
cx_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
88+
cy_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
9289
end do
90+
call random_number(temp_re)
91+
call random_number(temp_im)
92+
ca_d = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
9393

9494
! Store _orig and _d_orig
95+
cx_d_orig = cx_d
9596
cy_d_orig = cy_d
9697
ca_d_orig = ca_d
97-
cx_d_orig = cx_d
98+
cx_orig = cx
9899
cy_orig = cy
99100
ca_orig = ca
100-
cx_orig = cx
101101

102102
write(*,*) 'Testing CAXPY (n =', n, ')'
103103
cy_orig = cy
@@ -108,17 +108,17 @@ subroutine run_test_for_size(n, passed)
108108
write(*,*) 'Function calls completed successfully'
109109

110110
! Numerical differentiation check
111-
call check_derivatives_numerically(n, nsize, cy_orig, ca_orig, cx_orig, cy_d_orig, ca_d_orig, cx_d_orig, cy_d, passed)
111+
call check_derivatives_numerically(n, nsize, cx_orig, cy_orig, ca_orig, cx_d_orig, cy_d_orig, ca_d_orig, cy_d, passed)
112112

113113
end subroutine run_test_for_size
114114

115-
subroutine check_derivatives_numerically(n, nsize, cy_orig, ca_orig, cx_orig, cy_d_orig, ca_d_orig, cx_d_orig, cy_d, passed)
115+
subroutine check_derivatives_numerically(n, nsize, cx_orig, cy_orig, ca_orig, cx_d_orig, cy_d_orig, ca_d_orig, cy_d, passed)
116116
implicit none
117117
integer, intent(in) :: n
118118
integer, intent(in) :: nsize
119+
complex(4), intent(in) :: cx_orig(n), cx_d_orig(n)
119120
complex(4), intent(in) :: cy_orig(n), cy_d_orig(n)
120121
complex(4), intent(in) :: ca_orig, ca_d_orig
121-
complex(4), intent(in) :: cx_orig(n), cx_d_orig(n)
122122
complex(4), intent(in) :: cy_d(n)
123123
logical, intent(out) :: passed
124124

@@ -129,9 +129,9 @@ subroutine check_derivatives_numerically(n, nsize, cy_orig, ca_orig, cx_orig, cy
129129
logical :: has_large_errors
130130
complex(4), dimension(n) :: cy_forward, cy_backward
131131
integer :: i, j
132+
complex(4), dimension(n) :: cx
132133
complex(4), dimension(n) :: cy
133134
complex(4) :: ca
134-
complex(4), dimension(n) :: cx
135135

136136
max_error = 0.0e0
137137
has_large_errors = .false.
@@ -140,16 +140,16 @@ subroutine check_derivatives_numerically(n, nsize, cy_orig, ca_orig, cx_orig, cy
140140
write(*,*) 'Step size h =', h
141141

142142
! Forward perturbation: f(x + h)
143+
cx = cx_orig + h * cx_d_orig
143144
cy = cy_orig + h * cy_d_orig
144145
ca = ca_orig + h * ca_d_orig
145-
cx = cx_orig + h * cx_d_orig
146146
call caxpy(nsize, ca, cx, 1, cy, 1)
147147
cy_forward = cy
148148

149149
! Backward perturbation: f(x - h)
150+
cx = cx_orig - h * cx_d_orig
150151
cy = cy_orig - h * cy_d_orig
151152
ca = ca_orig - h * ca_d_orig
152-
cx = cx_orig - h * cx_d_orig
153153
call caxpy(nsize, ca, cx, 1, cy, 1)
154154
cy_backward = cy
155155

@@ -178,7 +178,7 @@ subroutine check_derivatives_numerically(n, nsize, cy_orig, ca_orig, cx_orig, cy
178178
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
179179
passed = .not. has_large_errors
180180
if (has_large_errors) then
181-
write(*,*) 'FAIL: Large errors detected in derivatives (outside tolerance)'
181+
write(*,*) 'FAIL: Derivatives are outside tolerance'
182182
else
183183
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
184184
end if

BLAS/test/test_caxpy_reverse.f90

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,11 @@ subroutine check_vjp_numerically(n, nsize, incx_val, incy_val, ca_orig, cx_orig,
205205
relative_error = abs_error
206206
end if
207207
max_error = relative_error
208-
209-
write(*,*) ''
210208
write(*,*) 'Maximum relative error:', max_error
211209
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
212210
passed = .not. has_large_errors
213211
if (has_large_errors) then
214-
write(*,*) 'FAIL: Large errors detected in derivatives (outside tolerance)'
212+
write(*,*) 'FAIL: Derivatives are outside tolerance'
215213
else
216214
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
217215
end if

BLAS/test/test_caxpy_vector_forward.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ program test_caxpy_vector_forward
2929
all_passed = all_passed .and. passed
3030
end do
3131
if (all_passed) then
32-
write(*,*) 'PASS: Vector forward mode - all sizes completed successfully'
32+
write(*,*) 'PASS: All sizes completed successfully'
3333
else
34-
write(*,*) 'FAIL: Vector forward mode - one or more sizes had derivative errors'
34+
write(*,*) 'FAIL: One or more sizes had derivative errors'
3535
end if
3636

3737
contains
@@ -126,7 +126,7 @@ subroutine check_derivatives_numerically(n, nbdirs, nsize, incx_val, incy_val, a
126126
max_error = 0.0e0
127127
has_large_errors = .false.
128128

129-
write(*,*) 'Checking vector derivatives against numerical differentiation:'
129+
write(*,*) 'Checking derivatives against numerical differentiation:'
130130
write(*,*) 'Step size h =', h
131131

132132
do idir = 1, nbdirs
@@ -152,13 +152,13 @@ subroutine check_derivatives_numerically(n, nbdirs, nsize, incx_val, incy_val, a
152152
end do
153153
end do
154154

155-
write(*,*) 'Maximum relative error across all directions:', max_error
155+
write(*,*) 'Maximum relative error:', max_error
156156
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
157157
passed = .not. has_large_errors
158158
if (has_large_errors) then
159-
write(*,*) 'FAIL: Large errors detected in vector derivatives (outside tolerance)'
159+
write(*,*) 'FAIL: Derivatives are outside tolerance'
160160
else
161-
write(*,*) 'PASS: Vector derivatives are within tolerance (rtol + atol)'
161+
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
162162
end if
163163

164164
end subroutine check_derivatives_numerically

BLAS/test/test_caxpy_vector_reverse.f90

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ program test_caxpy_vector_reverse
2929
all_passed = all_passed .and. passed
3030
end do
3131
if (all_passed) then
32-
write(*,*) 'PASS: Vector reverse mode - all sizes completed successfully'
32+
write(*,*) 'PASS: All sizes completed successfully'
3333
else
34-
write(*,*) 'FAIL: Vector reverse mode - one or more sizes had derivative errors'
34+
write(*,*) 'FAIL: One or more sizes had derivative errors'
3535
end if
3636

3737
contains
@@ -173,13 +173,11 @@ subroutine check_vjp_numerically(n, nbdirs, nsize, incx_val, incy_val, alpha_ori
173173
end if
174174
if (relative_error > max_error) max_error = relative_error
175175
end do
176-
177-
write(*,*) ''
178176
write(*,*) 'Maximum relative error:', max_error
179177
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
180178
passed = .not. has_large_errors
181179
if (has_large_errors) then
182-
write(*,*) 'FAIL: Large errors detected in derivatives (outside tolerance)'
180+
write(*,*) 'FAIL: Derivatives are outside tolerance'
183181
else
184182
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
185183
end if

BLAS/test/test_ccopy.f90

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ subroutine run_test_for_size(n, passed)
4646
integer :: incy
4747

4848
! Derivative variables
49-
complex(4), dimension(n) :: cy_d
5049
complex(4), dimension(n) :: cx_d
50+
complex(4), dimension(n) :: cy_d
5151

5252
! Array restoration and derivative storage
53-
complex(4), dimension(n) :: cy_orig, cy_d_orig
5453
complex(4), dimension(n) :: cx_orig, cx_d_orig
54+
complex(4), dimension(n) :: cy_orig, cy_d_orig
5555
real(4) :: temp_re, temp_im ! For complex random init
5656
integer :: i, j
5757

@@ -74,19 +74,19 @@ subroutine run_test_for_size(n, passed)
7474
do i = 1, n
7575
call random_number(temp_re)
7676
call random_number(temp_im)
77-
cy_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
77+
cx_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
7878
end do
7979
do i = 1, n
8080
call random_number(temp_re)
8181
call random_number(temp_im)
82-
cx_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
82+
cy_d(i) = cmplx(temp_re * 2.0 - 1.0, temp_im * 2.0 - 1.0, kind=4)
8383
end do
8484

8585
! Store _orig and _d_orig
86-
cy_d_orig = cy_d
8786
cx_d_orig = cx_d
88-
cy_orig = cy
87+
cy_d_orig = cy_d
8988
cx_orig = cx
89+
cy_orig = cy
9090

9191
write(*,*) 'Testing CCOPY (n =', n, ')'
9292

@@ -169,7 +169,7 @@ subroutine check_derivatives_numerically(n, nsize, cx_orig, cy_orig, cx_d_orig,
169169
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
170170
passed = .not. has_large_errors
171171
if (has_large_errors) then
172-
write(*,*) 'FAIL: Large errors detected in derivatives (outside tolerance)'
172+
write(*,*) 'FAIL: Derivatives are outside tolerance'
173173
else
174174
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
175175
end if

BLAS/test/test_ccopy_reverse.f90

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,11 @@ subroutine check_vjp_numerically(n, nsize, incx_val, incy_val, cx_orig, cy_orig,
187187
relative_error = abs_error
188188
end if
189189
max_error = relative_error
190-
191-
write(*,*) ''
192190
write(*,*) 'Maximum relative error:', max_error
193191
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
194192
passed = .not. has_large_errors
195193
if (has_large_errors) then
196-
write(*,*) 'FAIL: Large errors detected in derivatives (outside tolerance)'
194+
write(*,*) 'FAIL: Derivatives are outside tolerance'
197195
else
198196
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
199197
end if

BLAS/test/test_ccopy_vector_forward.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ program test_ccopy_vector_forward
2929
all_passed = all_passed .and. passed
3030
end do
3131
if (all_passed) then
32-
write(*,*) 'PASS: Vector forward mode - all sizes completed successfully'
32+
write(*,*) 'PASS: All sizes completed successfully'
3333
else
34-
write(*,*) 'FAIL: Vector forward mode - one or more sizes had derivative errors'
34+
write(*,*) 'FAIL: One or more sizes had derivative errors'
3535
end if
3636

3737
contains
@@ -113,7 +113,7 @@ subroutine check_derivatives_numerically(n, nbdirs, nsize, incx_val, incy_val, x
113113
max_error = 0.0e0
114114
has_large_errors = .false.
115115

116-
write(*,*) 'Checking vector derivatives against numerical differentiation:'
116+
write(*,*) 'Checking derivatives against numerical differentiation:'
117117
write(*,*) 'Step size h =', h
118118

119119
do idir = 1, nbdirs
@@ -137,13 +137,13 @@ subroutine check_derivatives_numerically(n, nbdirs, nsize, incx_val, incy_val, x
137137
end do
138138
end do
139139

140-
write(*,*) 'Maximum relative error across all directions:', max_error
140+
write(*,*) 'Maximum relative error:', max_error
141141
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
142142
passed = .not. has_large_errors
143143
if (has_large_errors) then
144-
write(*,*) 'FAIL: Large errors detected in vector derivatives (outside tolerance)'
144+
write(*,*) 'FAIL: Derivatives are outside tolerance'
145145
else
146-
write(*,*) 'PASS: Vector derivatives are within tolerance (rtol + atol)'
146+
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
147147
end if
148148

149149
end subroutine check_derivatives_numerically

BLAS/test/test_ccopy_vector_reverse.f90

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ program test_ccopy_vector_reverse
2929
all_passed = all_passed .and. passed
3030
end do
3131
if (all_passed) then
32-
write(*,*) 'PASS: Vector reverse mode - all sizes completed successfully'
32+
write(*,*) 'PASS: All sizes completed successfully'
3333
else
34-
write(*,*) 'FAIL: Vector reverse mode - one or more sizes had derivative errors'
34+
write(*,*) 'FAIL: One or more sizes had derivative errors'
3535
end if
3636

3737
contains
@@ -154,13 +154,11 @@ subroutine check_vjp_numerically(n, nbdirs, nsize, incx_val, incy_val, x_orig, y
154154
end if
155155
if (relative_error > max_error) max_error = relative_error
156156
end do
157-
158-
write(*,*) ''
159157
write(*,*) 'Maximum relative error:', max_error
160158
write(*,*) 'Tolerance thresholds: rtol=1.0e-3, atol=1.0e-3'
161159
passed = .not. has_large_errors
162160
if (has_large_errors) then
163-
write(*,*) 'FAIL: Large errors detected in derivatives (outside tolerance)'
161+
write(*,*) 'FAIL: Derivatives are outside tolerance'
164162
else
165163
write(*,*) 'PASS: Derivatives are within tolerance (rtol + atol)'
166164
end if

0 commit comments

Comments
 (0)