|
21 | 21 | import java.util.List; |
22 | 22 |
|
23 | 23 | import org.junit.jupiter.api.Test; |
| 24 | +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; |
24 | 25 |
|
25 | 26 | import org.springframework.cache.Cache; |
26 | 27 | import org.springframework.dao.DataRetrievalFailureException; |
|
44 | 45 | import org.springframework.security.core.userdetails.PasswordEncodedUser; |
45 | 46 | import org.springframework.security.core.userdetails.User; |
46 | 47 | import org.springframework.security.core.userdetails.UserDetails; |
| 48 | +import org.springframework.security.core.userdetails.UserDetailsChecker; |
47 | 49 | import org.springframework.security.core.userdetails.UserDetailsPasswordService; |
48 | 50 | import org.springframework.security.core.userdetails.UserDetailsService; |
49 | 51 | import org.springframework.security.core.userdetails.UsernameNotFoundException; |
|
64 | 66 | import static org.mockito.ArgumentMatchers.eq; |
65 | 67 | import static org.mockito.ArgumentMatchers.isA; |
66 | 68 | import static org.mockito.BDDMockito.given; |
| 69 | +import static org.mockito.BDDMockito.willThrow; |
67 | 70 | import static org.mockito.Mockito.mock; |
68 | 71 | import static org.mockito.Mockito.times; |
69 | 72 | import static org.mockito.Mockito.verify; |
@@ -422,12 +425,10 @@ public void constructWhenPasswordEncoderProvidedThenSets() { |
422 | 425 | assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance()); |
423 | 426 | } |
424 | 427 |
|
425 | | - /** |
426 | | - * This is an explicit test for SEC-2056. It is intentionally ignored since this test |
427 | | - * is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that |
428 | | - * SEC-2056 is fixed. |
429 | | - */ |
430 | | - public void IGNOREtestSec2056() { |
| 428 | + // SEC-2056 |
| 429 | + @Test |
| 430 | + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") |
| 431 | + public void testSec2056() { |
431 | 432 | UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod", |
432 | 433 | "koala"); |
433 | 434 | UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken |
@@ -460,6 +461,41 @@ public void IGNOREtestSec2056() { |
460 | 461 | .isTrue(); |
461 | 462 | } |
462 | 463 |
|
| 464 | + // related to SEC-2056 |
| 465 | + @Test |
| 466 | + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") |
| 467 | + public void testDisabledUserTiming() { |
| 468 | + UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); |
| 469 | + PasswordEncoder encoder = new BCryptPasswordEncoder(); |
| 470 | + DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); |
| 471 | + provider.setPasswordEncoder(encoder); |
| 472 | + MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod(); |
| 473 | + users.password = encoder.encode((CharSequence) user.getCredentials()); |
| 474 | + provider.setUserDetailsService(users); |
| 475 | + int sampleSize = 100; |
| 476 | + List<Long> enabledTimes = new ArrayList<>(sampleSize); |
| 477 | + for (int i = 0; i < sampleSize; i++) { |
| 478 | + long start = System.currentTimeMillis(); |
| 479 | + provider.authenticate(user); |
| 480 | + enabledTimes.add(System.currentTimeMillis() - start); |
| 481 | + } |
| 482 | + UserDetailsChecker preChecks = mock(UserDetailsChecker.class); |
| 483 | + willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class)); |
| 484 | + provider.setPreAuthenticationChecks(preChecks); |
| 485 | + List<Long> disabledTimes = new ArrayList<>(sampleSize); |
| 486 | + for (int i = 0; i < sampleSize; i++) { |
| 487 | + long start = System.currentTimeMillis(); |
| 488 | + assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user)); |
| 489 | + disabledTimes.add(System.currentTimeMillis() - start); |
| 490 | + } |
| 491 | + double enabledAvg = avg(enabledTimes); |
| 492 | + double disabledAvg = avg(disabledTimes); |
| 493 | + assertThat(Math.abs(disabledAvg - enabledAvg) <= 3) |
| 494 | + .withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average " |
| 495 | + + enabledAvg) |
| 496 | + .isTrue(); |
| 497 | + } |
| 498 | + |
463 | 499 | private double avg(List<Long> counts) { |
464 | 500 | return counts.stream().mapToLong(Long::longValue).average().orElse(0); |
465 | 501 | } |
|
0 commit comments