diff --git a/src/hex_cli_auth.erl b/src/hex_cli_auth.erl index 7c2c1016..11a7fdd1 100644 --- a/src/hex_cli_auth.erl +++ b/src/hex_cli_auth.erl @@ -8,7 +8,7 @@ %% == Callbacks == %% %% Callbacks are provided via the `cli_auth_callbacks' key in the config map. -%% All callbacks are required: +%% All callbacks below are required unless marked optional: %% %% ``` %% #{ @@ -28,6 +28,13 @@ %% RefreshToken :: binary() | undefined, %% ExpiresAt :: integer()) -> ok), %% +%% %% Invalidate the stored global OAuth token after it expired and could +%% %% not be refreshed (optional). Lets the build tool drop the unusable +%% %% token so concurrent and subsequent callers stop retrying the doomed +%% %% refresh, and warn the user. Invoked at most once per resolution, while +%% %% holding the token-refresh lock. +%% clear_oauth_tokens => fun(() -> ok), +%% %% %% User interaction %% prompt_otp => fun((Message :: binary()) -> {ok, OtpCode :: binary()} | cancelled), %% should_authenticate => fun((Reason :: no_credentials | token_refresh_failed) -> boolean()), @@ -112,6 +119,7 @@ ExpiresAt :: integer() ) -> ok ), + clear_oauth_tokens => fun(() -> ok), prompt_otp := fun((Message :: binary()) -> {ok, OtpCode :: binary()} | cancelled), should_authenticate := fun((Reason :: auth_prompt_reason()) -> boolean()), get_client_id := fun(() -> binary()) @@ -600,7 +608,7 @@ do_resolve_oauth_token_with_context(Config) -> is_binary(maps:get(refresh_token, Tokens)), case is_token_expired(ExpiresAt) of true -> - maybe_refresh_token_with_context(Config, Tokens); + refresh_or_clear(Config, Tokens); false -> BearerToken = <<"Bearer ", AccessToken/binary>>, {ok, BearerToken, #{source => oauth, has_refresh_token => HasRefreshToken}} @@ -609,6 +617,21 @@ do_resolve_oauth_token_with_context(Config) -> {error, no_auth} end. +%% @private +%% Refresh an expired global token; if the refresh fails, invalidate the stored +%% token via the optional clear_oauth_tokens callback. This runs inside the +%% token_refresh lock, so the unusable token is dropped exactly once and the +%% callers serialized behind the lock re-read it as absent instead of each +%% retrying the doomed refresh against the server. +refresh_or_clear(Config, Tokens) -> + case maybe_refresh_token_with_context(Config, Tokens) of + {ok, _Bearer, _Ctx} = Ok -> + Ok; + {error, _} = Error -> + maybe_call_callback(Config, clear_oauth_tokens, []), + Error + end. + %% @private maybe_refresh_token_with_context(Config, #{refresh_token := RefreshToken}) when is_binary(RefreshToken) @@ -756,3 +779,13 @@ call_callback(Config, Name, Args) -> #{cli_auth_callbacks := Callbacks} = Config, Fun = maps:get(Name, Callbacks), erlang:apply(Fun, Args). + +%% @private +%% Like call_callback/3 but for optional callbacks: returns ok without doing +%% anything when the callback is not provided. +maybe_call_callback(Config, Name, Args) -> + #{cli_auth_callbacks := Callbacks} = Config, + case maps:find(Name, Callbacks) of + {ok, Fun} -> erlang:apply(Fun, Args); + error -> ok + end. diff --git a/test/hex_cli_auth_SUITE.erl b/test/hex_cli_auth_SUITE.erl index dc2d61d9..732da0a7 100644 --- a/test/hex_cli_auth_SUITE.erl +++ b/test/hex_cli_auth_SUITE.erl @@ -71,6 +71,7 @@ all() -> %% concurrency tests resolve_oauth_token_concurrent_refresh_serialized_test, + resolve_oauth_token_refresh_failure_clears_once_test, device_auth_concurrent_serialized_reuses_login_test ]. @@ -932,6 +933,93 @@ resolve_oauth_token_concurrent_refresh_serialized_test(_Config) -> end, ok. +resolve_oauth_token_refresh_failure_clears_once_test(_Config) -> + %% Several concurrent callers share one expired global token whose refresh + %% the server rejects (400). The first failure must invalidate the token via + %% clear_oauth_tokens while holding the token-refresh lock; every caller + %% serialized behind the lock then re-reads it as absent instead of each + %% re-POSTing to /oauth/token. So the refresh is attempted exactly once. + Now = erlang:system_time(second), + NumCallers = 5, + Self = self(), + ClearCount = counters:new(1, [atomics]), + + %% Shared token store: starts with the expired token, emptied by the first + %% (and only) clear so subsequent callers see no credentials. + TokenStore = ets:new(token_store, [public, set]), + true = ets:insert( + TokenStore, + {oauth_tokens, + {ok, #{ + access_token => <<"expired_token">>, + refresh_token => <<"refresh_token">>, + expires_at => Now - 100 + }}} + ), + + Config = config_with_callbacks(#{ + get_oauth_tokens => fun() -> + [{oauth_tokens, Tokens}] = ets:lookup(TokenStore, oauth_tokens), + Tokens + end, + clear_oauth_tokens => fun() -> + %% Slow clear: a missing lock or missing re-read would let other + %% callers race in and refresh again, tripping the count assertion. + timer:sleep(50), + ets:insert(TokenStore, {oauth_tokens, error}), + counters:add(ClearCount, 1, 1), + ok + end + }), + + %% Each caller plants a 400 response for its own refresh request. Only the + %% caller that wins the lock actually performs the refresh and consumes it. + FailResponse = + {ok, + {400, #{<<"content-type">> => <<"application/vnd.hex+erlang; charset=utf-8">>}, + term_to_binary(#{<<"error">> => <<"invalid_grant">>})}}, + + [ + spawn(fun() -> + self() ! {hex_http_test, oauth_refresh_response, FailResponse}, + Self ! {ready, self()}, + receive + go -> ok + end, + Result = hex_cli_auth:resolve_api_auth(read, Config), + Self ! {result, Result} + end) + || _ <- lists:seq(1, NumCallers) + ], + + %% Barrier: release all callers together to maximize the race. + Pids = [ + receive + {ready, Pid} -> Pid + after 1000 -> + error(caller_not_ready) + end + || _ <- lists:seq(1, NumCallers) + ], + [Pid ! go || Pid <- Pids], + + Results = [ + receive + {result, R} -> R + after 5000 -> + error(caller_timed_out) + end + || _ <- lists:seq(1, NumCallers) + ], + + %% The token was cleared exactly once => exactly one refresh POST happened. + ?assertEqual(1, counters:get(ClearCount, 1)), + %% No caller obtained a usable token. + [?assertMatch({error, _}, R) || R <- Results], + + ets:delete(TokenStore), + ok. + device_auth_concurrent_serialized_reuses_login_test(_Config) -> %% Multiple concurrent callers that all need to authenticate via device auth %% must serialize: the FIRST caller runs the device auth flow, persists the @@ -1091,6 +1179,7 @@ make_callbacks(Opts) -> PromptOtp = maps:get(prompt_otp, Opts, fun(_) -> cancelled end), ShouldAuthenticate = maps:get(should_authenticate, Opts, fun(_) -> false end), PersistFn = maps:get(persist_oauth_tokens, Opts, fun(_, _, _, _) -> ok end), + ClearFn = maps:get(clear_oauth_tokens, Opts, fun() -> ok end), DefaultGetOAuthTokens = fun() -> maps:get(oauth_tokens, Opts, error) end, GetOAuthTokensFn = maps:get(get_oauth_tokens, Opts, DefaultGetOAuthTokens), @@ -1098,6 +1187,7 @@ make_callbacks(Opts) -> get_auth_config => fun(RepoName) -> maps:get(RepoName, AuthConfig, undefined) end, get_oauth_tokens => GetOAuthTokensFn, persist_oauth_tokens => PersistFn, + clear_oauth_tokens => ClearFn, prompt_otp => PromptOtp, should_authenticate => ShouldAuthenticate, get_client_id => fun() -> <<"test_client">> end diff --git a/test/support/hex_http_test.erl b/test/support/hex_http_test.erl index 6de9d2e9..01a88e76 100644 --- a/test/support/hex_http_test.erl +++ b/test/support/hex_http_test.erl @@ -371,16 +371,21 @@ fixture(post, <>, _, {_, Body}) -> }, {ok, {200, api_headers(), term_to_binary(Payload)}}; <<"refresh_token">> -> - % Simulate successful token refresh - NewAccessToken = base64:encode(crypto:strong_rand_bytes(32)), - NewRefreshToken = base64:encode(crypto:strong_rand_bytes(32)), - Payload = #{ - <<"access_token">> => NewAccessToken, - <<"refresh_token">> => NewRefreshToken, - <<"token_type">> => <<"Bearer">>, - <<"expires_in">> => 3600 - }, - {ok, {200, api_headers(), term_to_binary(Payload)}}; + receive + {hex_http_test, oauth_refresh_response, Response} -> + Response + after 0 -> + % Default: simulate successful token refresh + NewAccessToken = base64:encode(crypto:strong_rand_bytes(32)), + NewRefreshToken = base64:encode(crypto:strong_rand_bytes(32)), + Payload = #{ + <<"access_token">> => NewAccessToken, + <<"refresh_token">> => NewRefreshToken, + <<"token_type">> => <<"Bearer">>, + <<"expires_in">> => 3600 + }, + {ok, {200, api_headers(), term_to_binary(Payload)}} + end; <<"client_credentials">> -> % Simulate successful client credentials token exchange #{<<"scope">> := Scope} = DecodedBody,