Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
using System.Threading.Tasks;
using Microsoft.Data.Common;
using Microsoft.Data.ProviderBase;
using Microsoft.Data.SqlClient.ManagedSni;

namespace Microsoft.Data.SqlClient
{
Expand Down Expand Up @@ -57,10 +56,6 @@ protected TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCo
// General methods //
/////////////////////

internal abstract uint EnableSsl(ref uint info, bool tlsFirst, string serverCertificateFilename);

internal abstract uint CheckConnection();

internal int DecrementPendingCallbacks(bool release)
{
int remaining = Interlocked.Decrement(ref _pendingCallbacks);
Expand Down Expand Up @@ -215,8 +210,10 @@ private uint GetSniPacket(PacketHandle packet, ref uint dataSize)
return SniPacketGetData(packet, _inBuff, ref dataSize);
}

private void SetBufferSecureStrings()
private bool TrySetBufferSecureStrings()
{
bool mustClearBuffer = false;

if (_securePasswords != null)
{
for (int i = 0; i < _securePasswords.Length; i++)
Expand All @@ -240,6 +237,8 @@ private void SetBufferSecureStrings()
}
TdsParserStaticMethods.ObfuscatePassword(data);
data.CopyTo(_outBuff, _securePasswordOffsetsInBuffer[i]);

mustClearBuffer = true;
}
finally
{
Expand All @@ -248,6 +247,8 @@ private void SetBufferSecureStrings()
}
}
}

return mustClearBuffer;
}

public void ReadAsyncCallback(PacketHandle packet, uint error) =>
Expand Down Expand Up @@ -561,13 +562,7 @@ private Task SNIWritePacket(PacketHandle packet, out uint sniError, bool canAccu
}

// Async operation completion may be delayed (success pending).
try
{
}
finally
{
sniError = WritePacket(packet, sync);
}
sniError = WritePacket(packet, sync);

if (sniError == TdsEnums.SNI_SUCCESS_IO_PENDING)
{
Expand Down Expand Up @@ -730,17 +725,17 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa
}
}

internal abstract PacketHandle CreateAndSetAttentionPacket();

internal abstract void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed);

private Task WriteSni(bool canAccumulate)
{
// Prepare packet, and write to packet.
PacketHandle packet = GetResetWritePacket(_outBytesUsed);
bool mustClearBuffer = TrySetBufferSecureStrings();

SetBufferSecureStrings();
SetPacketData(packet, _outBuff, _outBytesUsed);
if (mustClearBuffer)
{
_outBuff.AsSpan(0, _outBytesUsed).Clear();
}

Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock");
Task task = SNIWritePacket(packet, out _, canAccumulate, callerHasConnectionLock: true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
using System.Net;
using System.Runtime.InteropServices;
using System.Security.Authentication;
using System.Text;
using System.Threading.Tasks;
using Interop.Windows.Sni;
using Microsoft.Data.Common;
Expand Down Expand Up @@ -308,10 +307,9 @@ internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint

internal override PacketHandle CreateAndSetAttentionPacket()
{
SNIHandle handle = Handle;
SNIPacket attnPacket = new SNIPacket(handle);
SNIPacket attnPacket = new SNIPacket(Handle);
_sniAsyncAttnPacket = attnPacket;
SetPacketData(PacketHandle.FromNativePacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN);
SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN);
return PacketHandle.FromNativePacket(attnPacket);
}

Expand Down Expand Up @@ -399,28 +397,20 @@ internal override uint PostReadAsyncForMars(TdsParserStateObject physicalStateOb
PacketHandle temp = default;
uint error = TdsEnums.SNI_SUCCESS;

#if NETFRAMEWORK
RuntimeHelpers.PrepareConstrainedRegions();
#endif
try
{ }
finally
{
IncrementPendingCallbacks();
SessionHandle handle = SessionHandle;
// we do not need to consider partial packets when making this read because we
// expect this read to pend. a partial packet should not exist at setup of the
// parser
Debug.Assert(physicalStateObject.PartialPacket == null);
temp = ReadAsync(handle, out error);
IncrementPendingCallbacks();
SessionHandle handle = SessionHandle;
// we do not need to consider partial packets when making this read because we
// expect this read to pend. a partial packet should not exist at setup of the
// parser
Debug.Assert(physicalStateObject.PartialPacket == null);
temp = ReadAsync(handle, out error);

Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");
Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer");

if (temp.NativePointer != IntPtr.Zero)
{
// Be sure to release packet, otherwise it will be leaked by native.
ReleasePacket(temp);
}
if (temp.NativePointer != IntPtr.Zero)
{
// Be sure to release packet, otherwise it will be leaked by native.
ReleasePacket(temp);
}

Debug.Assert(IntPtr.Zero == temp.NativePointer, "unexpected syncReadPacket without corresponding SNIPacketRelease");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,19 +942,8 @@ private void EnableSsl(uint info, SqlConnectionEncryptOption encrypt, bool integ
info |= TdsEnums.SNI_SSL_IGNORE_CHANNEL_BINDINGS;
}

// Add SSL (Encryption) SNI provider.
AuthProviderInfo authInfo = new AuthProviderInfo();
authInfo.flags = info;
authInfo.tlsFirst = encrypt == SqlConnectionEncryptOption.Strict;
authInfo.certId = null;
authInfo.certHash = false;
authInfo.clientCertificateCallbackContext = IntPtr.Zero;
authInfo.clientCertificateCallback = null;
authInfo.serverCertFileName = string.IsNullOrEmpty(serverCertificateFilename) ? null : serverCertificateFilename;

Debug.Assert((_encryptionOption & EncryptionOptions.CLIENT_CERT) == 0, "Client certificate authentication support has been removed");

error = SniNativeWrapper.SniAddProvider(_physicalStateObj.Handle, Provider.SSL_PROV, authInfo);
error = _physicalStateObj.EnableSsl(ref info, encrypt == SqlConnectionEncryptOption.Strict, serverCertificateFilename);

if (error != TdsEnums.SNI_SUCCESS)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,18 @@

using System;
using System.Buffers.Binary;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
using System.Security;
using System.Threading;
using System.Threading.Tasks;
using Interop.Windows.Sni;
using Microsoft.Data.Common;
using Microsoft.Data.ProviderBase;

namespace Microsoft.Data.SqlClient
{
internal partial class TdsParserStateObject
{
protected SNIHandle _sessionHandle = null; // the SNI handle we're to work on

// SNI variables // multiple resultsets in one batch.
protected SNIPacket _sniPacket = null; // Will have to re-vamp this for MARS
internal SNIPacket _sniAsyncAttnPacket = null; // Packet to use to send Attn

// Used for blanking out password in trace.
internal int _tracePasswordOffset = 0;
internal int _tracePasswordLength = 0;
Expand Down Expand Up @@ -68,23 +58,10 @@ protected TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCo
_lastSuccessfulIOTimer = parser._physicalStateObj._lastSuccessfulIOTimer;
}

////////////////
// Properties //
////////////////
internal SNIHandle Handle
{
get
{
return _sessionHandle;
}
}

/////////////////////
// General methods //
/////////////////////

internal uint CheckConnection() => SniNativeWrapper.SniCheckConnection(Handle);

internal int DecrementPendingCallbacks(bool release)
{
int remaining = Interlocked.Decrement(ref _pendingCallbacks);
Expand All @@ -94,7 +71,7 @@ internal int DecrementPendingCallbacks(bool release)

// NOTE: TdsParserSessionPool may call DecrementPendingCallbacks on a TdsParserStateObject which is already disposed
// This is not dangerous (since the stateObj is no longer in use), but we need to add a workaround in the assert for it
Debug.Assert((remaining == -1 && _sessionHandle == null) || (0 <= remaining && remaining < 3), $"_pendingCallbacks values is invalid after decrementing: {remaining}");
Debug.Assert((remaining == -1 && SessionHandle.IsNull) || (0 <= remaining && remaining < 3), $"_pendingCallbacks values is invalid after decrementing: {remaining}");
return remaining;
}

Expand All @@ -121,11 +98,7 @@ internal bool ValidateSNIConnection()
try
{
Interlocked.Increment(ref _readingCount);
SNIHandle handle = Handle;
if (handle != null)
{
error = SniNativeWrapper.SniCheckConnection(handle);
}
error = CheckConnection();
}
finally
{
Expand Down Expand Up @@ -243,6 +216,47 @@ private uint GetSniPacket(PacketHandle packet, ref uint dataSize)
return SniPacketGetData(packet, _inBuff, ref dataSize);
}

private bool TrySetBufferSecureStrings()
{
bool mustClearBuffer = false;

if (_securePasswords != null)
{
for (int i = 0; i < _securePasswords.Length; i++)
{
if (_securePasswords[i] != null)
{
IntPtr str = IntPtr.Zero;
try
{
str = Marshal.SecureStringToBSTR(_securePasswords[i]);
byte[] data = new byte[_securePasswords[i].Length * 2];
Marshal.Copy(str, data, 0, _securePasswords[i].Length * 2);
if (!BitConverter.IsLittleEndian)
{
Span<byte> span = data.AsSpan();
for (int ii = 0; ii < _securePasswords[i].Length * 2; ii += 2)
{
short value = BinaryPrimitives.ReadInt16LittleEndian(span.Slice(ii));
BinaryPrimitives.WriteInt16BigEndian(span.Slice(ii), value);
}
}
TdsParserStaticMethods.ObfuscatePassword(data);
data.CopyTo(_outBuff, _securePasswordOffsetsInBuffer[i]);

mustClearBuffer = true;
}
finally
{
Marshal.ZeroFreeBSTR(str);
}
}
}
}

return mustClearBuffer;
}

public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error)
{
// Key never used.
Expand Down Expand Up @@ -717,20 +731,17 @@ internal void SendAttention(bool mustTakeWriteLock = false, bool asyncClose = fa
}
}

internal PacketHandle CreateAndSetAttentionPacket()
{
SNIPacket attnPacket = new SNIPacket(Handle);
_sniAsyncAttnPacket = attnPacket;
SniNativeWrapper.SniPacketSetData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN, null, null);
return PacketHandle.FromNativePacket(attnPacket);
}

private Task WriteSni(bool canAccumulate)
{
// Prepare packet, and write to packet.
PacketHandle packet = GetResetWritePacket(_outBytesUsed);
SNIPacket nativePacket = packet.NativePacket;
SniNativeWrapper.SniPacketSetData(nativePacket, _outBuff, _outBytesUsed, _securePasswords, _securePasswordOffsetsInBuffer);
bool mustClearBuffer = TrySetBufferSecureStrings();

SetPacketData(packet, _outBuff, _outBytesUsed);
if (mustClearBuffer)
{
_outBuff.AsSpan(0, _outBytesUsed).Clear();
}

Debug.Assert(Parser.Connection._parserLock.ThreadMayHaveLock(), "Thread is writing without taking the connection lock");
Task task = SNIWritePacket(packet, out _, canAccumulate, callerHasConnectionLock: true);
Expand Down
Loading
Loading