Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/hex_cli_auth.erl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
%% == Callbacks ==
%%
%% Callbacks are provided via the `cli_auth_callbacks' key in the config map.
%% All callbacks are required:
%% All callbacks below are required unless marked optional:
%%
%% ```
%% #{
Expand All @@ -28,6 +28,13 @@
%% RefreshToken :: binary() | undefined,
%% ExpiresAt :: integer()) -> ok),
%%
%% %% Invalidate the stored global OAuth token after it expired and could
%% %% not be refreshed (optional). Lets the build tool drop the unusable
%% %% token so concurrent and subsequent callers stop retrying the doomed
%% %% refresh, and warn the user. Invoked at most once per resolution, while
%% %% holding the token-refresh lock.
%% clear_oauth_tokens => fun(() -> ok),
%%
%% %% User interaction
%% prompt_otp => fun((Message :: binary()) -> {ok, OtpCode :: binary()} | cancelled),
%% should_authenticate => fun((Reason :: no_credentials | token_refresh_failed) -> boolean()),
Expand Down Expand Up @@ -112,6 +119,7 @@
ExpiresAt :: integer()
) -> ok
),
clear_oauth_tokens => fun(() -> ok),
prompt_otp := fun((Message :: binary()) -> {ok, OtpCode :: binary()} | cancelled),
should_authenticate := fun((Reason :: auth_prompt_reason()) -> boolean()),
get_client_id := fun(() -> binary())
Expand Down Expand Up @@ -600,7 +608,7 @@ do_resolve_oauth_token_with_context(Config) ->
is_binary(maps:get(refresh_token, Tokens)),
case is_token_expired(ExpiresAt) of
true ->
maybe_refresh_token_with_context(Config, Tokens);
refresh_or_clear(Config, Tokens);
false ->
BearerToken = <<"Bearer ", AccessToken/binary>>,
{ok, BearerToken, #{source => oauth, has_refresh_token => HasRefreshToken}}
Expand All @@ -609,6 +617,21 @@ do_resolve_oauth_token_with_context(Config) ->
{error, no_auth}
end.

%% @private
%% Refresh an expired global token; if the refresh fails, invalidate the stored
%% token via the optional clear_oauth_tokens callback. This runs inside the
%% token_refresh lock, so the unusable token is dropped exactly once and the
%% callers serialized behind the lock re-read it as absent instead of each
%% retrying the doomed refresh against the server.
refresh_or_clear(Config, Tokens) ->
case maybe_refresh_token_with_context(Config, Tokens) of
{ok, _Bearer, _Ctx} = Ok ->
Ok;
{error, _} = Error ->
maybe_call_callback(Config, clear_oauth_tokens, []),
Error
end.

%% @private
maybe_refresh_token_with_context(Config, #{refresh_token := RefreshToken}) when
is_binary(RefreshToken)
Expand Down Expand Up @@ -756,3 +779,13 @@ call_callback(Config, Name, Args) ->
#{cli_auth_callbacks := Callbacks} = Config,
Fun = maps:get(Name, Callbacks),
erlang:apply(Fun, Args).

%% @private
%% Like call_callback/3 but for optional callbacks: returns ok without doing
%% anything when the callback is not provided.
maybe_call_callback(Config, Name, Args) ->
#{cli_auth_callbacks := Callbacks} = Config,
case maps:find(Name, Callbacks) of
{ok, Fun} -> erlang:apply(Fun, Args);
error -> ok
end.
90 changes: 90 additions & 0 deletions test/hex_cli_auth_SUITE.erl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ all() ->

%% concurrency tests
resolve_oauth_token_concurrent_refresh_serialized_test,
resolve_oauth_token_refresh_failure_clears_once_test,
device_auth_concurrent_serialized_reuses_login_test
].

Expand Down Expand Up @@ -932,6 +933,93 @@ resolve_oauth_token_concurrent_refresh_serialized_test(_Config) ->
end,
ok.

resolve_oauth_token_refresh_failure_clears_once_test(_Config) ->
%% Several concurrent callers share one expired global token whose refresh
%% the server rejects (400). The first failure must invalidate the token via
%% clear_oauth_tokens while holding the token-refresh lock; every caller
%% serialized behind the lock then re-reads it as absent instead of each
%% re-POSTing to /oauth/token. So the refresh is attempted exactly once.
Now = erlang:system_time(second),
NumCallers = 5,
Self = self(),
ClearCount = counters:new(1, [atomics]),

%% Shared token store: starts with the expired token, emptied by the first
%% (and only) clear so subsequent callers see no credentials.
TokenStore = ets:new(token_store, [public, set]),
true = ets:insert(
TokenStore,
{oauth_tokens,
{ok, #{
access_token => <<"expired_token">>,
refresh_token => <<"refresh_token">>,
expires_at => Now - 100
}}}
),

Config = config_with_callbacks(#{
get_oauth_tokens => fun() ->
[{oauth_tokens, Tokens}] = ets:lookup(TokenStore, oauth_tokens),
Tokens
end,
clear_oauth_tokens => fun() ->
%% Slow clear: a missing lock or missing re-read would let other
%% callers race in and refresh again, tripping the count assertion.
timer:sleep(50),
ets:insert(TokenStore, {oauth_tokens, error}),
counters:add(ClearCount, 1, 1),
ok
end
}),

%% Each caller plants a 400 response for its own refresh request. Only the
%% caller that wins the lock actually performs the refresh and consumes it.
FailResponse =
{ok,
{400, #{<<"content-type">> => <<"application/vnd.hex+erlang; charset=utf-8">>},
term_to_binary(#{<<"error">> => <<"invalid_grant">>})}},

[
spawn(fun() ->
self() ! {hex_http_test, oauth_refresh_response, FailResponse},
Self ! {ready, self()},
receive
go -> ok
end,
Result = hex_cli_auth:resolve_api_auth(read, Config),
Self ! {result, Result}
end)
|| _ <- lists:seq(1, NumCallers)
],

%% Barrier: release all callers together to maximize the race.
Pids = [
receive
{ready, Pid} -> Pid
after 1000 ->
error(caller_not_ready)
end
|| _ <- lists:seq(1, NumCallers)
],
[Pid ! go || Pid <- Pids],

Results = [
receive
{result, R} -> R
after 5000 ->
error(caller_timed_out)
end
|| _ <- lists:seq(1, NumCallers)
],

%% The token was cleared exactly once => exactly one refresh POST happened.
?assertEqual(1, counters:get(ClearCount, 1)),
%% No caller obtained a usable token.
[?assertMatch({error, _}, R) || R <- Results],

ets:delete(TokenStore),
ok.

device_auth_concurrent_serialized_reuses_login_test(_Config) ->
%% Multiple concurrent callers that all need to authenticate via device auth
%% must serialize: the FIRST caller runs the device auth flow, persists the
Expand Down Expand Up @@ -1091,13 +1179,15 @@ make_callbacks(Opts) ->
PromptOtp = maps:get(prompt_otp, Opts, fun(_) -> cancelled end),
ShouldAuthenticate = maps:get(should_authenticate, Opts, fun(_) -> false end),
PersistFn = maps:get(persist_oauth_tokens, Opts, fun(_, _, _, _) -> ok end),
ClearFn = maps:get(clear_oauth_tokens, Opts, fun() -> ok end),
DefaultGetOAuthTokens = fun() -> maps:get(oauth_tokens, Opts, error) end,
GetOAuthTokensFn = maps:get(get_oauth_tokens, Opts, DefaultGetOAuthTokens),

#{
get_auth_config => fun(RepoName) -> maps:get(RepoName, AuthConfig, undefined) end,
get_oauth_tokens => GetOAuthTokensFn,
persist_oauth_tokens => PersistFn,
clear_oauth_tokens => ClearFn,
prompt_otp => PromptOtp,
should_authenticate => ShouldAuthenticate,
get_client_id => fun() -> <<"test_client">> end
Expand Down
25 changes: 15 additions & 10 deletions test/support/hex_http_test.erl
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,21 @@ fixture(post, <<?TEST_API_URL, "/oauth/token">>, _, {_, Body}) ->
},
{ok, {200, api_headers(), term_to_binary(Payload)}};
<<"refresh_token">> ->
% Simulate successful token refresh
NewAccessToken = base64:encode(crypto:strong_rand_bytes(32)),
NewRefreshToken = base64:encode(crypto:strong_rand_bytes(32)),
Payload = #{
<<"access_token">> => NewAccessToken,
<<"refresh_token">> => NewRefreshToken,
<<"token_type">> => <<"Bearer">>,
<<"expires_in">> => 3600
},
{ok, {200, api_headers(), term_to_binary(Payload)}};
receive
{hex_http_test, oauth_refresh_response, Response} ->
Response
after 0 ->
% Default: simulate successful token refresh
NewAccessToken = base64:encode(crypto:strong_rand_bytes(32)),
NewRefreshToken = base64:encode(crypto:strong_rand_bytes(32)),
Payload = #{
<<"access_token">> => NewAccessToken,
<<"refresh_token">> => NewRefreshToken,
<<"token_type">> => <<"Bearer">>,
<<"expires_in">> => 3600
},
{ok, {200, api_headers(), term_to_binary(Payload)}}
end;
<<"client_credentials">> ->
% Simulate successful client credentials token exchange
#{<<"scope">> := Scope} = DecodedBody,
Expand Down
Loading