|
21 | 21 | import java.util.List; |
22 | 22 |
|
23 | 23 | import org.junit.jupiter.api.Test; |
| 24 | +import org.junit.jupiter.api.condition.EnabledIfSystemProperty; |
24 | 25 |
|
25 | 26 | import org.springframework.cache.Cache; |
26 | 27 | import org.springframework.dao.DataRetrievalFailureException; |
|
42 | 43 | import org.springframework.security.core.userdetails.PasswordEncodedUser; |
43 | 44 | import org.springframework.security.core.userdetails.User; |
44 | 45 | import org.springframework.security.core.userdetails.UserDetails; |
| 46 | +import org.springframework.security.core.userdetails.UserDetailsChecker; |
45 | 47 | import org.springframework.security.core.userdetails.UserDetailsPasswordService; |
46 | 48 | import org.springframework.security.core.userdetails.UserDetailsService; |
47 | 49 | import org.springframework.security.core.userdetails.UsernameNotFoundException; |
|
62 | 64 | import static org.mockito.ArgumentMatchers.eq; |
63 | 65 | import static org.mockito.ArgumentMatchers.isA; |
64 | 66 | import static org.mockito.BDDMockito.given; |
| 67 | +import static org.mockito.BDDMockito.willThrow; |
65 | 68 | import static org.mockito.Mockito.mock; |
66 | 69 | import static org.mockito.Mockito.times; |
67 | 70 | import static org.mockito.Mockito.verify; |
@@ -452,12 +455,10 @@ public void constructWhenPasswordEncoderProvidedThenSets() { |
452 | 455 | assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance()); |
453 | 456 | } |
454 | 457 |
|
455 | | - /** |
456 | | - * This is an explicit test for SEC-2056. It is intentionally ignored since this test |
457 | | - * is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that |
458 | | - * SEC-2056 is fixed. |
459 | | - */ |
460 | | - public void IGNOREtestSec2056() { |
| 458 | + // SEC-2056 |
| 459 | + @Test |
| 460 | + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") |
| 461 | + public void testSec2056() { |
461 | 462 | UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod", |
462 | 463 | "koala"); |
463 | 464 | UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken |
@@ -491,6 +492,41 @@ public void IGNOREtestSec2056() { |
491 | 492 | .isTrue(); |
492 | 493 | } |
493 | 494 |
|
| 495 | + // related to SEC-2056 |
| 496 | + @Test |
| 497 | + @EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true") |
| 498 | + public void testDisabledUserTiming() { |
| 499 | + UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala"); |
| 500 | + PasswordEncoder encoder = new BCryptPasswordEncoder(); |
| 501 | + DaoAuthenticationProvider provider = new DaoAuthenticationProvider(); |
| 502 | + provider.setPasswordEncoder(encoder); |
| 503 | + MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod(); |
| 504 | + users.password = encoder.encode((CharSequence) user.getCredentials()); |
| 505 | + provider.setUserDetailsService(users); |
| 506 | + int sampleSize = 100; |
| 507 | + List<Long> enabledTimes = new ArrayList<>(sampleSize); |
| 508 | + for (int i = 0; i < sampleSize; i++) { |
| 509 | + long start = System.currentTimeMillis(); |
| 510 | + provider.authenticate(user); |
| 511 | + enabledTimes.add(System.currentTimeMillis() - start); |
| 512 | + } |
| 513 | + UserDetailsChecker preChecks = mock(UserDetailsChecker.class); |
| 514 | + willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class)); |
| 515 | + provider.setPreAuthenticationChecks(preChecks); |
| 516 | + List<Long> disabledTimes = new ArrayList<>(sampleSize); |
| 517 | + for (int i = 0; i < sampleSize; i++) { |
| 518 | + long start = System.currentTimeMillis(); |
| 519 | + assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user)); |
| 520 | + disabledTimes.add(System.currentTimeMillis() - start); |
| 521 | + } |
| 522 | + double enabledAvg = avg(enabledTimes); |
| 523 | + double disabledAvg = avg(disabledTimes); |
| 524 | + assertThat(Math.abs(disabledAvg - enabledAvg) <= 3) |
| 525 | + .withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average " |
| 526 | + + enabledAvg) |
| 527 | + .isTrue(); |
| 528 | + } |
| 529 | + |
494 | 530 | private double avg(List<Long> counts) { |
495 | 531 | return counts.stream().mapToLong(Long::longValue).average().orElse(0); |
496 | 532 | } |
|
0 commit comments