diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs index d15e241f75..a8fdf219d3 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs @@ -118,28 +118,46 @@ public override async Task AcquireTokenAsync(SqlAuthenti string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix; string[] scopes = new string[] { scope }; + TokenRequestContext tokenRequestContext = new(scopes); + + /* We split audience from Authority URL here. Audience can be one of the following: + * The Azure AD authority audience enumeration + * The tenant ID, which can be: + * - A GUID (the ID of your Azure AD instance), for single-tenant applications + * - A domain name associated with your Azure AD instance (also for single-tenant applications) + * One of these placeholders as a tenant ID in place of the Azure AD authority audience enumeration: + * - `organizations` for a multitenant application + * - `consumers` to sign in users only with their personal accounts + * - `common` to sign in users with their work and school accounts or their personal Microsoft accounts + * + * MSAL will throw a meaningful exception if you specify both the Azure AD authority audience and the tenant ID. + * If you don't specify an audience, your app will target Azure AD and personal Microsoft accounts as an audience. (That is, it will behave as though `common` were specified.) + * More information: https://docs.microsoft.com/azure/active-directory/develop/msal-client-application-configuration + **/ int seperatorIndex = parameters.Authority.LastIndexOf('/'); - string tenantId = parameters.Authority.Substring(seperatorIndex + 1); string authority = parameters.Authority.Remove(seperatorIndex + 1); - - TokenRequestContext tokenRequestContext = new TokenRequestContext(scopes); + string audience = parameters.Authority.Substring(seperatorIndex + 1); string clientId = string.IsNullOrWhiteSpace(parameters.UserId) ? null : parameters.UserId; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDefault) { - DefaultAzureCredentialOptions defaultAzureCredentialOptions = new DefaultAzureCredentialOptions() + DefaultAzureCredentialOptions defaultAzureCredentialOptions = new() { AuthorityHost = new Uri(authority), - ManagedIdentityClientId = clientId, - InteractiveBrowserTenantId = tenantId, - SharedTokenCacheTenantId = tenantId, - SharedTokenCacheUsername = clientId, - VisualStudioCodeTenantId = tenantId, - VisualStudioTenantId = tenantId, + SharedTokenCacheTenantId = audience, + VisualStudioCodeTenantId = audience, + VisualStudioTenantId = audience, ExcludeInteractiveBrowserCredential = true // Force disabled, even though it's disabled by default to respect driver specifications. }; - AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token); + + // Optionally set clientId when available + if (clientId is not null) + { + defaultAzureCredentialOptions.ManagedIdentityClientId = clientId; + defaultAzureCredentialOptions.SharedTokenCacheUsername = clientId; + } + AccessToken accessToken = await new DefaultAzureCredential(defaultAzureCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Default auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -148,7 +166,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryManagedIdentity || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryMSI) { - AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token); + AccessToken accessToken = await new ManagedIdentityCredential(clientId, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Managed Identity auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -156,7 +174,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti AuthenticationResult result; if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal) { - AccessToken accessToken = await new ClientSecretCredential(tenantId, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token); + AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Service Principal auth mode. Expiry Time: {0}", accessToken.ExpiresOn); return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn); } @@ -194,13 +212,15 @@ public override async Task AcquireTokenAsync(SqlAuthenti result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) .WithCorrelationId(parameters.ConnectionId) .WithUsername(parameters.UserId) - .ExecuteAsync(cancellationToken: cts.Token); + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); } else { result = await app.AcquireTokenByIntegratedWindowsAuth(scopes) .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token); + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); } SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn); } @@ -213,14 +233,15 @@ public override async Task AcquireTokenAsync(SqlAuthenti result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password) .WithCorrelationId(parameters.ConnectionId) - .ExecuteAsync(cancellationToken: cts.Token); + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn); } else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive || parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow) { // Fetch available accounts from 'app' instance - System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync()).GetEnumerator(); + System.Collections.Generic.IEnumerator accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator(); IAccount account = default; if (accounts.MoveNext()) @@ -250,7 +271,7 @@ public override async Task AcquireTokenAsync(SqlAuthenti { // If 'account' is available in 'app', we use the same to acquire token silently. // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent - result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token); + result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } catch (MsalUiRequiredException) @@ -258,14 +279,14 @@ public override async Task AcquireTokenAsync(SqlAuthenti // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application, // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired), // or the user needs to perform two factor authentication. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts); + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } } else { // If no existing 'account' is found, we request user to sign in interactively. - result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts); + result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false); SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn); } } @@ -304,7 +325,8 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( .WithCorrelationId(connectionId) .WithCustomWebUi(_customWebUI) .WithLoginHint(userId) - .ExecuteAsync(ctsInteractive.Token); + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); } else { @@ -328,7 +350,8 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( return await app.AcquireTokenInteractive(scopes) .WithCorrelationId(connectionId) .WithLoginHint(userId) - .ExecuteAsync(ctsInteractive.Token); + .ExecuteAsync(ctsInteractive.Token) + .ConfigureAwait(false); } } else @@ -336,7 +359,8 @@ private async Task AcquireTokenInteractiveDeviceFlowAsync( AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes, deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)) .WithCorrelationId(connectionId) - .ExecuteAsync(cancellationToken: cts.Token); + .ExecuteAsync(cancellationToken: cts.Token) + .ConfigureAwait(false); return result; } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs index 90405689ab..a05d485dff 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs @@ -76,6 +76,16 @@ private static void ConnectAndDisconnect(string connectionString, SqlCredential private static bool IsAADConnStringsSetup() => DataTestUtility.IsAADPasswordConnStrSetup(); private static bool IsManagedIdentitySetup() => DataTestUtility.ManagedIdentitySupported; + [PlatformSpecific(TestPlatforms.Windows)] + [ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))] + public static void KustoDatabaseTest() + { + // This is a sample Kusto database that can be connected by any AD account. + using SqlConnection connection = new SqlConnection("Data Source=help.kusto.windows.net; Authentication=Active Directory Default;Trust Server Certificate=True;"); + connection.Open(); + Assert.True(connection.State == System.Data.ConnectionState.Open); + } + [ConditionalFact(nameof(IsAccessTokenSetup), nameof(IsAADConnStringsSetup))] public static void AccessTokenTest() {