diff --git a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj index c60ca96e13c2..35c08bd531e1 100644 --- a/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj +++ b/src/System.Data.SqlClient/src/System.Data.SqlClient.csproj @@ -1,4 +1,4 @@ - + {D4550556-4745-457F-BA8F-3EBF3836D6B4} System.Data.SqlClient @@ -276,6 +276,8 @@ + + @@ -285,6 +287,8 @@ + + @@ -479,6 +483,8 @@ + + diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Unix.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Unix.cs new file mode 100644 index 000000000000..f9fd9dc7e904 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Unix.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +namespace System.Data.SqlClient +{ + // this structure is used for transporting packet handle references between the TdsParserStateObject + // base class and Managed or Native implementations. + // It prevents the native IntPtr type from being boxed and prevents the need to cast from object which loses compile time type safety + // It carries type information so that assertions about the type of handle can be made in the implemented abstract methods + // it is a ref struct so that it can only be used to transport the handles and not store them + + // N.B. If you change this type you must also change the version for the other platform + + internal readonly ref struct PacketHandle + { + public const int NativePointerType = 1; + public const int NativePacketType = 2; + public const int ManagedPacketType = 3; + + public readonly SNI.SNIPacket ManagedPacket; + public readonly int Type; + + private PacketHandle(SNI.SNIPacket managedPacket, int type) + { + Type = type; + ManagedPacket = managedPacket; + } + + public static PacketHandle FromManagedPacket(SNI.SNIPacket managedPacket) => new PacketHandle(managedPacket, ManagedPacketType); + } +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Windows.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Windows.cs new file mode 100644 index 000000000000..f15d26fd38b0 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/PacketHandle.Windows.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +namespace System.Data.SqlClient +{ + // this structure is used for transporting packet handle references between the TdsParserStateObject + // base class and Managed or Native implementations. + // It prevents the native IntPtr type from being boxed and prevents the need to cast from object which loses compile time type safety + // It carries type information so that assertions about the type of handle can be made in the implemented abstract methods + // it is a ref struct so that it can only be used to transport the handles and not store them + + // N.B. If you change this type you must also change the version for the other platform + + internal readonly ref struct PacketHandle + { + public const int NativePointerType = 1; + public const int NativePacketType = 2; + public const int ManagedPacketType = 3; + + public readonly IntPtr NativePointer; + public readonly SNIPacket NativePacket; + + public readonly SNI.SNIPacket ManagedPacket; + public readonly int Type; + + private PacketHandle(IntPtr nativePointer, SNIPacket nativePacket, SNI.SNIPacket managedPacket, int type) + { + Type = type; + ManagedPacket = managedPacket; + NativePointer = nativePointer; + NativePacket = nativePacket; + } + + public static PacketHandle FromManagedPacket(SNI.SNIPacket managedPacket) => new PacketHandle(default, default, managedPacket, ManagedPacketType); + + public static PacketHandle FromNativePointer(IntPtr nativePointer) => new PacketHandle(nativePointer, default, default, NativePointerType); + + public static PacketHandle FromNativePacket(SNIPacket nativePacket) => new PacketHandle(default, nativePacket, default, NativePacketType); + + + } +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs index 4506ff4c9b01..5a0e8f7bace2 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIMarsHandle.cs @@ -318,7 +318,7 @@ public void HandleReceiveError(SNIPacket packet) _packetEvent.Set(); } - ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(packet, 1); + ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 1); } /// @@ -332,7 +332,7 @@ public void HandleSendComplete(SNIPacket packet, uint sniErrorCode) { Debug.Assert(_callbackObject != null); - ((TdsParserStateObject)_callbackObject).WriteAsyncCallback(packet, sniErrorCode); + ((TdsParserStateObject)_callbackObject).WriteAsyncCallback(PacketHandle.FromManagedPacket(packet), sniErrorCode); } } @@ -378,7 +378,7 @@ public void HandleReceiveComplete(SNIPacket packet, SNISMUXHeader header) _asyncReceives--; Debug.Assert(_callbackObject != null); - ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(packet, 0); + ((TdsParserStateObject)_callbackObject).ReadAsyncCallback(PacketHandle.FromManagedPacket(packet), 0); } } diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs index f7ba249b066f..f15dca8f589b 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SNI/SNIPacket.cs @@ -20,8 +20,6 @@ internal class SNIPacket : IDisposable, IEquatable private int _offset; private string _description; private SNIAsyncCallback _completionCallback; - - private ArrayPool _arrayPool = ArrayPool.Shared; private bool _isBufferFromArrayPool = false; public SNIPacket() { } @@ -98,14 +96,14 @@ public void Allocate(int capacity) { if (_isBufferFromArrayPool) { - _arrayPool.Return(_data); + ArrayPool.Shared.Return(_data); } _data = null; } if (_data == null) { - _data = _arrayPool.Rent(capacity); + _data = ArrayPool.Shared.Rent(capacity); _isBufferFromArrayPool = true; } @@ -221,7 +219,7 @@ public void Release() { if(_isBufferFromArrayPool) { - _arrayPool.Return(_data); + ArrayPool.Shared.Return(_data); } _data = null; _capacity = 0; diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Unix.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Unix.cs new file mode 100644 index 000000000000..5bf099aba58e --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Unix.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +namespace System.Data.SqlClient +{ + // this structure is used for transporting packet handle references between the TdsParserStateObject + // base class and Managed or Native implementations. + // It carries type information so that assertions about the type of handle can be made in the + // implemented abstract methods + // it is a ref struct so that it can only be used to transport the handles and not store them + + // N.B. If you change this type you must also change the version for the other platform + + internal readonly ref struct SessionHandle + { + public const int NativeHandleType = 1; + public const int ManagedHandleType = 2; + + public readonly SNI.SNIHandle ManagedHandle; + public readonly int Type; + + public SessionHandle(SNI.SNIHandle managedHandle, int type) + { + Type = type; + ManagedHandle = managedHandle; + } + + public bool IsNull => ManagedHandle is null; + + public static SessionHandle FromManagedSession(SNI.SNIHandle managedSessionHandle) => new SessionHandle(managedSessionHandle, ManagedHandleType); + } +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Windows.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Windows.cs new file mode 100644 index 000000000000..a7215963c8b2 --- /dev/null +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/SessionHandle.Windows.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + + +namespace System.Data.SqlClient +{ + // this structure is used for transporting packet handle references between the TdsParserStateObject + // base class and Managed or Native implementations. + // It carries type information so that assertions about the type of handle can be made in the + // implemented abstract methods + // it is a ref struct so that it can only be used to transport the handles and not store them + + // N.B. If you change this type you must also change the version for the other platform + + internal readonly ref struct SessionHandle + { + public const int NativeHandleType = 1; + public const int ManagedHandleType = 2; + + public readonly SNI.SNIHandle ManagedHandle; + public readonly SNIHandle NativeHandle; + + public readonly int Type; + + public SessionHandle(SNI.SNIHandle managedHandle, SNIHandle nativeHandle, int type) + { + Type = type; + ManagedHandle = managedHandle; + NativeHandle = nativeHandle; + } + + public bool IsNull => (Type == NativeHandleType) ? NativeHandle is null : ManagedHandle is null; + + public static SessionHandle FromManagedSession(SNI.SNIHandle managedSessionHandle) => new SessionHandle(managedSessionHandle, default, ManagedHandleType); + + public static SessionHandle FromNativeHandle(SNIHandle nativeSessionHandle) => new SessionHandle(default, nativeSessionHandle, NativeHandleType); + } +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.Windows.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.Windows.cs index 12ce515bf93d..6fec58ad11dd 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.Windows.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParser.Windows.cs @@ -20,20 +20,22 @@ internal void PostReadAsyncForMars() // Have to post read to initialize MARS - will get callback on this when connection goes // down or is closed. - IntPtr temp = IntPtr.Zero; + PacketHandle temp = default; uint error = TdsEnums.SNI_SUCCESS; _pMarsPhysicalConObj.IncrementPendingCallbacks(); - object handle = _pMarsPhysicalConObj.SessionHandle; - temp = (IntPtr)_pMarsPhysicalConObj.ReadAsync(out error, ref handle); + SessionHandle handle = _pMarsPhysicalConObj.SessionHandle; + temp = _pMarsPhysicalConObj.ReadAsync(handle, out error); - if (temp != IntPtr.Zero) + Debug.Assert(temp.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); + + if (temp.NativePointer != IntPtr.Zero) { // Be sure to release packet, otherwise it will be leaked by native. _pMarsPhysicalConObj.ReleasePacket(temp); } - - Debug.Assert(IntPtr.Zero == temp, "unexpected syncReadPacket without corresponding SNIPacketRelease"); + + Debug.Assert(IntPtr.Zero == temp.NativePointer, "unexpected syncReadPacket without corresponding SNIPacketRelease"); if (TdsEnums.SNI_SUCCESS_IO_PENDING != error) { Debug.Assert(TdsEnums.SNI_SUCCESS != error, "Unexpected successful read async on physical connection before enabling MARS!"); @@ -118,4 +120,4 @@ private SNIErrorDetails GetSniErrorDetails() } } // tdsparser -}//namespace \ No newline at end of file +}//namespace diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserSafeHandles.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserSafeHandles.cs index 55cd8a1c5fb5..7dce0de70ae5 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserSafeHandles.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserSafeHandles.cs @@ -104,7 +104,7 @@ private static void ReadDispatcher(IntPtr key, IntPtr packet, uint error) if (null != stateObj) { - stateObj.ReadAsyncCallback(IntPtr.Zero, packet, error); + stateObj.ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromNativePointer(packet), error); } } } @@ -125,7 +125,7 @@ private static void WriteDispatcher(IntPtr key, IntPtr packet, uint error) if (null != stateObj) { - stateObj.WriteAsyncCallback(IntPtr.Zero, packet, error); + stateObj.WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromNativePointer(packet), error); } } } @@ -296,4 +296,4 @@ public void Dispose() } } } -} \ No newline at end of file +} diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs index fea7447e4f5f..e6ba9a6674a2 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObject.cs @@ -392,7 +392,7 @@ internal abstract uint Status get; } - internal abstract object SessionHandle + internal abstract SessionHandle SessionHandle { get; } @@ -761,27 +761,27 @@ private void ResetCancelAndProcessAttention() internal abstract void DisposePacketCache(); - internal abstract bool IsPacketEmpty(object readPacket); + internal abstract bool IsPacketEmpty(PacketHandle readPacket); - internal abstract object ReadSyncOverAsync(int timeoutRemaining, out uint error); + internal abstract PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error); - internal abstract object ReadAsync(out uint error, ref object handle); + internal abstract PacketHandle ReadAsync(SessionHandle handle, out uint error); internal abstract uint CheckConnection(); internal abstract uint SetConnectionBufferSize(ref uint unsignedPacketSize); - internal abstract void ReleasePacket(object syncReadPacket); + internal abstract void ReleasePacket(PacketHandle syncReadPacket); - protected abstract uint SNIPacketGetData(object packet, byte[] _inBuff, ref uint dataSize); + protected abstract uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize); - internal abstract object GetResetWritePacket(); + internal abstract PacketHandle GetResetWritePacket(); internal abstract void ClearAllWritePackets(); - internal abstract object AddPacketToPendingList(object packet); + internal abstract PacketHandle AddPacketToPendingList(PacketHandle packet); - protected abstract void RemovePacketFromPendingList(object pointer); + protected abstract void RemovePacketFromPendingList(PacketHandle pointer); internal abstract uint GenerateSspiClientContext(byte[] receivedBuff, uint receivedLength, ref byte[] sendBuff, ref uint sendLength, byte[] _sniSpnBuffer); @@ -855,7 +855,7 @@ internal int DecrementPendingCallbacks(bool release) // NOTE: TdsParserSessionPool may call DecrementPendingCallbacks on a TdsParserStateObject which is already disposed // This is not dangerous (since the stateObj is no longer in use), but we need to add a workaround in the assert for it - Debug.Assert((remaining == -1 && SessionHandle == null) || (0 <= remaining && remaining < 3), string.Format("_pendingCallbacks values is invalid after decrementing: {0}", remaining)); + Debug.Assert((remaining == -1 && SessionHandle.IsNull) || (0 <= remaining && remaining < 3), string.Format("_pendingCallbacks values is invalid after decrementing: {0}", remaining)); return remaining; } @@ -2069,7 +2069,7 @@ internal void ReadSniSyncOverAsync() throw ADP.ClosedConnectionError(); } - object readPacket = null; + PacketHandle readPacket = default; uint error; @@ -2291,7 +2291,7 @@ internal void ReadSni(TaskCompletionSource completion) #endif - object readPacket = null; + PacketHandle readPacket = default; uint error = 0; @@ -2317,16 +2317,14 @@ internal void ReadSni(TaskCompletionSource completion) ChangeNetworkPacketTimeout(msecsRemaining, Timeout.Infinite); } - object handle = null; - Interlocked.Increment(ref _readingCount); - handle = SessionHandle; - if (handle != null) + SessionHandle handle = SessionHandle; + if (!handle.IsNull) { IncrementPendingCallbacks(); - readPacket = ReadAsync(out error, ref handle); + readPacket = ReadAsync(handle, out error); if (!(TdsEnums.SNI_SUCCESS == error || TdsEnums.SNI_SUCCESS_IO_PENDING == error)) { @@ -2335,8 +2333,8 @@ internal void ReadSni(TaskCompletionSource completion) } Interlocked.Decrement(ref _readingCount); - - if (handle == null) + + if (handle.IsNull) { throw ADP.ClosedConnectionError(); } @@ -2419,8 +2417,8 @@ internal bool IsConnectionAlive(bool throwOnException) { uint error; SniContext = SniContext.Snix_Connect; - error = CheckConnection(); + if ((error != TdsEnums.SNI_SUCCESS) && (error != TdsEnums.SNI_WAIT_TIMEOUT)) { // Connection is dead @@ -2498,7 +2496,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) { stateObj.SendAttention(mustTakeWriteLock: true); - object syncReadPacket = null; + PacketHandle syncReadPacket = default; bool shouldDecrement = false; try @@ -2570,7 +2568,7 @@ private void ReadSniError(TdsParserStateObject stateObj, uint error) AssertValidState(); } - public void ProcessSniPacket(object packet, uint error) + public void ProcessSniPacket(PacketHandle packet, uint error) { if (error != 0) { @@ -2669,13 +2667,12 @@ private void SetBufferSecureStrings() } } - public void ReadAsyncCallback(T packet, uint error) + public void ReadAsyncCallback(PacketHandle packet, uint error) { ReadAsyncCallback(IntPtr.Zero, packet, error); } - - public void ReadAsyncCallback(IntPtr key, T packet, uint error) + public void ReadAsyncCallback(IntPtr key, PacketHandle packet, uint error) { // Key never used. // Note - it's possible that when native calls managed that an asynchronous exception @@ -2755,7 +2752,7 @@ public void ReadAsyncCallback(IntPtr key, T packet, uint error) } } - protected abstract bool CheckPacket(object packet, TaskCompletionSource source); + protected abstract bool CheckPacket(PacketHandle packet, TaskCompletionSource source); private void ReadAsyncCallbackCaptureException(TaskCompletionSource source) { @@ -2801,12 +2798,12 @@ private void ReadAsyncCallbackCaptureException(TaskCompletionSource sour #pragma warning disable 0420 // a reference to a volatile field will not be treated as volatile - public void WriteAsyncCallback(T packet, uint sniError) + public void WriteAsyncCallback(PacketHandle packet, uint sniError) { WriteAsyncCallback(IntPtr.Zero, packet, sniError); } - public void WriteAsyncCallback(IntPtr key, T packet, uint sniError) + public void WriteAsyncCallback(IntPtr key, PacketHandle packet, uint sniError) { // Key never used. RemovePacketFromPendingList(packet); try @@ -3218,7 +3215,7 @@ private void CancelWritePacket() #pragma warning disable 0420 // a reference to a volatile field will not be treated as volatile - private Task SNIWritePacket(object packet, out uint sniError, bool canAccumulate, bool callerHasConnectionLock) + private Task SNIWritePacket(PacketHandle packet, out uint sniError, bool canAccumulate, bool callerHasConnectionLock) { // Check for a stored exception var delayedException = Interlocked.Exchange(ref _delayedWriteAsyncCallbackException, null); @@ -3230,7 +3227,7 @@ private Task SNIWritePacket(object packet, out uint sniError, bool canAccumulate Task task = null; _writeCompletionSource = null; - object packetPointer = EmptyReadPacket; + PacketHandle packetPointer = EmptyReadPacket; bool sync = !_parser._asyncWrite; if (sync && _asyncWriteCount > 0) @@ -3351,8 +3348,9 @@ private Task SNIWritePacket(object packet, out uint sniError, bool canAccumulate return task; } - internal abstract bool IsValidPacket(object packetPointer); - internal abstract uint WritePacket(object packet, bool sync); + internal abstract bool IsValidPacket(PacketHandle packetPointer); + + internal abstract uint WritePacket(PacketHandle packet, bool sync); #pragma warning restore 0420 @@ -3369,7 +3367,7 @@ internal void SendAttention(bool mustTakeWriteLock = false) return; } - object attnPacket = CreateAndSetAttentionPacket(); + PacketHandle attnPacket = CreateAndSetAttentionPacket(); try { @@ -3427,14 +3425,14 @@ internal void SendAttention(bool mustTakeWriteLock = false) } } - internal abstract object CreateAndSetAttentionPacket(); + internal abstract PacketHandle CreateAndSetAttentionPacket(); - internal abstract void SetPacketData(object packet, byte[] buffer, int bytesUsed); + internal abstract void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed); private Task WriteSni(bool canAccumulate) { // Prepare packet, and write to packet. - object packet = GetResetWritePacket(); + PacketHandle packet = GetResetWritePacket(); SetBufferSecureStrings(); SetPacketData(packet, _outBuff, _outBytesUsed); @@ -3646,7 +3644,7 @@ internal int WarningCount } } - protected abstract object EmptyReadPacket { get; } + protected abstract PacketHandle EmptyReadPacket { get; } /// /// Gets the full list of errors and warnings (including the pre-attention ones), then wipes all error and warning lists diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs index 96832fb8b46f..fb57c55f556e 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectFactory.Windows.cs @@ -19,7 +19,14 @@ internal sealed class TdsParserStateObjectFactory //private static bool shouldUseLegacyNetorking; //public static bool UseManagedSNI { get; } = AppContext.TryGetSwitch(UseLegacyNetworkingOnWindows, out shouldUseLegacyNetorking) ? !shouldUseLegacyNetorking : true; +#if DEBUG + private static Lazy useManagedSNIOnWindows = new Lazy( + () => bool.TrueString.Equals(Environment.GetEnvironmentVariable("System.Data.SqlClient.UseManagedSNIOnWindows"), StringComparison.InvariantCultureIgnoreCase) + ); + public static bool UseManagedSNI => useManagedSNIOnWindows.Value; +#else public static bool UseManagedSNI { get; } = false; +#endif public EncryptionOptions EncryptionOptions { diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs index 103e6fb163c5..151d4e554aa0 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -31,14 +31,12 @@ internal TdsParserStateObjectManaged(TdsParser parser, TdsParserStateObject phys internal override uint Status => _sessionHandle != null ? _sessionHandle.Status : TdsEnums.SNI_UNINITIALIZED; - internal override object SessionHandle => _sessionHandle; + internal override SessionHandle SessionHandle => SessionHandle.FromManagedSession(_sessionHandle); - protected override object EmptyReadPacket => null; - - protected override bool CheckPacket(object packet, TaskCompletionSource source) + protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource source) { - SNIPacket p = packet as SNIPacket; - return p.IsInvalid || (!p.IsInvalid && source != null); + SNIPacket p = packet.ManagedPacket; + return p.IsInvalid || source != null; } protected override void CreateSessionHandle(TdsParserStateObject physicalConnection, bool async) @@ -54,7 +52,7 @@ internal SNIMarsHandle CreateMarsSession(object callbackObject, bool async) return _marsConnection.CreateMarsSession(callbackObject, async); } - protected override uint SNIPacketGetData(object packet, byte[] _inBuff, ref uint dataSize) => SNIProxy.Singleton.PacketGetData(packet as SNIPacket, _inBuff, ref dataSize); + protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) => SNIProxy.Singleton.PacketGetData(packet.ManagedPacket, _inBuff, ref dataSize); internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSniOpenTimeout, long timerExpire, out byte[] instanceName, ref byte[] spnBuffer, bool flushCache, bool async, bool parallel, bool isIntegratedSecurity) { @@ -72,11 +70,11 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni } } - internal void ReadAsyncCallback(SNIPacket packet, uint error) => ReadAsyncCallback(IntPtr.Zero, packet, error); + internal void ReadAsyncCallback(SNIPacket packet, uint error) => ReadAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), error); - internal void WriteAsyncCallback(SNIPacket packet, uint sniError) => WriteAsyncCallback(IntPtr.Zero, packet, sniError); + internal void WriteAsyncCallback(SNIPacket packet, uint sniError) => WriteAsyncCallback(IntPtr.Zero, PacketHandle.FromManagedPacket(packet), sniError); - protected override void RemovePacketFromPendingList(object packet) + protected override void RemovePacketFromPendingList(PacketHandle packet) { // No-Op } @@ -125,7 +123,7 @@ protected override void FreeGcHandle(int remaining, bool release) internal override bool IsFailedHandle() => _sessionHandle.Status != TdsEnums.SNI_SUCCESS; - internal override object ReadSyncOverAsync(int timeoutRemaining, out uint error) + internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error) { SNIHandle handle = Handle; if (handle == null) @@ -134,17 +132,19 @@ internal override object ReadSyncOverAsync(int timeoutRemaining, out uint error) } SNIPacket packet = null; error = SNIProxy.Singleton.ReadSyncOverAsync(handle, out packet, timeoutRemaining); - return packet; + return PacketHandle.FromManagedPacket(packet); } - internal override bool IsPacketEmpty(object packet) + protected override PacketHandle EmptyReadPacket => PacketHandle.FromManagedPacket(null); + + internal override bool IsPacketEmpty(PacketHandle packet) { - return packet == null; + return packet.ManagedPacket == null; } - internal override void ReleasePacket(object syncReadPacket) + internal override void ReleasePacket(PacketHandle syncReadPacket) { - ((SNIPacket)syncReadPacket).Dispose(); + syncReadPacket.ManagedPacket?.Dispose(); } internal override uint CheckConnection() @@ -153,38 +153,46 @@ internal override uint CheckConnection() return handle == null ? TdsEnums.SNI_SUCCESS : SNIProxy.Singleton.CheckConnection(handle); } - internal override object ReadAsync(out uint error, ref object handle) + internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) { SNIPacket packet; - error = SNIProxy.Singleton.ReadAsync((SNIHandle)handle, out packet); - return packet; + error = SNIProxy.Singleton.ReadAsync(handle.ManagedHandle, out packet); + return PacketHandle.FromManagedPacket(packet); } - internal override object CreateAndSetAttentionPacket() + internal override PacketHandle CreateAndSetAttentionPacket() { if (_sniAsyncAttnPacket == null) { SNIPacket attnPacket = new SNIPacket(); - SetPacketData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN); + SetPacketData(PacketHandle.FromManagedPacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN); _sniAsyncAttnPacket = attnPacket; } - return _sniAsyncAttnPacket; + return PacketHandle.FromManagedPacket(_sniAsyncAttnPacket); } - internal override uint WritePacket(object packet, bool sync) + internal override uint WritePacket(PacketHandle packet, bool sync) { - return SNIProxy.Singleton.WritePacket((SNIHandle)Handle, (SNIPacket)packet, sync); + return SNIProxy.Singleton.WritePacket(Handle, packet.ManagedPacket, sync); } - internal override object AddPacketToPendingList(object packet) + internal override PacketHandle AddPacketToPendingList(PacketHandle packet) { // No-Op return packet; } - internal override bool IsValidPacket(object packetPointer) => (SNIPacket)packetPointer != null && !((SNIPacket)packetPointer).IsInvalid; + internal override bool IsValidPacket(PacketHandle packet) + { + Debug.Assert(packet.Type == PacketHandle.ManagedPacketType, "unexpected packet type when requiring ManagedPacket"); + return ( + packet.Type == PacketHandle.ManagedPacketType && + packet.ManagedPacket != null && + !packet.ManagedPacket.IsInvalid + ); + } - internal override object GetResetWritePacket() + internal override PacketHandle GetResetWritePacket() { if (_sniPacket != null) { @@ -197,7 +205,7 @@ internal override object GetResetWritePacket() _sniPacket = _writePacketCache.Take(Handle); } } - return _sniPacket; + return PacketHandle.FromManagedPacket(_sniPacket); } internal override void ClearAllWritePackets() @@ -214,8 +222,8 @@ internal override void ClearAllWritePackets() } } - internal override void SetPacketData(object packet, byte[] buffer, int bytesUsed) => SNIProxy.Singleton.PacketSetData((SNIPacket)packet, buffer, bytesUsed); - + internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) => SNIProxy.Singleton.PacketSetData(packet.ManagedPacket, buffer, bytesUsed); + internal override uint SniGetConnectionId(ref Guid clientConnectionId) => SNIProxy.Singleton.GetConnectionId(Handle, ref clientConnectionId); internal override uint DisabeSsl() => SNIProxy.Singleton.DisableSsl(Handle); diff --git a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs index e830fc074de6..5c43bdb07902 100644 --- a/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/System.Data.SqlClient/src/System/Data/SqlClient/TdsParserStateObjectNative.cs @@ -12,8 +12,6 @@ namespace System.Data.SqlClient { internal class TdsParserStateObjectNative : TdsParserStateObject { - private static readonly object s_cachedEmptyReadPacketObjectPointer = (object)IntPtr.Zero; - private SNIHandle _sessionHandle = null; // the SNI handle we're to work on private SNIPacket _sniPacket = null; // Will have to re-vamp this for MARS @@ -35,9 +33,7 @@ internal TdsParserStateObjectNative(TdsParser parser, TdsParserStateObject physi internal override uint Status => _sessionHandle != null ? _sessionHandle.Status : TdsEnums.SNI_UNINITIALIZED; - internal override object SessionHandle => _sessionHandle; - - protected override object EmptyReadPacket => s_cachedEmptyReadPacketObjectPointer; + internal override SessionHandle SessionHandle => SessionHandle.FromNativeHandle(_sessionHandle); protected override void CreateSessionHandle(TdsParserStateObject physicalConnection, bool async) { @@ -99,11 +95,16 @@ internal override void CreatePhysicalSNIHandle(string serverName, bool ignoreSni _sessionHandle = new SNIHandle(myInfo, serverName, spnBuffer, ignoreSniOpenTimeout, checked((int)timeout), out instanceName, flushCache, !async, fParallel); } - protected override uint SNIPacketGetData(object packet, byte[] _inBuff, ref uint dataSize) => SNINativeMethodWrapper.SNIPacketGetData((IntPtr)packet, _inBuff, ref dataSize); + protected override uint SNIPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) + { + Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); + return SNINativeMethodWrapper.SNIPacketGetData(packet.NativePointer, _inBuff, ref dataSize); + } - protected override bool CheckPacket(object packet, TaskCompletionSource source) + protected override bool CheckPacket(PacketHandle packet, TaskCompletionSource source) { - IntPtr ptr = (IntPtr)(object)packet; + Debug.Assert(packet.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); + IntPtr ptr = packet.NativePointer; return IntPtr.Zero == ptr || IntPtr.Zero != ptr && source != null; } @@ -111,9 +112,10 @@ protected override bool CheckPacket(object packet, TaskCompletionSource public void WriteAsyncCallback(IntPtr key, IntPtr packet, uint sniError) => WriteAsyncCallback(key, packet, sniError); - protected override void RemovePacketFromPendingList(object ptr) + protected override void RemovePacketFromPendingList(PacketHandle ptr) { - IntPtr pointer = (IntPtr)ptr; + Debug.Assert(ptr.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); + IntPtr pointer = ptr.NativePointer; SNIPacket recoveredPacket; @@ -171,7 +173,7 @@ protected override void FreeGcHandle(int remaining, bool release) internal override bool IsFailedHandle() => _sessionHandle.Status != TdsEnums.SNI_SUCCESS; - internal override object ReadSyncOverAsync(int timeoutRemaining, out uint error) + internal override PacketHandle ReadSyncOverAsync(int timeoutRemaining, out uint error) { SNIHandle handle = Handle; if (handle == null) @@ -180,12 +182,22 @@ internal override object ReadSyncOverAsync(int timeoutRemaining, out uint error) } IntPtr readPacketPtr = IntPtr.Zero; error = SNINativeMethodWrapper.SNIReadSyncOverAsync(handle, ref readPacketPtr, GetTimeoutRemaining()); - return readPacketPtr; + return PacketHandle.FromNativePointer(readPacketPtr); } - internal override bool IsPacketEmpty(object readPacket) => IntPtr.Zero == (IntPtr)readPacket; + protected override PacketHandle EmptyReadPacket => PacketHandle.FromNativePointer(default); - internal override void ReleasePacket(object syncReadPacket) => SNINativeMethodWrapper.SNIPacketRelease((IntPtr)syncReadPacket); + internal override bool IsPacketEmpty(PacketHandle readPacket) + { + Debug.Assert(readPacket.Type == PacketHandle.NativePointerType || readPacket.Type == 0, "unexpected packet type when requiring NativePointer"); + return IntPtr.Zero == readPacket.NativePointer; + } + + internal override void ReleasePacket(PacketHandle syncReadPacket) + { + Debug.Assert(syncReadPacket.Type == PacketHandle.NativePointerType, "unexpected packet type when requiring NativePointer"); + SNINativeMethodWrapper.SNIPacketRelease(syncReadPacket.NativePointer); + } internal override uint CheckConnection() { @@ -193,27 +205,33 @@ internal override uint CheckConnection() return handle == null ? TdsEnums.SNI_SUCCESS : SNINativeMethodWrapper.SNICheckConnection(handle); } - internal override object ReadAsync(out uint error, ref object handle) + internal override PacketHandle ReadAsync(SessionHandle handle, out uint error) { + Debug.Assert(handle.Type == SessionHandle.NativeHandleType, "unexpected handle type when requiring NativePointer"); IntPtr readPacketPtr = IntPtr.Zero; - error = SNINativeMethodWrapper.SNIReadAsync((SNIHandle)handle, ref readPacketPtr); - return readPacketPtr; + error = SNINativeMethodWrapper.SNIReadAsync(handle.NativeHandle, ref readPacketPtr); + return PacketHandle.FromNativePointer(readPacketPtr); } - internal override object CreateAndSetAttentionPacket() + internal override PacketHandle CreateAndSetAttentionPacket() { SNIHandle handle = Handle; SNIPacket attnPacket = new SNIPacket(handle); _sniAsyncAttnPacket = attnPacket; - SetPacketData(attnPacket, SQL.AttentionHeader, TdsEnums.HEADER_LEN); - return attnPacket; + SetPacketData(PacketHandle.FromNativePacket(attnPacket), SQL.AttentionHeader, TdsEnums.HEADER_LEN); + return PacketHandle.FromNativePacket(attnPacket); } - internal override uint WritePacket(object packet, bool sync) => SNINativeMethodWrapper.SNIWritePacket(Handle, (SNIPacket)packet, sync); + internal override uint WritePacket(PacketHandle packet, bool sync) + { + Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); + return SNINativeMethodWrapper.SNIWritePacket(Handle, packet.NativePacket, sync); + } - internal override object AddPacketToPendingList(object packetToAdd) + internal override PacketHandle AddPacketToPendingList(PacketHandle packetToAdd) { - SNIPacket packet = (SNIPacket)packetToAdd; + Debug.Assert(packetToAdd.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); + SNIPacket packet = packetToAdd.NativePacket; Debug.Assert(packet == _sniPacket, "Adding a packet other than the current packet to the pending list"); _sniPacket = null; IntPtr pointer = packet.DangerousGetHandle(); @@ -223,12 +241,20 @@ internal override object AddPacketToPendingList(object packetToAdd) _pendingWritePackets.Add(pointer, packet); } - return pointer; + return PacketHandle.FromNativePointer(pointer); } - internal override bool IsValidPacket(object packetPointer) => (IntPtr)packetPointer != IntPtr.Zero; + internal override bool IsValidPacket(PacketHandle packetPointer) + { + Debug.Assert(packetPointer.Type == PacketHandle.NativePointerType || packetPointer.Type==PacketHandle.NativePacketType, "unexpected packet type when requiring NativePointer"); + return ( + (packetPointer.Type == PacketHandle.NativePointerType && packetPointer.NativePointer != IntPtr.Zero) + || + (packetPointer.Type == PacketHandle.NativePacketType && packetPointer.NativePacket != null) + ); + } - internal override object GetResetWritePacket() + internal override PacketHandle GetResetWritePacket() { if (_sniPacket != null) { @@ -241,7 +267,7 @@ internal override object GetResetWritePacket() _sniPacket = _writePacketCache.Take(Handle); } } - return _sniPacket; + return PacketHandle.FromNativePacket(_sniPacket); } internal override void ClearAllWritePackets() @@ -258,8 +284,11 @@ internal override void ClearAllWritePackets() } } - internal override void SetPacketData(object packet, byte[] buffer, int bytesUsed) - => SNINativeMethodWrapper.SNIPacketSetData((SNIPacket)packet, buffer, bytesUsed); + internal override void SetPacketData(PacketHandle packet, byte[] buffer, int bytesUsed) + { + Debug.Assert(packet.Type == PacketHandle.NativePacketType, "unexpected packet type when requiring NativePacket"); + SNINativeMethodWrapper.SNIPacketSetData(packet.NativePacket, buffer, bytesUsed); + } internal override uint SniGetConnectionId(ref Guid clientConnectionId) => SNINativeMethodWrapper.SniGetConnectionId(Handle, ref clientConnectionId);