Skip to content

Commit 30088d2

Browse files
authored
fix(auth): Address ClientSideCredentialAccessBoundary RefreshTask race condition (#12681)
This change addresses a race condition in ClientSideCredentialAccessBoundaryFactory that occurred when multiple concurrent calls were made to generateToken. The fix involves: - Waiting on the RefreshTask itself rather than its internal task. - Using a single listener in RefreshTask to ensure finishRefreshTask completes before the outer future unblocks waiting threads. - Adding a regression test generateToken_freshInstance_concurrent_noNpe.
1 parent 7871849 commit 30088d2

2 files changed

Lines changed: 62 additions & 22 deletions

File tree

google-auth-library-java/cab-token-generator/java/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactory.java

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import com.google.common.annotations.VisibleForTesting;
5454
import com.google.common.base.Strings;
5555
import com.google.common.util.concurrent.AbstractFuture;
56-
import com.google.common.util.concurrent.FutureCallback;
5756
import com.google.common.util.concurrent.Futures;
5857
import com.google.common.util.concurrent.ListenableFuture;
5958
import com.google.common.util.concurrent.ListenableFutureTask;
@@ -79,7 +78,6 @@
7978
import java.util.Date;
8079
import java.util.List;
8180
import java.util.concurrent.ExecutionException;
82-
import javax.annotation.Nullable;
8381

8482
/**
8583
* A factory for generating downscoped access tokens using a client-side approach.
@@ -248,7 +246,7 @@ void refreshCredentialsIfRequired() throws IOException {
248246
}
249247
try {
250248
// Wait for the refresh task to complete.
251-
currentRefreshTask.task.get();
249+
currentRefreshTask.get();
252250
} catch (InterruptedException e) {
253251
// Restore the interrupted status and throw an exception.
254252
Thread.currentThread().interrupt();
@@ -495,31 +493,18 @@ class RefreshTask extends AbstractFuture<IntermediateCredentials> implements Run
495493
this.task = task;
496494
this.isNew = isNew;
497495

498-
// Add listener to update factory's credentials when the task completes.
496+
// Single listener to guarantee that finishRefreshTask updates the internal state BEFORE
497+
// the outer future completes and unblocks waiters.
499498
task.addListener(
500499
() -> {
501500
try {
502501
finishRefreshTask(task);
502+
RefreshTask.this.set(Futures.getDone(task));
503503
} catch (ExecutionException e) {
504504
Throwable cause = e.getCause();
505-
RefreshTask.this.setException(cause);
506-
}
507-
},
508-
MoreExecutors.directExecutor());
509-
510-
// Add callback to set the result or exception based on the outcome.
511-
Futures.addCallback(
512-
task,
513-
new FutureCallback<IntermediateCredentials>() {
514-
@Override
515-
public void onSuccess(IntermediateCredentials result) {
516-
RefreshTask.this.set(result);
517-
}
518-
519-
@Override
520-
public void onFailure(@Nullable Throwable t) {
521-
RefreshTask.this.setException(
522-
t != null ? t : new IOException("Refresh failed with null Throwable."));
505+
RefreshTask.this.setException(cause != null ? cause : e);
506+
} catch (Throwable t) {
507+
RefreshTask.this.setException(t);
523508
}
524509
},
525510
MoreExecutors.directExecutor());

google-auth-library-java/cab-token-generator/javatests/com/google/auth/credentialaccessboundary/ClientSideCredentialAccessBoundaryFactoryTest.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,4 +988,59 @@ void generateToken_withMalformSessionKey_failure() throws Exception {
988988

989989
assertThrows(GeneralSecurityException.class, () -> factory.generateToken(accessBoundary));
990990
}
991+
992+
@Test
993+
void generateToken_freshInstance_concurrent_noNpe() throws Exception {
994+
for (int run = 0; run < 10; run++) { // Run 10 times in a single test instance to save time
995+
GoogleCredentials sourceCredentials =
996+
getServiceAccountSourceCredentials(mockTokenServerTransportFactory);
997+
ClientSideCredentialAccessBoundaryFactory factory =
998+
ClientSideCredentialAccessBoundaryFactory.newBuilder()
999+
.setSourceCredential(sourceCredentials)
1000+
.setHttpTransportFactory(mockStsTransportFactory)
1001+
.build();
1002+
1003+
CredentialAccessBoundary.Builder cabBuilder = CredentialAccessBoundary.newBuilder();
1004+
CredentialAccessBoundary accessBoundary =
1005+
cabBuilder
1006+
.addRule(
1007+
CredentialAccessBoundary.AccessBoundaryRule.newBuilder()
1008+
.setAvailableResource("resource")
1009+
.setAvailablePermissions(ImmutableList.of("role"))
1010+
.build())
1011+
.build();
1012+
1013+
int numThreads = 5;
1014+
CountDownLatch latch = new CountDownLatch(numThreads);
1015+
java.util.concurrent.atomic.AtomicInteger npeCount =
1016+
new java.util.concurrent.atomic.AtomicInteger();
1017+
java.util.concurrent.ExecutorService executor =
1018+
java.util.concurrent.Executors.newFixedThreadPool(numThreads);
1019+
1020+
try {
1021+
for (int i = 0; i < numThreads; i++) {
1022+
executor.submit(
1023+
() -> {
1024+
try {
1025+
latch.countDown();
1026+
latch.await();
1027+
factory.generateToken(accessBoundary);
1028+
} catch (NullPointerException e) {
1029+
npeCount.incrementAndGet();
1030+
} catch (Exception e) {
1031+
// Ignore other exceptions for the sake of the race reproduction
1032+
}
1033+
});
1034+
}
1035+
} finally {
1036+
executor.shutdown();
1037+
executor.awaitTermination(5, java.util.concurrent.TimeUnit.SECONDS);
1038+
}
1039+
1040+
org.junit.jupiter.api.Assertions.assertEquals(
1041+
0,
1042+
npeCount.get(),
1043+
"Expected zero NullPointerExceptions due to the race condition, but some were thrown.");
1044+
}
1045+
}
9911046
}

0 commit comments

Comments
 (0)