Skip to content

Commit ae41144

Browse files
Add tests + sample = modified implementation
1 parent 7687984 commit ae41144

File tree

5 files changed

+106
-22
lines changed

5 files changed

+106
-22
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//<Snippet1>
2+
using System;
3+
using System.Threading.Tasks;
4+
using Microsoft.Identity.Client;
5+
using Microsoft.Data.SqlClient;
6+
7+
namespace CustomAuthenticationProviderExamples
8+
{
9+
/// <summary>
10+
/// Example demonstrating creating custom device code flow authentication provider and attaching it with the driver.
11+
/// This is helpful for applications that wish to override Callback of Device Code Result as implemented by SqlClient driver.
12+
/// </summary>
13+
public class CustomDeviceCodeFlowAzureAuthenticationProvider : SqlAuthenticationProvider
14+
{
15+
public async override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters)
16+
{
17+
string clientId = "my-client-id";
18+
string clientName = "My Application Name";
19+
string s_defaultScopeSuffix = "/.default";
20+
21+
string[] scopes = new string[] { parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix };
22+
23+
IPublicClientApplication app = PublicClientApplicationBuilder.Create(clientId)
24+
.WithAuthority(parameters.Authority)
25+
.WithClientName(clientName)
26+
.WithRedirectUri("https://login.microsoftonline.com/common/oauth2/nativeclient")
27+
.Build();
28+
29+
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
30+
deviceCodeResult => CustomDeviceFlowCallback(deviceCodeResult)).ExecuteAsync();
31+
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
32+
}
33+
34+
public override bool IsSupported(SqlAuthenticationMethod authenticationMethod)
35+
{
36+
if (authenticationMethod.Equals(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow))
37+
{
38+
return true;
39+
}
40+
return false;
41+
}
42+
43+
private Task CustomDeviceFlowCallback(DeviceCodeResult result)
44+
{
45+
Console.WriteLine(result.Message);
46+
return Task.FromResult(0);
47+
}
48+
}
49+
50+
public class Program
51+
{
52+
public static void Main()
53+
{
54+
SqlAuthenticationProvider.SetProvider(SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow, new CustomDeviceCodeFlowAzureAuthenticationProvider());
55+
using (SqlConnection sqlConnection = new SqlConnection("Server=<myserver>.database.windows.net;Authentication=Active Directory Device Code Flow;Database=<db>;"))
56+
{
57+
sqlConnection.Open();
58+
Console.WriteLine("Connected successfully!");
59+
}
60+
}
61+
}
62+
}
63+
//</Snippet1>

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
namespace Microsoft.Data.SqlClient
1212
{
13-
14-
1513
/// <summary>
1614
/// Authentication provider manager.
1715
/// </summary>

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,12 @@ internal class ActiveDirectoryAuthenticationProvider : SqlAuthenticationProvider
2222
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
2323
private readonly SqlClientLogger _logger = new SqlClientLogger();
2424

25-
/// <summary>
26-
/// Get Token
27-
/// </summary>
28-
/// <param name="parameters"></param>
29-
/// <returns>Authentication token</returns>
30-
public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters) =>
31-
AcquireTokenAsync(parameters, DefaultDeviceFlowCallback);
32-
3325
/// <summary>
3426
/// Get token.
3527
/// </summary>
3628
/// <param name="parameters"></param>
37-
/// <param name="deviceCodeResultCallback"></param>
3829
/// <returns>Authentication token</returns>
39-
public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters, Func<DeviceCodeResult, Task> deviceCodeResultCallback) => Task.Run(async () =>
30+
public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters) => Task.Run(async () =>
4031
{
4132
AuthenticationResult result;
4233
string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
@@ -125,12 +116,12 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
125116
}
126117
catch (MsalUiRequiredException)
127118
{
128-
result = await AcquireTokenInteractiveDeviceFlow(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, deviceCodeResultCallback);
119+
result = await AcquireTokenInteractiveDeviceFlow(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
129120
}
130121
}
131122
else
132123
{
133-
result = await AcquireTokenInteractiveDeviceFlow(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, deviceCodeResultCallback);
124+
result = await AcquireTokenInteractiveDeviceFlow(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
134125
}
135126
}
136127
else
@@ -142,7 +133,7 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
142133
});
143134

144135
private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlow(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
145-
SqlAuthenticationMethod authenticationMethod, Func<DeviceCodeResult, Task> deviceCodeResultCallback)
136+
SqlAuthenticationMethod authenticationMethod)
146137
{
147138
CancellationTokenSource cts = new CancellationTokenSource();
148139
#if netcoreapp
@@ -187,7 +178,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlow(IPubl
187178
else
188179
{
189180
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
190-
deviceCodeResult => deviceCodeResultCallback(deviceCodeResult)).ExecuteAsync();
181+
deviceCodeResult => DeviceFlowCallback(deviceCodeResult)).ExecuteAsync();
191182
return result;
192183
}
193184
}
@@ -199,7 +190,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlow(IPubl
199190
}
200191
}
201192

202-
private Task DefaultDeviceFlowCallback(DeviceCodeResult result)
193+
private Task DeviceFlowCallback(DeviceCodeResult result)
203194
{
204195
// This will print the message on the console which tells the user where to go sign-in using
205196
// a separate browser and the code to enter once they sign in.

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlAuthenticationProvider.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,5 @@ public virtual void BeforeUnload(SqlAuthenticationMethod authenticationMethod) {
3535

3636
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlAuthenticationProvider.xml' path='docs/members[@name="SqlAuthenticationProvider"]/AcquireTokenAsync/*'/>
3737
public abstract Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters);
38-
39-
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlAuthenticationProvider.xml' path='docs/members[@name="SqlAuthenticationProvider"]/AcquireTokenAsync2/*'/>
40-
public abstract Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters, Func<DeviceCodeResult, Task> deviceCodeResultCallback);
4138
}
4239
}

src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/ConnectivityTests/AADConnectionTest.cs

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Diagnostics;
7+
using System.Security;
78
using Xunit;
89

910
namespace Microsoft.Data.SqlClient.ManualTesting.Tests
@@ -307,7 +308,7 @@ public static void NoCredentialsActiveDirectoryServicePrincipal()
307308
// test Passes with correct connection string.
308309
string[] removeKeys = { "Authentication", "User ID", "Password", "UID", "PWD" };
309310
string connStr = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, removeKeys) +
310-
$"Authentication=Active Directory Service Principal; User ID={DataTestUtility.AADServicePrincipalId}; Password={DataTestUtility.AADServicePrincipalSecret};";
311+
$"Authentication=Active Directory Service Principal; User ID={DataTestUtility.AADServicePrincipalId}; PWD={DataTestUtility.AADServicePrincipalSecret};";
311312
ConnectAndDisconnect(connStr);
312313

313314
// connection fails with expected error message.
@@ -320,7 +321,41 @@ public static void NoCredentialsActiveDirectoryServicePrincipal()
320321
Assert.Contains(expectedMessage, e.Message);
321322
}
322323

323-
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsIntegratedSecuritySetup), nameof(DataTestUtility.AreConnStringsSetup)) ]
324+
[ConditionalFact(nameof(IsAADConnStringsSetup))]
325+
public static void ActiveDirectoryDeviceCodeFlowWithUserIdMustFail()
326+
{
327+
// connection fails with expected error message.
328+
string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" };
329+
string connStrWithUID = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys) +
330+
"Authentication=Active Directory Device Code Flow; UID=someuser;";
331+
ArgumentException e = Assert.Throws<ArgumentException>(() => ConnectAndDisconnect(connStrWithUID));
332+
333+
string expectedMessage = "Cannot use 'Authentication=Active Directory Device Code Flow' with 'User ID', 'UID', 'Password' or 'PWD' connection string keywords.";
334+
Assert.Contains(expectedMessage, e.Message);
335+
}
336+
337+
[ConditionalFact(nameof(IsAADConnStringsSetup))]
338+
public static void ActiveDirectoryDeviceCodeFlowWithCredentialsMustFail()
339+
{
340+
// connection fails with expected error message.
341+
string[] credKeys = { "Authentication", "User ID", "Password", "UID", "PWD" };
342+
string connStrWithNoCred = DataTestUtility.RemoveKeysInConnStr(DataTestUtility.AADPasswordConnectionString, credKeys) +
343+
"Authentication=Active Directory Device Code Flow;";
344+
345+
SecureString str = new SecureString();
346+
foreach (char c in "hello")
347+
{
348+
str.AppendChar(c);
349+
}
350+
str.MakeReadOnly();
351+
SqlCredential credential = new SqlCredential("someuser", str);
352+
InvalidOperationException e = Assert.Throws<InvalidOperationException>(() => ConnectAndDisconnect(connStrWithNoCred, credential));
353+
354+
string expectedMessage = "Cannot set the Credential property if 'Authentication=Active Directory Device Code Flow' has been specified in the connection string.";
355+
Assert.Contains(expectedMessage, e.Message);
356+
}
357+
358+
[ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.IsIntegratedSecuritySetup), nameof(DataTestUtility.AreConnStringsSetup))]
324359
public static void ADInteractiveUsingSSPI()
325360
{
326361
// test Passes with correct connection string.

0 commit comments

Comments
 (0)