Skip to content

Commit a317a3d

Browse files
committed
Add Support for Always Running Additional Authentication Checks
Signed-off-by: Josh Cummings <3627351+jzheaux@users.noreply.github.com>
1 parent 68b820e commit a317a3d

2 files changed

Lines changed: 81 additions & 10 deletions

File tree

core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider
9292

9393
private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();
9494

95+
private boolean alwaysPerformAdditionalChecksOnUser = true;
96+
9597
private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();
9698

9799
/**
@@ -146,8 +148,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
146148
Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract");
147149
}
148150
try {
149-
this.preAuthenticationChecks.check(user);
150-
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
151+
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
151152
}
152153
catch (AuthenticationException ex) {
153154
if (!cacheWasUsed) {
@@ -157,8 +158,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
157158
// we're using latest data (i.e. not from the cache)
158159
cacheWasUsed = false;
159160
user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication);
160-
this.preAuthenticationChecks.check(user);
161-
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
161+
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
162162
}
163163
this.postAuthenticationChecks.check(user);
164164
if (!cacheWasUsed) {
@@ -171,6 +171,25 @@ public Authentication authenticate(Authentication authentication) throws Authent
171171
return createSuccessAuthentication(principalToReturn, authentication, user);
172172
}
173173

174+
private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) {
175+
try {
176+
this.preAuthenticationChecks.check(user);
177+
}
178+
catch (AuthenticationException ex) {
179+
if (!this.alwaysPerformAdditionalChecksOnUser) {
180+
throw ex;
181+
}
182+
try {
183+
additionalAuthenticationChecks(user, authentication);
184+
}
185+
catch (AuthenticationException ignored) {
186+
// preserve the original failed check
187+
}
188+
throw ex;
189+
}
190+
additionalAuthenticationChecks(user, authentication);
191+
}
192+
174193
private String determineUsername(Authentication authentication) {
175194
return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName();
176195
}
@@ -313,6 +332,22 @@ public void setPostAuthenticationChecks(UserDetailsChecker postAuthenticationChe
313332
this.postAuthenticationChecks = postAuthenticationChecks;
314333
}
315334

335+
/**
336+
* Set whether to always perform the additional checks on the user, even if the
337+
* pre-authentication checks fail. This is useful to ensure that regardless of the
338+
* state of the user account, authentication takes the same amount of time to
339+
* complete.
340+
*
341+
* <p>
342+
* For applications that rely on the additional checks running only once should set
343+
* this value to {@code false}
344+
* @param alwaysPerformAdditionalChecksOnUser
345+
* @since 5.7.23
346+
*/
347+
public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) {
348+
this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser;
349+
}
350+
316351
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
317352
this.authoritiesMapper = authoritiesMapper;
318353
}

core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.List;
2222

2323
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
2425

2526
import org.springframework.cache.Cache;
2627
import org.springframework.dao.DataRetrievalFailureException;
@@ -42,6 +43,7 @@
4243
import org.springframework.security.core.userdetails.PasswordEncodedUser;
4344
import org.springframework.security.core.userdetails.User;
4445
import org.springframework.security.core.userdetails.UserDetails;
46+
import org.springframework.security.core.userdetails.UserDetailsChecker;
4547
import org.springframework.security.core.userdetails.UserDetailsPasswordService;
4648
import org.springframework.security.core.userdetails.UserDetailsService;
4749
import org.springframework.security.core.userdetails.UsernameNotFoundException;
@@ -62,6 +64,7 @@
6264
import static org.mockito.ArgumentMatchers.eq;
6365
import static org.mockito.ArgumentMatchers.isA;
6466
import static org.mockito.BDDMockito.given;
67+
import static org.mockito.BDDMockito.willThrow;
6568
import static org.mockito.Mockito.mock;
6669
import static org.mockito.Mockito.times;
6770
import static org.mockito.Mockito.verify;
@@ -452,12 +455,10 @@ public void constructWhenPasswordEncoderProvidedThenSets() {
452455
assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance());
453456
}
454457

455-
/**
456-
* This is an explicit test for SEC-2056. It is intentionally ignored since this test
457-
* is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that
458-
* SEC-2056 is fixed.
459-
*/
460-
public void IGNOREtestSec2056() {
458+
// SEC-2056
459+
@Test
460+
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
461+
public void testSec2056() {
461462
UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod",
462463
"koala");
463464
UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken
@@ -491,6 +492,41 @@ public void IGNOREtestSec2056() {
491492
.isTrue();
492493
}
493494

495+
// related to SEC-2056
496+
@Test
497+
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
498+
public void testDisabledUserTiming() {
499+
UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala");
500+
PasswordEncoder encoder = new BCryptPasswordEncoder();
501+
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
502+
provider.setPasswordEncoder(encoder);
503+
MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod();
504+
users.password = encoder.encode((CharSequence) user.getCredentials());
505+
provider.setUserDetailsService(users);
506+
int sampleSize = 100;
507+
List<Long> enabledTimes = new ArrayList<>(sampleSize);
508+
for (int i = 0; i < sampleSize; i++) {
509+
long start = System.currentTimeMillis();
510+
provider.authenticate(user);
511+
enabledTimes.add(System.currentTimeMillis() - start);
512+
}
513+
UserDetailsChecker preChecks = mock(UserDetailsChecker.class);
514+
willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class));
515+
provider.setPreAuthenticationChecks(preChecks);
516+
List<Long> disabledTimes = new ArrayList<>(sampleSize);
517+
for (int i = 0; i < sampleSize; i++) {
518+
long start = System.currentTimeMillis();
519+
assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user));
520+
disabledTimes.add(System.currentTimeMillis() - start);
521+
}
522+
double enabledAvg = avg(enabledTimes);
523+
double disabledAvg = avg(disabledTimes);
524+
assertThat(Math.abs(disabledAvg - enabledAvg) <= 3)
525+
.withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average "
526+
+ enabledAvg)
527+
.isTrue();
528+
}
529+
494530
private double avg(List<Long> counts) {
495531
return counts.stream().mapToLong(Long::longValue).average().orElse(0);
496532
}

0 commit comments

Comments
 (0)