Skip to content

Commit ab4bd8c

Browse files
committed
feat: enable refresh token rotation for salesforce
previously, we were not using the new refresh token if one was provided
1 parent feb3755 commit ab4bd8c

4 files changed

Lines changed: 106 additions & 17 deletions

File tree

apps/api/src/app/routes/route.middleware.ts

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,14 +357,14 @@ export async function getOrgForRequest(
357357
}
358358

359359
// Handle org refresh - then remove event listener if refreshed
360-
const handleRefresh = async (accessToken: string, refreshToken: string) => {
360+
const handleRefresh = async (newAccessToken: string, newRefreshToken: string) => {
361361
// Refresh event will be fired when renewed access token
362362
// to store it in your storage for next request
363363
try {
364-
if (!refreshToken) {
364+
if (!newRefreshToken) {
365365
return;
366366
}
367-
await salesforceOrgsDb.updateAccessToken_UNSAFE({ accessToken, refreshToken, org, userId: user.id });
367+
await salesforceOrgsDb.updateAccessToken_UNSAFE({ accessToken: newAccessToken, refreshToken: newRefreshToken, org, userId: user.id });
368368
} catch (ex) {
369369
logger.error({ requestId, ...getExceptionLog(ex) }, '[ORG][REFRESH] Error saving refresh token');
370370
}
@@ -378,6 +378,25 @@ export async function getOrgForRequest(
378378
}
379379
};
380380

381+
// Re-reads current tokens from DB so concurrent workers that lose the refresh token rotation race
382+
// can retry with the tokens written by the worker that won.
383+
const getFreshTokens = async () => {
384+
try {
385+
const freshOrg = await salesforceOrgsDb.findByUniqueId_UNSAFE(user.id, uniqueId);
386+
if (!freshOrg) {
387+
return null;
388+
}
389+
const [freshAccessToken, freshRefreshToken] = await sfdcEncService.decryptAccessToken({
390+
encryptedAccessToken: freshOrg.accessToken,
391+
userId: user.id,
392+
});
393+
return { accessToken: freshAccessToken, refreshToken: freshRefreshToken };
394+
} catch (ex) {
395+
logger.error({ requestId, ...getExceptionLog(ex) }, '[ORG][REFRESH] Error fetching fresh tokens for race condition check');
396+
return null;
397+
}
398+
};
399+
381400
const jetstreamConn = new ApiConnection(
382401
{
383402
apiRequestAdapter: getApiRequestFactoryFn(fetch),
@@ -392,6 +411,7 @@ export async function getOrgForRequest(
392411
logger,
393412
sfdcClientId: ENV.SFDC_CONSUMER_KEY,
394413
sfdcClientSecret: ENV.SFDC_CONSUMER_SECRET,
414+
getFreshTokens,
395415
},
396416
handleRefresh,
397417
handleConnectionError,
@@ -574,7 +594,6 @@ export function setPermissionPolicy(_req: express.Request, res: express.Response
574594
next();
575595
}
576596

577-
578597
export function setCacheControlForApiRoutes(_req: express.Request, res: express.Response, next: express.NextFunction) {
579598
res.setHeader('Cache-Control', 'no-store, no-cache, must-revalidate, proxy-revalidate, max-age=0');
580599
next();

apps/jetstream-desktop/src/utils/route.utils.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,14 +200,14 @@ export function initApiConnection(
200200
const accessToken = plaintext.slice(0, spaceIndex);
201201
const refreshToken = plaintext.slice(spaceIndex + 1);
202202

203-
const handleRefresh = async (accessToken: string, refreshToken: string) => {
203+
const handleRefresh = async (newAccessToken: string, newRefreshToken: string) => {
204204
// Refresh event will be fired when renewed access token
205205
// to store it in your storage for next request
206206
try {
207-
if (!refreshToken) {
207+
if (!newRefreshToken) {
208208
return;
209209
}
210-
await updateAccessTokens(org.uniqueId, { accessToken, refreshToken });
210+
await updateAccessTokens(org.uniqueId, { accessToken: newAccessToken, refreshToken: newRefreshToken });
211211
} catch (ex) {
212212
logger.error('[ORG][REFRESH] Error saving refresh token', getErrorMessage(ex));
213213
}
@@ -221,6 +221,26 @@ export function initApiConnection(
221221
}
222222
};
223223

224+
// Re-reads current tokens from the in-memory store so concurrent requests that lose the
225+
// refresh token rotation race can retry with the tokens written by the request that won.
226+
const getFreshTokens = async () => {
227+
try {
228+
const freshOrg = getSalesforceOrgById(org.uniqueId);
229+
if (!freshOrg) {
230+
return null;
231+
}
232+
const plaintext = decryptTokenPortable(freshOrg.accessToken);
233+
const spaceIndex = plaintext.indexOf(' ');
234+
if (spaceIndex === -1) {
235+
return null;
236+
}
237+
return { accessToken: plaintext.slice(0, spaceIndex), refreshToken: plaintext.slice(spaceIndex + 1) };
238+
} catch (ex) {
239+
logger.error('[ORG][REFRESH] Error fetching fresh tokens for race condition check', getErrorMessage(ex));
240+
return null;
241+
}
242+
};
243+
224244
const jetstreamConn = new ApiConnection(
225245
{
226246
apiRequestAdapter: getApiRequestFactoryFn(fetch),
@@ -234,6 +254,7 @@ export function initApiConnection(
234254
logger: logger as any,
235255
enableLogging: false,
236256
sfdcClientId: ENV.DESKTOP_SFDC_CLIENT_ID,
257+
getFreshTokens,
237258
},
238259
handleRefresh,
239260
handleConnectionError,

libs/salesforce-api/src/lib/callout-adapter.ts

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
import { ERROR_MESSAGES, HTTP } from '@jetstream/shared/constants';
2+
import { getErrorMessageAndStackObj } from '@jetstream/shared/utils';
23
import { parse } from '@jetstreamapp/simple-xml';
34
import isObject from 'lodash/isObject';
4-
import { ApiRequestOptions, ApiRequestOutputType, BulkXmlErrorResponse, FetchFn, FetchResponse, Logger, SoapErrorResponse } from './types';
5+
import {
6+
ApiRequestOptions,
7+
ApiRequestOutputType,
8+
BulkXmlErrorResponse,
9+
FetchFn,
10+
FetchResponse,
11+
Logger,
12+
SessionInfo,
13+
SoapErrorResponse,
14+
} from './types';
515

616
const SOAP_API_AUTH_ERROR_REGEX = /<faultcode>[a-zA-Z]+:INVALID_SESSION_ID<\/faultcode>/;
717
// Shows up for certain API requests, such as Identity
@@ -46,9 +56,14 @@ function parseXml(value: string) {
4656
export function getApiRequestFactoryFn(fetch: FetchFn) {
4757
return (
4858
logger: Logger,
49-
onRefresh?: (accessToken: string) => void,
59+
onRefresh?: (accessToken: string, refreshToken?: string) => void,
5060
onConnectionError?: (accessToken: string) => void,
61+
/**
62+
* Enable logging only applies to request/response data
63+
* other logging for refresh flow and logic errors will still be logged
64+
*/
5165
enableLogging?: boolean,
66+
getFreshTokens?: () => Promise<Pick<SessionInfo, 'accessToken' | 'refreshToken'> | null>,
5267
) => {
5368
const apiRequest = async <Response = unknown>(options: ApiRequestOptions, attemptRefresh = true): Promise<Response> => {
5469
// eslint-disable-next-line prefer-const
@@ -125,18 +140,37 @@ export function getApiRequestFactoryFn(fetch: FetchFn) {
125140
sessionInfo.refreshToken
126141
) {
127142
try {
128-
// if 401 and we have a refresh token, then attempt to refresh the token
129-
const { access_token: newAccessToken } = await exchangeRefreshToken(fetch, sessionInfo);
130-
onRefresh?.(newAccessToken);
143+
logger.debug({ url, method, status: response.status }, '[TOKEN REFRESH] Attempting token refresh');
144+
const { access_token: newAccessToken, refresh_token: newRefreshToken } = await exchangeRefreshToken(fetch, sessionInfo);
145+
logger.debug({ url, method, tokenRotated: !!newRefreshToken }, '[TOKEN REFRESH] Token refresh successful');
146+
onRefresh?.(newAccessToken, newRefreshToken);
131147
// replace token in body
132148
if (typeof options.body === 'string' && options.body.includes(accessToken)) {
133149
// if the response is soap, we need to return the response as is
134150
options.body = options.body.replace(accessToken, newAccessToken);
135151
}
136152

137153
return apiRequest({ ...options, sessionInfo: { ...sessionInfo, accessToken: newAccessToken } }, false);
138-
} catch {
139-
logger.warn('Unable to refresh accessToken');
154+
} catch (ex) {
155+
logger.warn({ url, method, ...getErrorMessageAndStackObj(ex) }, '[TOKEN REFRESH] Unable to refresh accessToken');
156+
157+
// Check if another worker already refreshed (race condition on token rotation).
158+
// If the DB has a different access token, a concurrent request won the race — retry with fresh tokens.
159+
if (getFreshTokens) {
160+
try {
161+
const freshTokens = await getFreshTokens();
162+
if (freshTokens && freshTokens.accessToken !== accessToken) {
163+
logger.info({ url, method }, '[TOKEN REFRESH] Concurrent refresh detected — retrying with tokens from another worker');
164+
return apiRequest({ ...options, sessionInfo: { ...sessionInfo, ...freshTokens } }, false);
165+
}
166+
} catch (freshEx) {
167+
logger.warn(
168+
{ url, method, ...getErrorMessageAndStackObj(freshEx) },
169+
'[TOKEN REFRESH] Failed to retrieve fresh tokens for race condition check',
170+
);
171+
}
172+
}
173+
140174
responseText = ERROR_MESSAGES.SFDC_EXPIRED_TOKEN;
141175
onConnectionError?.(ERROR_MESSAGES.SFDC_EXPIRED_TOKEN);
142176
}
@@ -192,7 +226,10 @@ function handleSalesforceApiError(outputType: ApiRequestOutputType, responseText
192226
return output;
193227
}
194228

195-
function exchangeRefreshToken(fetch: FetchFn, sessionInfo: ApiRequestOptions['sessionInfo']): Promise<{ access_token: string }> {
229+
function exchangeRefreshToken(
230+
fetch: FetchFn,
231+
sessionInfo: ApiRequestOptions['sessionInfo'],
232+
): Promise<{ access_token: string; refresh_token?: string }> {
196233
const searchParams = new URLSearchParams({
197234
grant_type: 'refresh_token',
198235
});

libs/salesforce-api/src/lib/connection.ts

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ export interface ApiConnectionOptions {
2323
sfdcClientId?: string;
2424
sfdcClientSecret?: string;
2525
logger: Logger;
26+
/** Re-reads current tokens from the source of truth (e.g. DB) to handle concurrent refresh token rotation across workers */
27+
getFreshTokens?: () => Promise<Pick<SessionInfo, 'accessToken' | 'refreshToken'> | null>;
2628
}
2729

2830
export class ApiConnection {
@@ -56,12 +58,19 @@ export class ApiConnection {
5658
sfdcClientId,
5759
sfdcClientSecret,
5860
logger,
61+
getFreshTokens,
5962
}: ApiConnectionOptions,
6063
refreshCallback?: (accessToken: string, refreshToken: string) => void,
6164
onConnectionError?: (error: string) => void,
6265
) {
6366
this.logger = logger;
64-
this.apiRequest = apiRequestAdapter(logger, this.handleRefresh.bind(this), this.handleConnectionError.bind(this), enableLogging);
67+
this.apiRequest = apiRequestAdapter(
68+
logger,
69+
this.handleRefresh.bind(this),
70+
this.handleConnectionError.bind(this),
71+
enableLogging,
72+
getFreshTokens,
73+
);
6574
this.refreshCallback = refreshCallback;
6675
this.onConnectionError = onConnectionError;
6776
this.sessionInfo = {
@@ -122,8 +131,11 @@ export class ApiConnection {
122131
this.sessionInfo.userId = userId ?? this.sessionInfo.userId;
123132
}
124133

125-
public handleRefresh(accessToken: string) {
134+
public handleRefresh(accessToken: string, newRefreshToken?: string) {
126135
this.sessionInfo.accessToken = accessToken;
136+
if (newRefreshToken) {
137+
this.sessionInfo.refreshToken = newRefreshToken;
138+
}
127139
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
128140
this.refreshCallback?.(accessToken, this.sessionInfo.refreshToken!);
129141
}

0 commit comments

Comments
 (0)