Skip to content

Commit 6e5f8f2

Browse files
committed
Merge remote-tracking branch 'origin/6.5.x' into 6.5.x
2 parents 5b638a5 + 4187af3 commit 6e5f8f2

8 files changed

Lines changed: 174 additions & 19 deletions

File tree

core/src/main/java/org/springframework/security/authentication/dao/AbstractUserDetailsAuthenticationProvider.java

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ public abstract class AbstractUserDetailsAuthenticationProvider
9292

9393
private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();
9494

95+
private boolean alwaysPerformAdditionalChecksOnUser = true;
96+
9597
private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();
9698

9799
/**
@@ -146,8 +148,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
146148
Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract");
147149
}
148150
try {
149-
this.preAuthenticationChecks.check(user);
150-
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
151+
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
151152
}
152153
catch (AuthenticationException ex) {
153154
if (!cacheWasUsed) {
@@ -157,8 +158,7 @@ public Authentication authenticate(Authentication authentication) throws Authent
157158
// we're using latest data (i.e. not from the cache)
158159
cacheWasUsed = false;
159160
user = retrieveUser(username, (UsernamePasswordAuthenticationToken) authentication);
160-
this.preAuthenticationChecks.check(user);
161-
additionalAuthenticationChecks(user, (UsernamePasswordAuthenticationToken) authentication);
161+
performPreCheck(user, (UsernamePasswordAuthenticationToken) authentication);
162162
}
163163
this.postAuthenticationChecks.check(user);
164164
if (!cacheWasUsed) {
@@ -171,6 +171,25 @@ public Authentication authenticate(Authentication authentication) throws Authent
171171
return createSuccessAuthentication(principalToReturn, authentication, user);
172172
}
173173

174+
private void performPreCheck(UserDetails user, UsernamePasswordAuthenticationToken authentication) {
175+
try {
176+
this.preAuthenticationChecks.check(user);
177+
}
178+
catch (AuthenticationException ex) {
179+
if (!this.alwaysPerformAdditionalChecksOnUser) {
180+
throw ex;
181+
}
182+
try {
183+
additionalAuthenticationChecks(user, authentication);
184+
}
185+
catch (AuthenticationException ignored) {
186+
// preserve the original failed check
187+
}
188+
throw ex;
189+
}
190+
additionalAuthenticationChecks(user, authentication);
191+
}
192+
174193
private String determineUsername(Authentication authentication) {
175194
return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName();
176195
}
@@ -313,6 +332,22 @@ public void setPostAuthenticationChecks(UserDetailsChecker postAuthenticationChe
313332
this.postAuthenticationChecks = postAuthenticationChecks;
314333
}
315334

335+
/**
336+
* Set whether to always perform the additional checks on the user, even if the
337+
* pre-authentication checks fail. This is useful to ensure that regardless of the
338+
* state of the user account, authentication takes the same amount of time to
339+
* complete.
340+
*
341+
* <p>
342+
* For applications that rely on the additional checks running only once should set
343+
* this value to {@code false}
344+
* @param alwaysPerformAdditionalChecksOnUser
345+
* @since 5.7.23
346+
*/
347+
public void setAlwaysPerformAdditionalChecksOnUser(boolean alwaysPerformAdditionalChecksOnUser) {
348+
this.alwaysPerformAdditionalChecksOnUser = alwaysPerformAdditionalChecksOnUser;
349+
}
350+
316351
public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
317352
this.authoritiesMapper = authoritiesMapper;
318353
}

core/src/main/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenService.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ public OneTimeToken consume(OneTimeTokenAuthenticationToken authenticationToken)
152152
return null;
153153
}
154154
OneTimeToken token = tokens.get(0);
155-
deleteOneTimeToken(token);
155+
if (deleteOneTimeToken(token) == 0) {
156+
return null;
157+
}
156158
if (isExpired(token)) {
157159
return null;
158160
}
@@ -170,11 +172,11 @@ private List<OneTimeToken> selectOneTimeToken(OneTimeTokenAuthenticationToken au
170172
return this.jdbcOperations.query(SELECT_ONE_TIME_TOKEN_SQL, pss, this.oneTimeTokenRowMapper);
171173
}
172174

173-
private void deleteOneTimeToken(OneTimeToken oneTimeToken) {
175+
private int deleteOneTimeToken(OneTimeToken oneTimeToken) {
174176
List<SqlParameterValue> parameters = List
175177
.of(new SqlParameterValue(Types.VARCHAR, oneTimeToken.getTokenValue()));
176178
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
177-
this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
179+
return this.jdbcOperations.update(DELETE_ONE_TIME_TOKEN_SQL, pss);
178180
}
179181

180182
private ThreadPoolTaskScheduler createTaskScheduler(String cleanupCron) {

core/src/test/java/org/springframework/security/authentication/dao/DaoAuthenticationProviderTests.java

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.List;
2222

2323
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
2425

2526
import org.springframework.cache.Cache;
2627
import org.springframework.dao.DataRetrievalFailureException;
@@ -42,6 +43,7 @@
4243
import org.springframework.security.core.userdetails.PasswordEncodedUser;
4344
import org.springframework.security.core.userdetails.User;
4445
import org.springframework.security.core.userdetails.UserDetails;
46+
import org.springframework.security.core.userdetails.UserDetailsChecker;
4547
import org.springframework.security.core.userdetails.UserDetailsPasswordService;
4648
import org.springframework.security.core.userdetails.UserDetailsService;
4749
import org.springframework.security.core.userdetails.UsernameNotFoundException;
@@ -62,6 +64,7 @@
6264
import static org.mockito.ArgumentMatchers.eq;
6365
import static org.mockito.ArgumentMatchers.isA;
6466
import static org.mockito.BDDMockito.given;
67+
import static org.mockito.BDDMockito.willThrow;
6568
import static org.mockito.Mockito.mock;
6669
import static org.mockito.Mockito.times;
6770
import static org.mockito.Mockito.verify;
@@ -452,12 +455,10 @@ public void constructWhenPasswordEncoderProvidedThenSets() {
452455
assertThat(daoAuthenticationProvider.getPasswordEncoder()).isSameAs(NoOpPasswordEncoder.getInstance());
453456
}
454457

455-
/**
456-
* This is an explicit test for SEC-2056. It is intentionally ignored since this test
457-
* is not deterministic and {@link #testUserNotFoundEncodesPassword()} ensures that
458-
* SEC-2056 is fixed.
459-
*/
460-
public void IGNOREtestSec2056() {
458+
// SEC-2056
459+
@Test
460+
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
461+
public void testSec2056() {
461462
UsernamePasswordAuthenticationToken foundUser = UsernamePasswordAuthenticationToken.unauthenticated("rod",
462463
"koala");
463464
UsernamePasswordAuthenticationToken notFoundUser = UsernamePasswordAuthenticationToken
@@ -491,6 +492,41 @@ public void IGNOREtestSec2056() {
491492
.isTrue();
492493
}
493494

495+
// related to SEC-2056
496+
@Test
497+
@EnabledIfSystemProperty(named = "spring.security.timing-tests", matches = "true")
498+
public void testDisabledUserTiming() {
499+
UsernamePasswordAuthenticationToken user = UsernamePasswordAuthenticationToken.unauthenticated("rod", "koala");
500+
PasswordEncoder encoder = new BCryptPasswordEncoder();
501+
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
502+
provider.setPasswordEncoder(encoder);
503+
MockUserDetailsServiceUserRod users = new MockUserDetailsServiceUserRod();
504+
users.password = encoder.encode((CharSequence) user.getCredentials());
505+
provider.setUserDetailsService(users);
506+
int sampleSize = 100;
507+
List<Long> enabledTimes = new ArrayList<>(sampleSize);
508+
for (int i = 0; i < sampleSize; i++) {
509+
long start = System.currentTimeMillis();
510+
provider.authenticate(user);
511+
enabledTimes.add(System.currentTimeMillis() - start);
512+
}
513+
UserDetailsChecker preChecks = mock(UserDetailsChecker.class);
514+
willThrow(new DisabledException("User is disabled")).given(preChecks).check(any(UserDetails.class));
515+
provider.setPreAuthenticationChecks(preChecks);
516+
List<Long> disabledTimes = new ArrayList<>(sampleSize);
517+
for (int i = 0; i < sampleSize; i++) {
518+
long start = System.currentTimeMillis();
519+
assertThatExceptionOfType(DisabledException.class).isThrownBy(() -> provider.authenticate(user));
520+
disabledTimes.add(System.currentTimeMillis() - start);
521+
}
522+
double enabledAvg = avg(enabledTimes);
523+
double disabledAvg = avg(disabledTimes);
524+
assertThat(Math.abs(disabledAvg - enabledAvg) <= 3)
525+
.withFailMessage("Disabled user average " + disabledAvg + " should be within 3ms of enabled user average "
526+
+ enabledAvg)
527+
.isTrue();
528+
}
529+
494530
private double avg(List<Long> counts) {
495531
return counts.stream().mapToLong(Long::longValue).average().orElse(0);
496532
}

core/src/test/java/org/springframework/security/authentication/ott/JdbcOneTimeTokenServiceTests.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,24 @@
2121
import java.time.Instant;
2222
import java.time.ZoneOffset;
2323
import java.time.temporal.ChronoUnit;
24+
import java.util.List;
2425

2526
import org.junit.jupiter.api.AfterEach;
2627
import org.junit.jupiter.api.BeforeEach;
2728
import org.junit.jupiter.api.Test;
29+
import org.mockito.ArgumentMatchers;
2830

2931
import org.springframework.jdbc.core.JdbcOperations;
3032
import org.springframework.jdbc.core.JdbcTemplate;
33+
import org.springframework.jdbc.core.PreparedStatementSetter;
34+
import org.springframework.jdbc.core.RowMapper;
3135
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
3236
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
3337
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
3438

3539
import static org.assertj.core.api.Assertions.assertThat;
3640
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
41+
import static org.mockito.ArgumentMatchers.any;
3742
import static org.mockito.BDDMockito.given;
3843
import static org.mockito.Mockito.mock;
3944

@@ -145,6 +150,27 @@ void consumeWhenTokenDoesNotExistsThenReturnNull() {
145150
assertThat(consumedOneTimeToken).isNull();
146151
}
147152

153+
@Test
154+
void consumeWhenTokenIsDeletedConcurrentlyThenReturnNull() throws Exception {
155+
// Simulates a concurrent consume: SELECT finds the token but DELETE affects
156+
// 0 rows because another caller already consumed it.
157+
JdbcOperations jdbcOperations = mock(JdbcOperations.class);
158+
Instant notExpired = Instant.now().plus(5, ChronoUnit.MINUTES);
159+
OneTimeToken token = new DefaultOneTimeToken(TOKEN_VALUE, USERNAME, notExpired);
160+
given(jdbcOperations.query(any(String.class), any(PreparedStatementSetter.class),
161+
ArgumentMatchers.<RowMapper<OneTimeToken>>any()))
162+
.willReturn(List.of(token));
163+
given(jdbcOperations.update(any(String.class), any(PreparedStatementSetter.class))).willReturn(0);
164+
JdbcOneTimeTokenService service = new JdbcOneTimeTokenService(jdbcOperations);
165+
try {
166+
OneTimeToken consumed = service.consume(new OneTimeTokenAuthenticationToken(TOKEN_VALUE));
167+
assertThat(consumed).isNull();
168+
}
169+
finally {
170+
service.destroy();
171+
}
172+
}
173+
148174
@Test
149175
void consumeWhenTokenIsExpiredThenReturnNull() {
150176
GenerateOneTimeTokenRequest request = new GenerateOneTimeTokenRequest(USERNAME);

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoder.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ public static JwkSetUriJwtDecoderBuilder withIssuerLocation(String issuer) {
230230
.getConfigurationForIssuerLocation(issuer, rest);
231231
JwtDecoderProviderConfigurationUtils.validateIssuer(configuration, issuer);
232232
return configuration.get("jwks_uri").toString();
233-
}, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms);
233+
}, JwtDecoderProviderConfigurationUtils::getJWSAlgorithms)
234+
.validator(JwtValidators.createDefaultWithIssuer(issuer));
234235
}
235236

236237
/**
@@ -289,6 +290,8 @@ public static final class JwkSetUriJwtDecoderBuilder {
289290

290291
private Consumer<ConfigurableJWTProcessor<SecurityContext>> jwtProcessorCustomizer;
291292

293+
private OAuth2TokenValidator<Jwt> validator = JwtValidators.createDefault();
294+
292295
private JwkSetUriJwtDecoderBuilder(String jwkSetUri) {
293296
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
294297
this.jwkSetUri = (rest) -> jwkSetUri;
@@ -423,6 +426,12 @@ public JwkSetUriJwtDecoderBuilder jwtProcessorCustomizer(
423426
return this;
424427
}
425428

429+
JwkSetUriJwtDecoderBuilder validator(OAuth2TokenValidator<Jwt> validator) {
430+
Assert.notNull(validator, "validator cannot be null");
431+
this.validator = validator;
432+
return this;
433+
}
434+
426435
JWSKeySelector<SecurityContext> jwsKeySelector(JWKSource<SecurityContext> jwkSource) {
427436
if (this.signatureAlgorithms.isEmpty()) {
428437
return new JWSVerificationKeySelector<>(this.defaultAlgorithms.apply(jwkSource), jwkSource);
@@ -461,7 +470,9 @@ JWTProcessor<SecurityContext> processor() {
461470
* @return the configured {@link NimbusJwtDecoder}
462471
*/
463472
public NimbusJwtDecoder build() {
464-
return new NimbusJwtDecoder(processor());
473+
NimbusJwtDecoder decoder = new NimbusJwtDecoder(processor());
474+
decoder.setJwtValidator(this.validator);
475+
return decoder;
465476
}
466477

467478
private static final class SpringJWKSource<C extends SecurityContext> implements JWKSetSource<C> {

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoder.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ public static JwkSetUriReactiveJwtDecoderBuilder withIssuerLocation(String issue
241241
}
242242
return Mono.just(configuration.get("jwks_uri").toString());
243243
}),
244-
ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms);
244+
ReactiveJwtDecoderProviderConfigurationUtils::getJWSAlgorithms)
245+
.validator(JwtValidators.createDefaultWithIssuer(issuer));
245246
}
246247

247248
/**
@@ -332,6 +333,8 @@ public static final class JwkSetUriReactiveJwtDecoderBuilder {
332333

333334
private BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer;
334335

336+
private OAuth2TokenValidator<Jwt> validator = JwtValidators.createDefault();
337+
335338
private JwkSetUriReactiveJwtDecoderBuilder(String jwkSetUri) {
336339
Assert.hasText(jwkSetUri, "jwkSetUri cannot be empty");
337340
this.jwkSetUri = (web) -> Mono.just(jwkSetUri);
@@ -456,6 +459,11 @@ public JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
456459
return this;
457460
}
458461

462+
JwkSetUriReactiveJwtDecoderBuilder validator(OAuth2TokenValidator<Jwt> validator) {
463+
this.validator = validator;
464+
return this;
465+
}
466+
459467
JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
460468
BiFunction<ReactiveRemoteJWKSource, ConfigurableJWTProcessor<JWKSecurityContext>, Mono<ConfigurableJWTProcessor<JWKSecurityContext>>> jwtProcessorCustomizer) {
461469
Assert.notNull(jwtProcessorCustomizer, "jwtProcessorCustomizer cannot be null");
@@ -468,7 +476,9 @@ JwkSetUriReactiveJwtDecoderBuilder jwtProcessorCustomizer(
468476
* @return the configured {@link NimbusReactiveJwtDecoder}
469477
*/
470478
public NimbusReactiveJwtDecoder build() {
471-
return new NimbusReactiveJwtDecoder(processor());
479+
NimbusReactiveJwtDecoder decoder = new NimbusReactiveJwtDecoder(processor());
480+
decoder.setJwtValidator(this.validator);
481+
return decoder;
472482
}
473483

474484
Mono<JWSKeySelector<JWKSecurityContext>> jwsKeySelector(ReactiveRemoteJWKSource source) {

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwtDecoderTests.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,11 +328,26 @@ public void decodeWhenIssuerLocationThenOk() {
328328
.willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK));
329329
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
330330
.willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
331-
JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build();
331+
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer)
332+
.restOperations(restOperations)
333+
.build();
334+
jwtDecoder.setJwtValidator(JwtValidators.createDefault());
332335
Jwt jwt = jwtDecoder.decode(SIGNED_JWT);
333336
assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
334337
}
335338

339+
@Test
340+
public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() {
341+
String issuer = "https://example.org/wrong-issuer";
342+
RestOperations restOperations = mock(RestOperations.class);
343+
given(restOperations.exchange(any(RequestEntity.class), any(ParameterizedTypeReference.class)))
344+
.willReturn(new ResponseEntity<>(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks"), HttpStatus.OK));
345+
given(restOperations.exchange(any(RequestEntity.class), eq(String.class)))
346+
.willReturn(new ResponseEntity<>(JWK_SET, HttpStatus.OK));
347+
JwtDecoder jwtDecoder = NimbusJwtDecoder.withIssuerLocation(issuer).restOperations(restOperations).build();
348+
assertThatExceptionOfType(JwtValidationException.class).isThrownBy(() -> jwtDecoder.decode(SIGNED_JWT));
349+
}
350+
336351
@Test
337352
public void withJwkSetUriWhenNullOrEmptyThenThrowsException() {
338353
// @formatter:off

oauth2/oauth2-jose/src/test/java/org/springframework/security/oauth2/jwt/NimbusReactiveJwtDecoderTests.java

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,13 +617,33 @@ public void decodeWhenIssuerLocationThenOk() {
617617
given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class)))
618618
.willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks")));
619619
given(spec.retrieve()).willReturn(responseSpec);
620-
ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
620+
NimbusReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
621621
.webClient(webClient)
622622
.build();
623+
jwtDecoder.setJwtValidator(JwtValidators.createDefault());
623624
Jwt jwt = jwtDecoder.decode(this.messageReadToken).block();
624625
assertThat(jwt.hasClaim(JwtClaimNames.EXP)).isNotNull();
625626
}
626627

628+
@Test
629+
public void decodeWhenIssuerLocationThenRejectsMismatchingIssuers() {
630+
String issuer = "https://example.org/wrong-issuer";
631+
WebClient real = WebClient.builder().build();
632+
WebClient.RequestHeadersUriSpec spec = spy(real.get());
633+
WebClient webClient = spy(WebClient.class);
634+
given(webClient.get()).willReturn(spec);
635+
WebClient.ResponseSpec responseSpec = mock(WebClient.ResponseSpec.class);
636+
given(responseSpec.bodyToMono(String.class)).willReturn(Mono.just(this.jwkSet));
637+
given(responseSpec.bodyToMono(any(ParameterizedTypeReference.class)))
638+
.willReturn(Mono.just(Map.of("issuer", issuer, "jwks_uri", issuer + "/jwks")));
639+
given(spec.retrieve()).willReturn(responseSpec);
640+
ReactiveJwtDecoder jwtDecoder = NimbusReactiveJwtDecoder.withIssuerLocation(issuer)
641+
.webClient(webClient)
642+
.build();
643+
assertThatExceptionOfType(JwtValidationException.class)
644+
.isThrownBy(() -> jwtDecoder.decode(this.messageReadToken).block());
645+
}
646+
627647
@Test
628648
public void jwsKeySelectorWhenNoAlgorithmThenReturnsRS256Selector() {
629649
ReactiveRemoteJWKSource jwkSource = mock(ReactiveRemoteJWKSource.class);

0 commit comments

Comments
 (0)