Skip to content

Commit e83ba83

Browse files
committed
generic: directly inherit from IHostProvider
Since we want to change how we're handling the GenerateCredentialAsync call (to return more information) we need to 'promote' the generic host provider to inheriting directly from the IHostProvider interface, and not the abstract HostProvider class. Signed-off-by: Matthew John Cheetham <mjcheetham@outlook.com>
1 parent 8b66a29 commit e83ba83

2 files changed

Lines changed: 139 additions & 49 deletions

File tree

src/shared/Core.Tests/GenericHostProviderTests.cs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ public async Task GenericHostProvider_CreateCredentialAsync_WiaNotAllowed_Return
9393

9494
var provider = new GenericHostProvider(context, basicAuthMock.Object, wiaAuthMock.Object, oauthMock.Object);
9595

96-
ICredential credential = await provider.GenerateCredentialAsync(input);
96+
var result = await provider.GenerateCredentialAsync(input);
97+
ICredential credential = result.Credential;
9798

9899
Assert.NotNull(credential);
99100
Assert.Equal(testUserName, credential.Account);
@@ -128,7 +129,8 @@ public async Task GenericHostProvider_CreateCredentialAsync_LegacyAuthorityBasic
128129

129130
var provider = new GenericHostProvider(context, basicAuthMock.Object, wiaAuthMock.Object, oauthMock.Object);
130131

131-
ICredential credential = await provider.GenerateCredentialAsync(input);
132+
var result = await provider.GenerateCredentialAsync(input);
133+
ICredential credential = result.Credential;
132134

133135
Assert.NotNull(credential);
134136
Assert.Equal(testUserName, credential.Account);
@@ -160,7 +162,8 @@ public async Task GenericHostProvider_CreateCredentialAsync_NonHttpProtocol_Retu
160162

161163
var provider = new GenericHostProvider(context, basicAuthMock.Object, wiaAuthMock.Object, oauthMock.Object);
162164

163-
ICredential credential = await provider.GenerateCredentialAsync(input);
165+
var result = await provider.GenerateCredentialAsync(input);
166+
ICredential credential = result.Credential;
164167

165168
Assert.NotNull(credential);
166169
Assert.Equal(testUserName, credential.Account);
@@ -254,7 +257,8 @@ public async Task GenericHostProvider_GenerateCredentialAsync_OAuth_CompleteOAut
254257

255258
var provider = new GenericHostProvider(context, basicAuthMock.Object, wiaAuthMock.Object, oauthMock.Object);
256259

257-
ICredential credential = await provider.GenerateCredentialAsync(input);
260+
var result = await provider.GenerateCredentialAsync(input);
261+
ICredential credential = result.Credential;
258262

259263
Assert.NotNull(credential);
260264
Assert.Equal(testUserName, credential.Account);
@@ -292,7 +296,8 @@ private static async Task TestCreateCredentialAsync_ReturnsEmptyCredential(Windo
292296

293297
var provider = new GenericHostProvider(context, basicAuthMock.Object, wiaAuthMock.Object, oauthMock.Object);
294298

295-
ICredential credential = await provider.GenerateCredentialAsync(input);
299+
var result = await provider.GenerateCredentialAsync(input);
300+
ICredential credential = result.Credential;
296301

297302
Assert.NotNull(credential);
298303
Assert.Equal(string.Empty, credential.Account);
@@ -324,7 +329,8 @@ private static async Task TestCreateCredentialAsync_ReturnsBasicCredential(Windo
324329

325330
var provider = new GenericHostProvider(context, basicAuthMock.Object, wiaAuthMock.Object, oauthMock.Object);
326331

327-
ICredential credential = await provider.GenerateCredentialAsync(input);
332+
var result = await provider.GenerateCredentialAsync(input);
333+
ICredential credential = result.Credential;
328334

329335
Assert.NotNull(credential);
330336
Assert.Equal(testUserName, credential.Account);

src/shared/Core/GenericHostProvider.cs

Lines changed: 127 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
namespace GitCredentialManager
1111
{
12-
public class GenericHostProvider : HostProvider
12+
public class GenericHostProvider : DisposableObject, IHostProvider
1313
{
14+
private readonly ICommandContext _context;
1415
private readonly IBasicAuthentication _basicAuth;
1516
private readonly IWindowsIntegratedAuthentication _winAuth;
1617
private readonly IOAuthAuthentication _oauth;
@@ -23,44 +24,120 @@ public GenericHostProvider(ICommandContext context,
2324
IBasicAuthentication basicAuth,
2425
IWindowsIntegratedAuthentication winAuth,
2526
IOAuthAuthentication oauth)
26-
: base(context)
2727
{
28+
EnsureArgument.NotNull(context, nameof(context));
2829
EnsureArgument.NotNull(basicAuth, nameof(basicAuth));
2930
EnsureArgument.NotNull(winAuth, nameof(winAuth));
3031
EnsureArgument.NotNull(oauth, nameof(oauth));
3132

33+
_context = context;
3234
_basicAuth = basicAuth;
3335
_winAuth = winAuth;
3436
_oauth = oauth;
3537
}
3638

37-
public override string Id => "generic";
39+
public string Id => "generic";
3840

39-
public override string Name => "Generic";
41+
public string Name => "Generic";
4042

41-
public override IEnumerable<string> SupportedAuthorityIds =>
43+
public IEnumerable<string> SupportedAuthorityIds =>
4244
EnumerableExtensions.ConcatMany(
4345
BasicAuthentication.AuthorityIds,
4446
WindowsIntegratedAuthentication.AuthorityIds
4547
);
4648

47-
public override bool IsSupported(InputArguments input)
49+
public bool IsSupported(InputArguments input)
4850
{
4951
// The generic provider should support all possible protocols (HTTP, HTTPS, SMTP, IMAP, etc)
5052
return input != null && !string.IsNullOrWhiteSpace(input.Protocol);
5153
}
5254

53-
public override async Task<ICredential> GenerateCredentialAsync(InputArguments input)
55+
public bool IsSupported(HttpResponseMessage response)
56+
{
57+
return false;
58+
}
59+
60+
public string GetServiceName(InputArguments input)
61+
{
62+
// By default we assume the service name will be the absolute URI based on the
63+
// input arguments from Git, without any userinfo part.
64+
return input.GetRemoteUri(includeUser: false).AbsoluteUri.TrimEnd('/');
65+
}
66+
67+
public async Task<GetCredentialResult> GetCredentialAsync(InputArguments input)
68+
{
69+
// Try and locate an existing credential in the OS credential store
70+
string service = GetServiceName(input);
71+
_context.Trace.WriteLine($"Looking for existing credential in store with service={service} account={input.UserName}...");
72+
73+
ICredential credential = _context.CredentialStore.Get(service, input.UserName);
74+
if (credential == null)
75+
{
76+
_context.Trace.WriteLine("No existing credentials found.");
77+
78+
// No existing credential was found, create a new one
79+
_context.Trace.WriteLine("Creating new credential...");
80+
return await GenerateCredentialAsync(input);
81+
}
82+
else
83+
{
84+
_context.Trace.WriteLine("Existing credential found.");
85+
}
86+
87+
return new GetCredentialResult(credential);
88+
}
89+
90+
public Task StoreCredentialAsync(InputArguments input)
91+
{
92+
string service = GetServiceName(input);
93+
94+
// WIA-authentication is signaled to Git as an empty username/password pair
95+
// and we will get called to 'store' these WIA credentials.
96+
// We avoid storing empty credentials.
97+
if (string.IsNullOrWhiteSpace(input.UserName) && string.IsNullOrWhiteSpace(input.Password))
98+
{
99+
_context.Trace.WriteLine("Not storing empty credential.");
100+
}
101+
else
102+
{
103+
// Add or update the credential in the store.
104+
_context.Trace.WriteLine($"Storing credential with service={service} account={input.UserName}...");
105+
_context.CredentialStore.AddOrUpdate(service, input.UserName, input.Password);
106+
_context.Trace.WriteLine("Credential was successfully stored.");
107+
}
108+
109+
return Task.CompletedTask;
110+
}
111+
112+
public Task EraseCredentialAsync(InputArguments input)
113+
{
114+
string service = GetServiceName(input);
115+
116+
// Try to locate an existing credential
117+
_context.Trace.WriteLine($"Erasing stored credential in store with service={service} account={input.UserName}...");
118+
if (_context.CredentialStore.Remove(service, input.UserName))
119+
{
120+
_context.Trace.WriteLine("Credential was successfully erased.");
121+
}
122+
else
123+
{
124+
_context.Trace.WriteLine("No credential was erased.");
125+
}
126+
127+
return Task.CompletedTask;
128+
}
129+
130+
public async Task<GetCredentialResult> GenerateCredentialAsync(InputArguments input)
54131
{
55132
ThrowIfDisposed();
56133

57134
// We only want to *warn* about HTTP remotes for the generic provider because it supports all protocols
58135
// and, historically, we never blocked HTTP remotes in this provider.
59136
// The user can always set the 'GCM_ALLOW_UNSAFE' setting to silence the warning.
60-
if (!Context.Settings.AllowUnsafeRemotes &&
137+
if (!_context.Settings.AllowUnsafeRemotes &&
61138
StringComparer.OrdinalIgnoreCase.Equals(input.Protocol, "http"))
62139
{
63-
Context.Streams.Error.WriteLine(
140+
_context.Streams.Error.WriteLine(
64141
"warning: use of unencrypted HTTP remote URLs is not recommended; " +
65142
$"see {Constants.HelpUrls.GcmUnsafeRemotes} for more information.");
66143
}
@@ -74,56 +151,63 @@ public override async Task<ICredential> GenerateCredentialAsync(InputArguments i
74151
// Cannot check WIA or OAuth support for non-HTTP based protocols
75152
}
76153
// Check for an OAuth configuration for this remote
77-
else if (GenericOAuthConfig.TryGet(Context.Trace, Context.Settings, input, out GenericOAuthConfig oauthConfig))
154+
else if (GenericOAuthConfig.TryGet(_context.Trace, _context.Settings, input, out GenericOAuthConfig oauthConfig))
78155
{
79-
Context.Trace.WriteLine($"Found generic OAuth configuration for '{uri}':");
80-
Context.Trace.WriteLine($"\tAuthzEndpoint = {oauthConfig.Endpoints.AuthorizationEndpoint}");
81-
Context.Trace.WriteLine($"\tTokenEndpoint = {oauthConfig.Endpoints.TokenEndpoint}");
82-
Context.Trace.WriteLine($"\tDeviceEndpoint = {oauthConfig.Endpoints.DeviceAuthorizationEndpoint}");
83-
Context.Trace.WriteLine($"\tClientId = {oauthConfig.ClientId}");
84-
Context.Trace.WriteLine($"\tClientSecret = {oauthConfig.ClientSecret}");
85-
Context.Trace.WriteLine($"\tRedirectUri = {oauthConfig.RedirectUri}");
86-
Context.Trace.WriteLine($"\tScopes = [{string.Join(", ", oauthConfig.Scopes)}]");
87-
Context.Trace.WriteLine($"\tUseAuthHeader = {oauthConfig.UseAuthHeader}");
88-
Context.Trace.WriteLine($"\tDefaultUserName = {oauthConfig.DefaultUserName}");
89-
90-
return await GetOAuthAccessToken(uri, input.UserName, oauthConfig, Context.Trace2);
156+
_context.Trace.WriteLine($"Found generic OAuth configuration for '{uri}':");
157+
_context.Trace.WriteLine($"\tAuthzEndpoint = {oauthConfig.Endpoints.AuthorizationEndpoint}");
158+
_context.Trace.WriteLine($"\tTokenEndpoint = {oauthConfig.Endpoints.TokenEndpoint}");
159+
_context.Trace.WriteLine($"\tDeviceEndpoint = {oauthConfig.Endpoints.DeviceAuthorizationEndpoint}");
160+
_context.Trace.WriteLine($"\tClientId = {oauthConfig.ClientId}");
161+
_context.Trace.WriteLine($"\tClientSecret = {oauthConfig.ClientSecret}");
162+
_context.Trace.WriteLine($"\tRedirectUri = {oauthConfig.RedirectUri}");
163+
_context.Trace.WriteLine($"\tScopes = [{string.Join(", ", oauthConfig.Scopes)}]");
164+
_context.Trace.WriteLine($"\tUseAuthHeader = {oauthConfig.UseAuthHeader}");
165+
_context.Trace.WriteLine($"\tDefaultUserName = {oauthConfig.DefaultUserName}");
166+
167+
return new GetCredentialResult(
168+
await GetOAuthAccessToken(uri, input.UserName, oauthConfig, _context.Trace2)
169+
);
91170
}
92171
// Try detecting WIA for this remote, if permitted
93172
else if (IsWindowsAuthAllowed)
94173
{
95174
if (PlatformUtils.IsWindows())
96175
{
97-
Context.Trace.WriteLine($"Checking host '{uri.AbsoluteUri}' for Windows Integrated Authentication...");
176+
_context.Trace.WriteLine($"Checking host '{uri.AbsoluteUri}' for Windows Integrated Authentication...");
98177
var supportedWiaTypes = await _winAuth.GetAuthenticationTypesAsync(uri);
99178
bool isWiaSupported = supportedWiaTypes != WindowsAuthenticationTypes.None;
100179

101180
if (!isWiaSupported)
102181
{
103-
Context.Trace.WriteLine("Host does not support WIA.");
182+
_context.Trace.WriteLine("Host does not support WIA.");
104183
}
105184
else
106185
{
107-
Context.Trace.WriteLine("Host supports WIA - generating empty credential...");
186+
_context.Trace.WriteLine("Host supports WIA - generating empty credential...");
108187

109188
// WIA is signaled to Git using an empty username/password
110-
return new GitCredential(string.Empty, string.Empty);
189+
ICredential creds = new GitCredential(string.Empty, string.Empty);
190+
return new GetCredentialResult(creds);
111191
}
112192
}
113193
else
114194
{
115-
string osType = PlatformUtils.GetPlatformInformation(Context.Trace2).OperatingSystemType;
116-
Context.Trace.WriteLine($"Skipping check for Windows Integrated Authentication on {osType}.");
195+
string osType = PlatformUtils.GetPlatformInformation(_context.Trace2).OperatingSystemType;
196+
_context.Trace.WriteLine($"Skipping check for Windows Integrated Authentication on {osType}.");
117197
}
118198
}
119199
else
120200
{
121-
Context.Trace.WriteLine("Windows Integrated Authentication detection has been disabled.");
201+
_context.Trace.WriteLine("Windows Integrated Authentication detection has been disabled.");
122202
}
123203

124204
// Use basic authentication
125-
Context.Trace.WriteLine("Prompting for basic credentials...");
126-
return await _basicAuth.GetCredentialsAsync(uri.AbsoluteUri, input.UserName);
205+
_context.Trace.WriteLine("Prompting for basic credentials...");
206+
return new GetCredentialResult(
207+
await _basicAuth.GetCredentialsAsync(uri.AbsoluteUri, input.UserName)
208+
);
209+
}
210+
127211
}
128212

129213
private async Task<ICredential> GetOAuthAccessToken(Uri remoteUri, string userName, GenericOAuthConfig config, ITrace2 trace2)
@@ -152,7 +236,7 @@ private async Task<ICredential> GetOAuthAccessToken(Uri remoteUri, string userNa
152236
.Uri.AbsoluteUri.TrimEnd('/');
153237

154238
// Try to use a refresh token if we have one
155-
ICredential refreshToken = Context.CredentialStore.Get(refreshService, userName);
239+
ICredential refreshToken = _context.CredentialStore.Get(refreshService, userName);
156240
if (refreshToken != null)
157241
{
158242
try
@@ -162,7 +246,7 @@ private async Task<ICredential> GetOAuthAccessToken(Uri remoteUri, string userNa
162246
// Store new refresh token if we have been given one
163247
if (!string.IsNullOrWhiteSpace(refreshResult.RefreshToken))
164248
{
165-
Context.CredentialStore.AddOrUpdate(refreshService, refreshToken.Account, refreshResult.RefreshToken);
249+
_context.CredentialStore.AddOrUpdate(refreshService, refreshToken.Account, refreshResult.RefreshToken);
166250
}
167251

168252
// Return the new access token
@@ -172,26 +256,26 @@ private async Task<ICredential> GetOAuthAccessToken(Uri remoteUri, string userNa
172256
{
173257
// Failed to use refresh token. It may have expired or been revoked.
174258
// Fall through to an interactive OAuth flow.
175-
Context.Trace.WriteLine("Failed to use refresh token.");
176-
Context.Trace.WriteException(ex);
259+
_context.Trace.WriteLine("Failed to use refresh token.");
260+
_context.Trace.WriteException(ex);
177261
}
178262
}
179263

180264
// Determine which interactive OAuth mode to use. Start by checking for mode preference in config
181265
var supportedModes = OAuthAuthenticationModes.All;
182-
if (Context.Settings.TryGetSetting(
266+
if (_context.Settings.TryGetSetting(
183267
Constants.EnvironmentVariables.OAuthAuthenticationModes,
184268
Constants.GitConfiguration.Credential.SectionName,
185269
Constants.GitConfiguration.Credential.OAuthAuthenticationModes,
186270
out string authModesStr))
187271
{
188272
if (Enum.TryParse(authModesStr, true, out supportedModes) && supportedModes != OAuthAuthenticationModes.None)
189273
{
190-
Context.Trace.WriteLine($"Supported authentication modes override present: {supportedModes}");
274+
_context.Trace.WriteLine($"Supported authentication modes override present: {supportedModes}");
191275
}
192276
else
193277
{
194-
Context.Trace.WriteLine($"Invalid value for supported authentication modes override setting: '{authModesStr}'");
278+
_context.Trace.WriteLine($"Invalid value for supported authentication modes override setting: '{authModesStr}'");
195279
}
196280
}
197281

@@ -216,13 +300,13 @@ private async Task<ICredential> GetOAuthAccessToken(Uri remoteUri, string userNa
216300
break;
217301

218302
default:
219-
throw new Trace2Exception(Context.Trace2, "No authentication mode selected!");
303+
throw new Trace2Exception(_context.Trace2, "No authentication mode selected!");
220304
}
221305

222306
// Store the refresh token if we have one
223307
if (!string.IsNullOrWhiteSpace(tokenResult.RefreshToken))
224308
{
225-
Context.CredentialStore.AddOrUpdate(refreshService, oauthUser, tokenResult.RefreshToken);
309+
_context.CredentialStore.AddOrUpdate(refreshService, oauthUser, tokenResult.RefreshToken);
226310
}
227311

228312
return new GitCredential(oauthUser, tokenResult.AccessToken);
@@ -238,23 +322,23 @@ private bool IsWindowsAuthAllowed
238322
{
239323
get
240324
{
241-
if (Context.Settings.IsWindowsIntegratedAuthenticationEnabled)
325+
if (_context.Settings.IsWindowsIntegratedAuthenticationEnabled)
242326
{
243327
/* COMPAT: In the old GCM one workaround for common authentication problems was to specify "basic" as the authority
244328
* which prevents any smart detection of provider or NTLM etc, allowing the user a chance to manually enter
245329
* a username/password or PAT.
246330
*
247331
* We take this old setting into account to ensure a good migration experience.
248332
*/
249-
return !BasicAuthentication.AuthorityIds.Contains(Context.Settings.LegacyAuthorityOverride, StringComparer.OrdinalIgnoreCase);
333+
return !BasicAuthentication.AuthorityIds.Contains(_context.Settings.LegacyAuthorityOverride, StringComparer.OrdinalIgnoreCase);
250334
}
251335

252336
return false;
253337
}
254338
}
255339

256340
private HttpClient _httpClient;
257-
private HttpClient HttpClient => _httpClient ?? (_httpClient = Context.HttpClientFactory.CreateClient());
341+
private HttpClient HttpClient => _httpClient ?? (_httpClient = _context.HttpClientFactory.CreateClient());
258342

259343
protected override void ReleaseManagedResources()
260344
{

0 commit comments

Comments
 (0)