diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 8499f7801d..9dc48f24cd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -156,6 +156,9 @@ Microsoft\Data\Sql\SqlNotificationRequest.cs + + Microsoft\Data\SqlClient\AAsyncCallContext.cs + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationProvider.cs @@ -647,7 +650,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs index fd53e3fe50..cc07cd03c3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -323,9 +323,9 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData() collation != null ? collation.LCID : _defaultLCID, collation != null ? collation.SqlCompareOptions : SqlCompareOptions.None, colMetaData.udt?.Type, - false, // isMultiValued - null, // fieldmetadata - null, // extended properties + isMultiValued: false, + fieldMetaData: null, + extendedProperties: null, colMetaData.column, typeSpecificNamePart1, typeSpecificNamePart2, @@ -443,7 +443,7 @@ internal void Bind(TdsParserStateObject stateObj) _defaultLCID = _parser.DefaultLCID; } -#if NET6_0_OR_GREATER +#if !NETFRAMEWORK [SuppressMessage("ReflectionAnalysis", "IL2111", Justification = "System.Type.TypeInitializer would not be used in dataType and providerSpecificDataType columns.")] #endif @@ -763,11 +763,10 @@ private TdsOperationStatus TryCleanPartialRead() { AssertReaderState(requireData: true, permitAsync: true); - TdsOperationStatus result; - // VSTS DEVDIV2 380446: It is possible that read attempt we are cleaning after ended with partially // processed header (if it falls between network packets). In this case the first thing to do is to // finish reading the header, otherwise code will start treating unread header as TDS payload. + TdsOperationStatus result; if (_stateObj._partialHeaderBytesRead > 0) { result = _stateObj.TryProcessHeader(); @@ -1154,7 +1153,9 @@ private TdsOperationStatus TryConsumeMetaData() // NOTE: We doom connection for TdsParserState.Closed since it indicates that it is in some abnormal and unstable state, probably as a result of // closing from another thread. In general, TdsParserState.Closed does not necessitate dooming the connection. if (_parser.Connection != null) + { _parser.Connection.DoomThisConnection(); + } throw SQL.ConnectionDoomed(); } bool ignored; @@ -1252,7 +1253,7 @@ override public IEnumerator GetEnumerator() } /// -#if NET6_0_OR_GREATER +#if !NETFRAMEWORK [SuppressMessage("ReflectionAnalysis", "IL2093:MismatchOnMethodReturnValueBetweenOverrides", Justification = "Annotations for DbDataReader was not shipped in net6.0")] [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] @@ -1273,7 +1274,7 @@ override public Type GetFieldType(int i) } } -#if NET6_0_OR_GREATER +#if !NETFRAMEWORK [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] #endif private Type GetFieldTypeInternal(_SqlMetaData metaData) @@ -1368,7 +1369,7 @@ override public string GetName(int i) } /// -#if NET8_0_OR_GREATER +#if !NETFRAMEWORK && NET8_0_OR_GREATER [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] #endif override public Type GetProviderSpecificFieldType(int i) @@ -1387,7 +1388,7 @@ override public Type GetProviderSpecificFieldType(int i) } } -#if NET6_0_OR_GREATER +#if !NETFRAMEWORK [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] #endif private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData) @@ -1712,7 +1713,9 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf } if (dataIndex < 0) + { throw ADP.NegativeParameter(nameof(dataIndex)); + } if (dataIndex < _columnDataBytesRead) { @@ -1730,14 +1733,20 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf // if bad buffer index, throw if (bufferIndex < 0 || bufferIndex >= buffer.Length) + { throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); + } // if there is not enough room in the buffer for data if (length + bufferIndex > buffer.Length) + { throw ADP.InvalidBufferSizeOrIndex(length, bufferIndex); + } if (length < 0) + { throw ADP.InvalidDataLength(length); + } // Skip if needed if (cb > 0) @@ -1774,7 +1783,9 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf // note that since we are caching in an array, and arrays aren't 64 bit ready yet, // we need can cast to int if the dataIndex is in range if (dataIndex < 0) + { throw ADP.NegativeParameter(nameof(dataIndex)); + } if (dataIndex > int.MaxValue) { @@ -1828,9 +1839,13 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf { // help the user out in the case where there's less data than requested if ((ndataIndex + length) > cbytes) + { cbytes = cbytes - ndataIndex; + } else + { cbytes = length; + } } Buffer.BlockCopy(data, ndataIndex, buffer, bufferIndex, cbytes); @@ -1844,15 +1859,21 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf cbytes = data.Length; if (length < 0) + { throw ADP.InvalidDataLength(length); + } // if bad buffer index, throw if (bufferIndex < 0 || bufferIndex >= buffer.Length) + { throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); + } // if there is not enough room in the buffer for data if (cbytes + bufferIndex > buffer.Length) + { throw ADP.InvalidBufferSizeOrIndex(cbytes, bufferIndex); + } throw; } @@ -1868,7 +1889,6 @@ internal int GetBytesInternalSequential(int i, byte[] buffer, int index, int len throw ADP.AsyncOperationPending(); } - TdsOperationStatus result; int value; SqlStatistics statistics = null; Debug.Assert(_stateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); @@ -1877,7 +1897,7 @@ internal int GetBytesInternalSequential(int i, byte[] buffer, int index, int len statistics = SqlStatistics.StartTimer(Statistics); SetTimeout(timeoutMilliseconds ?? _defaultTimeoutMilliseconds); - result = TryReadColumnHeader(i); + TdsOperationStatus result = TryReadColumnHeader(i); if (result != TdsOperationStatus.Done) { throw SQL.SynchronousCallMayNotPend(); @@ -2161,7 +2181,9 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn // if dataIndex outside of data range, return 0 if (ndataIndex < 0 || ndataIndex >= cchars) + { return 0; + } try { @@ -2169,9 +2191,13 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn { // help the user out in the case where there's less data than requested if ((ndataIndex + length) > cchars) + { cchars = cchars - ndataIndex; + } else + { cchars = length; + } } Array.Copy(_columnDataChars, ndataIndex, buffer, bufferIndex, cchars); @@ -2186,15 +2212,21 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn cchars = _columnDataChars.Length; if (length < 0) + { throw ADP.InvalidDataLength(length); + } // if bad buffer index, throw if (bufferIndex < 0 || bufferIndex >= buffer.Length) + { throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); + } // if there is not enough room in the buffer for data if (cchars + bufferIndex > buffer.Length) + { throw ADP.InvalidBufferSizeOrIndex(cchars, bufferIndex); + } throw; } @@ -2244,7 +2276,9 @@ private long GetCharsFromPlpData(int i, long dataIndex, char[] buffer, int buffe // _columnDataCharsRead is 0 and dataIndex > _columnDataCharsRead is true below. // In both cases we will clean decoder if (dataIndex == 0) + { _stateObj._plpdecoder = null; + } bool isUnicode = _metaData[i].metaType.IsNCharType; @@ -2617,7 +2651,7 @@ private object GetSqlValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else { - throw ADP.DataReaderClosed(nameof(GetSqlValueFromSqlBufferInternal)); + throw ADP.DataReaderClosed(); } } else @@ -2817,7 +2851,7 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa } else { - throw ADP.DataReaderClosed(nameof(GetValueFromSqlBufferInternal)); + throw ADP.DataReaderClosed(); } } } @@ -2896,7 +2930,7 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met { return (T)(object)data.DateTime; } -#if NET6_0_OR_GREATER +#if !NETFRAMEWORK else if (typeof(T) == typeof(DateOnly) && dataType == typeof(DateTime) && _typeSystem > SqlConnectionString.TypeSystem.SQLServer2005) { return (T)(object)data.DateOnly; @@ -3557,7 +3591,7 @@ private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) SqlStatistics statistics = null; using (TryEventScope.Create("SqlDataReader.TryReadInternal | API | Object Id {0}", ObjectID)) { -#if !NET6_0_OR_GREATER +#if NETFRAMEWORK RuntimeHelpers.PrepareConstrainedRegions(); #endif @@ -3708,10 +3742,10 @@ private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) if ((!_sharedState._dataReady) && (_stateObj.HasPendingData)) { byte token; - TdsOperationStatus debugResult = _stateObj.TryPeekByte(out token); - if (debugResult != TdsOperationStatus.Done) + result = _stateObj.TryPeekByte(out token); + if (result != TdsOperationStatus.Done) { - return debugResult; + return result; } Debug.Assert(TdsParser.IsValidTdsToken(token), $"DataReady is false, but next token is invalid: {token,-2:X2}"); @@ -3808,7 +3842,7 @@ private TdsOperationStatus TryReadColumnData() TdsOperationStatus result = _parser.TryReadSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, (int)_sharedState._columnDataBytesRemaining, _stateObj, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, - columnMetaData.column); + columnMetaData.column, _command); if (result != TdsOperationStatus.Done) { // will read UDTs as VARBINARY. @@ -3970,7 +4004,6 @@ internal TdsOperationStatus TryReadColumnInternal(int i, bool readHeaderOnly = f } else { - // we have read past the column somehow, this is an error Debug.Assert(false, "We have read past the column somehow, this is an error"); } } @@ -4310,7 +4343,6 @@ internal TdsOperationStatus TrySetMetaData(_SqlMetaDataSet metaData, bool moreIn if (metaData != null) { - TdsOperationStatus result; // we are done consuming metadata only if there is no moreInfo if (!moreInfo) { @@ -4321,7 +4353,7 @@ internal TdsOperationStatus TrySetMetaData(_SqlMetaDataSet metaData, bool moreIn // Peek, and if row token present, set _hasRows true since there is a // row in the result byte b; - result = _stateObj.TryPeekByte(out b); + TdsOperationStatus result = _stateObj.TryPeekByte(out b); if (result != TdsOperationStatus.Done) { return result; @@ -5201,7 +5233,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat var metaData = _metaData; if ((data != null) && (metaData != null)) { - return Task.FromResult(GetFieldValueFromSqlBufferInternal(data[i], metaData[i], isAsync:false)); + return Task.FromResult(GetFieldValueFromSqlBufferInternal(data[i], metaData[i], isAsync: false)); } else { @@ -5241,7 +5273,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { _stateObj._shouldHaveEnoughData = true; #endif - return Task.FromResult(GetFieldValueInternal(i, isAsync:true)); + return Task.FromResult(GetFieldValueInternal(i, isAsync: true)); #if DEBUG } finally @@ -5305,19 +5337,24 @@ private static Task GetFieldValueAsyncExecute(Task task, object state) reader.PrepareForAsyncContinuation(); } - TdsOperationStatus result; if (typeof(T) == typeof(Stream) || typeof(T) == typeof(TextReader) || typeof(T) == typeof(XmlReader)) { - if (reader.IsCommandBehavior(CommandBehavior.SequentialAccess) && reader._sharedState._dataReady && reader.TryReadColumnInternal(context._columnIndex, readHeaderOnly: true) == TdsOperationStatus.Done) + if (reader.IsCommandBehavior(CommandBehavior.SequentialAccess) && reader._sharedState._dataReady) { - return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync: true)); + bool internalReadSuccess = false; + internalReadSuccess = reader.TryReadColumnInternal(context._columnIndex, readHeaderOnly: true) == TdsOperationStatus.Done; + + if (internalReadSuccess) + { + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync: true)); + } } } - result = reader.TryReadColumn(columnIndex, setTimeout: false); + TdsOperationStatus result = reader.TryReadColumn(columnIndex, setTimeout: false); if (result == TdsOperationStatus.Done) { - return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync:false)); + return Task.FromResult(reader.GetFieldValueFromSqlBufferInternal(reader._data[columnIndex], reader._metaData[columnIndex], isAsync: false)); } else { @@ -5683,7 +5720,7 @@ private void CompleteAsyncCall(Task task, SqlDataReaderBaseAsyncCallContex } - internal class Snapshot + internal sealed class Snapshot { public bool _dataReady; public bool _haltRead; diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj index 9b161fa048..eae42b9e60 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj @@ -202,6 +202,9 @@ Microsoft\Data\Sql\SqlNotificationRequest.cs + + Microsoft\Data\SqlClient\AAsyncCallContext.cs + Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs index 7f680ebd39..e8f4938964 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlDataReader.cs @@ -73,8 +73,8 @@ internal class SharedState private FieldNameLookup _fieldNameLookup; private CommandBehavior _commandBehavior; - private static int _objectTypeCount; // EventSource Counter - internal readonly int ObjectID = System.Threading.Interlocked.Increment(ref _objectTypeCount); + private static int s_objectTypeCount; // EventSource Counter + internal readonly int ObjectID = Interlocked.Increment(ref s_objectTypeCount); // context // undone: we may still want to do this...it's nice to pass in an lpvoid (essentially) and just have the reader keep the state @@ -164,7 +164,7 @@ override public int Depth { if (this.IsClosed) { - throw ADP.DataReaderClosed("Depth"); + throw ADP.DataReaderClosed(); } return 0; @@ -179,7 +179,7 @@ override public int FieldCount { if (this.IsClosed) { - throw ADP.DataReaderClosed("FieldCount"); + throw ADP.DataReaderClosed(); } if (_currentTask != null) { @@ -202,7 +202,7 @@ override public bool HasRows { if (this.IsClosed) { - throw ADP.DataReaderClosed("HasRows"); + throw ADP.DataReaderClosed(); } if (_currentTask != null) { @@ -253,7 +253,7 @@ internal _SqlMetaDataSet MetaData { if (IsClosed) { - throw ADP.DataReaderClosed("MetaData"); + throw ADP.DataReaderClosed(); } // metaData comes in pieces: colmetadata, tabname, colinfo, etc // if we have any metaData, return it. If we have none, @@ -371,34 +371,35 @@ internal virtual SmiExtendedMetaData[] GetInternalSmiMetaData() length /= ADP.CharSize; } - metaDataReturn[returnIndex] = new SmiQueryMetaData( - colMetaData.type, - length, - colMetaData.precision, - colMetaData.scale, - collation != null ? collation.LCID : _defaultLCID, - collation != null ? collation.SqlCompareOptions : SqlCompareOptions.None, - colMetaData.udt?.Type, - false, // isMultiValued - null, // fieldmetadata - null, // extended properties - colMetaData.column, - typeSpecificNamePart1, - typeSpecificNamePart2, - typeSpecificNamePart3, - colMetaData.IsNullable, - colMetaData.serverName, - colMetaData.catalogName, - colMetaData.schemaName, - colMetaData.tableName, - colMetaData.baseColumn, - colMetaData.IsKey, - colMetaData.IsIdentity, - colMetaData.IsReadOnly, - colMetaData.IsExpression, - colMetaData.IsDifferentName, - colMetaData.IsHidden - ); + metaDataReturn[returnIndex] = + new SmiQueryMetaData( + colMetaData.type, + length, + colMetaData.precision, + colMetaData.scale, + collation != null ? collation.LCID : _defaultLCID, + collation != null ? collation.SqlCompareOptions : SqlCompareOptions.None, + colMetaData.udt?.Type, + isMultiValued: false, + fieldMetaData: null, + extendedProperties: null, + colMetaData.column, + typeSpecificNamePart1, + typeSpecificNamePart2, + typeSpecificNamePart3, + colMetaData.IsNullable, + colMetaData.serverName, + colMetaData.catalogName, + colMetaData.schemaName, + colMetaData.tableName, + colMetaData.baseColumn, + colMetaData.IsKey, + colMetaData.IsIdentity, + colMetaData.IsReadOnly, + colMetaData.IsExpression, + colMetaData.IsDifferentName, + colMetaData.IsHidden); + returnIndex += 1; } } @@ -449,6 +450,7 @@ internal MultiPartTableName[] TableNames _tableNames = value; } } + /// override public int VisibleFieldCount { @@ -456,7 +458,7 @@ override public int VisibleFieldCount { if (this.IsClosed) { - throw ADP.DataReaderClosed("VisibleFieldCount"); + throw ADP.DataReaderClosed(); } _SqlMetaDataSet md = this.MetaData; if (md == null) @@ -489,6 +491,7 @@ override public object this[string name] internal void Bind(TdsParserStateObject stateObj) { Debug.Assert(stateObj != null, "null stateobject"); + Debug.Assert(_snapshot == null, "Should not change during execution of asynchronous command"); stateObj.Owner = this; @@ -500,6 +503,10 @@ internal void Bind(TdsParserStateObject stateObj) // Fills in a schema table with meta data information. This function should only really be called by // UNDONE: need a way to refresh the table with more information as more data comes online for browse info like // table names and key information +#if !NETFRAMEWORK + [SuppressMessage("ReflectionAnalysis", "IL2111", + Justification = "System.Type.TypeInitializer would not be used in dataType and providerSpecificDataType columns.")] +#endif internal DataTable BuildSchemaTable() { _SqlMetaDataSet md = this.MetaData; @@ -509,91 +516,91 @@ internal DataTable BuildSchemaTable() schemaTable.Locale = CultureInfo.InvariantCulture; schemaTable.MinimumCapacity = md.Length; - DataColumn ColumnName = new DataColumn(SchemaTableColumn.ColumnName, typeof(System.String)); - DataColumn Ordinal = new DataColumn(SchemaTableColumn.ColumnOrdinal, typeof(System.Int32)); - DataColumn Size = new DataColumn(SchemaTableColumn.ColumnSize, typeof(System.Int32)); - DataColumn Precision = new DataColumn(SchemaTableColumn.NumericPrecision, typeof(System.Int16)); - DataColumn Scale = new DataColumn(SchemaTableColumn.NumericScale, typeof(System.Int16)); + DataColumn columnName = new DataColumn(SchemaTableColumn.ColumnName, typeof(string)); + DataColumn ordinal = new DataColumn(SchemaTableColumn.ColumnOrdinal, typeof(int)); + DataColumn size = new DataColumn(SchemaTableColumn.ColumnSize, typeof(int)); + DataColumn precision = new DataColumn(SchemaTableColumn.NumericPrecision, typeof(short)); + DataColumn scale = new DataColumn(SchemaTableColumn.NumericScale, typeof(short)); - DataColumn DataType = new DataColumn(SchemaTableColumn.DataType, typeof(System.Type)); - DataColumn ProviderSpecificDataType = new DataColumn(SchemaTableOptionalColumn.ProviderSpecificDataType, typeof(System.Type)); - DataColumn NonVersionedProviderType = new DataColumn(SchemaTableColumn.NonVersionedProviderType, typeof(System.Int32)); - DataColumn ProviderType = new DataColumn(SchemaTableColumn.ProviderType, typeof(System.Int32)); + DataColumn dataType = new DataColumn(SchemaTableColumn.DataType, typeof(System.Type)); + DataColumn providerSpecificDataType = new DataColumn(SchemaTableOptionalColumn.ProviderSpecificDataType, typeof(System.Type)); + DataColumn nonVersionedProviderType = new DataColumn(SchemaTableColumn.NonVersionedProviderType, typeof(int)); + DataColumn providerType = new DataColumn(SchemaTableColumn.ProviderType, typeof(int)); - DataColumn IsLong = new DataColumn(SchemaTableColumn.IsLong, typeof(System.Boolean)); - DataColumn AllowDBNull = new DataColumn(SchemaTableColumn.AllowDBNull, typeof(System.Boolean)); - DataColumn IsReadOnly = new DataColumn(SchemaTableOptionalColumn.IsReadOnly, typeof(System.Boolean)); - DataColumn IsRowVersion = new DataColumn(SchemaTableOptionalColumn.IsRowVersion, typeof(System.Boolean)); + DataColumn isLong = new DataColumn(SchemaTableColumn.IsLong, typeof(bool)); + DataColumn allowDBNull = new DataColumn(SchemaTableColumn.AllowDBNull, typeof(bool)); + DataColumn isReadOnly = new DataColumn(SchemaTableOptionalColumn.IsReadOnly, typeof(bool)); + DataColumn isRowVersion = new DataColumn(SchemaTableOptionalColumn.IsRowVersion, typeof(bool)); - DataColumn IsUnique = new DataColumn(SchemaTableColumn.IsUnique, typeof(System.Boolean)); - DataColumn IsKey = new DataColumn(SchemaTableColumn.IsKey, typeof(System.Boolean)); - DataColumn IsAutoIncrement = new DataColumn(SchemaTableOptionalColumn.IsAutoIncrement, typeof(System.Boolean)); - DataColumn IsHidden = new DataColumn(SchemaTableOptionalColumn.IsHidden, typeof(System.Boolean)); + DataColumn isUnique = new DataColumn(SchemaTableColumn.IsUnique, typeof(bool)); + DataColumn isKey = new DataColumn(SchemaTableColumn.IsKey, typeof(bool)); + DataColumn isAutoIncrement = new DataColumn(SchemaTableOptionalColumn.IsAutoIncrement, typeof(bool)); + DataColumn isHidden = new DataColumn(SchemaTableOptionalColumn.IsHidden, typeof(bool)); - DataColumn BaseCatalogName = new DataColumn(SchemaTableOptionalColumn.BaseCatalogName, typeof(System.String)); - DataColumn BaseSchemaName = new DataColumn(SchemaTableColumn.BaseSchemaName, typeof(System.String)); - DataColumn BaseTableName = new DataColumn(SchemaTableColumn.BaseTableName, typeof(System.String)); - DataColumn BaseColumnName = new DataColumn(SchemaTableColumn.BaseColumnName, typeof(System.String)); + DataColumn baseCatalogName = new DataColumn(SchemaTableOptionalColumn.BaseCatalogName, typeof(string)); + DataColumn baseSchemaName = new DataColumn(SchemaTableColumn.BaseSchemaName, typeof(string)); + DataColumn baseTableName = new DataColumn(SchemaTableColumn.BaseTableName, typeof(string)); + DataColumn baseColumnName = new DataColumn(SchemaTableColumn.BaseColumnName, typeof(string)); // unique to SqlClient - DataColumn BaseServerName = new DataColumn(SchemaTableOptionalColumn.BaseServerName, typeof(System.String)); - DataColumn IsAliased = new DataColumn(SchemaTableColumn.IsAliased, typeof(System.Boolean)); - DataColumn IsExpression = new DataColumn(SchemaTableColumn.IsExpression, typeof(System.Boolean)); - DataColumn IsIdentity = new DataColumn("IsIdentity", typeof(System.Boolean)); - DataColumn DataTypeName = new DataColumn("DataTypeName", typeof(System.String)); - DataColumn UdtAssemblyQualifiedName = new DataColumn("UdtAssemblyQualifiedName", typeof(System.String)); + DataColumn baseServerName = new DataColumn(SchemaTableOptionalColumn.BaseServerName, typeof(string)); + DataColumn isAliased = new DataColumn(SchemaTableColumn.IsAliased, typeof(bool)); + DataColumn isExpression = new DataColumn(SchemaTableColumn.IsExpression, typeof(bool)); + DataColumn isIdentity = new DataColumn("IsIdentity", typeof(bool)); + DataColumn dataTypeName = new DataColumn("DataTypeName", typeof(string)); + DataColumn udtAssemblyQualifiedName = new DataColumn("UdtAssemblyQualifiedName", typeof(string)); // Xml metadata specific - DataColumn XmlSchemaCollectionDatabase = new DataColumn("XmlSchemaCollectionDatabase", typeof(System.String)); - DataColumn XmlSchemaCollectionOwningSchema = new DataColumn("XmlSchemaCollectionOwningSchema", typeof(System.String)); - DataColumn XmlSchemaCollectionName = new DataColumn("XmlSchemaCollectionName", typeof(System.String)); + DataColumn xmlSchemaCollectionDatabase = new DataColumn("XmlSchemaCollectionDatabase", typeof(string)); + DataColumn xmlSchemaCollectionOwningSchema = new DataColumn("XmlSchemaCollectionOwningSchema", typeof(string)); + DataColumn xmlSchemaCollectionName = new DataColumn("XmlSchemaCollectionName", typeof(string)); // SparseColumnSet - DataColumn IsColumnSet = new DataColumn("IsColumnSet", typeof(System.Boolean)); + DataColumn isColumnSet = new DataColumn("IsColumnSet", typeof(bool)); - Ordinal.DefaultValue = 0; - IsLong.DefaultValue = false; + ordinal.DefaultValue = 0; + isLong.DefaultValue = false; DataColumnCollection columns = schemaTable.Columns; // must maintain order for backward compatibility - columns.Add(ColumnName); - columns.Add(Ordinal); - columns.Add(Size); - columns.Add(Precision); - columns.Add(Scale); - columns.Add(IsUnique); - columns.Add(IsKey); - columns.Add(BaseServerName); - columns.Add(BaseCatalogName); - columns.Add(BaseColumnName); - columns.Add(BaseSchemaName); - columns.Add(BaseTableName); - columns.Add(DataType); - columns.Add(AllowDBNull); - columns.Add(ProviderType); - columns.Add(IsAliased); - columns.Add(IsExpression); - columns.Add(IsIdentity); - columns.Add(IsAutoIncrement); - columns.Add(IsRowVersion); - columns.Add(IsHidden); - columns.Add(IsLong); - columns.Add(IsReadOnly); - columns.Add(ProviderSpecificDataType); - columns.Add(DataTypeName); - columns.Add(XmlSchemaCollectionDatabase); - columns.Add(XmlSchemaCollectionOwningSchema); - columns.Add(XmlSchemaCollectionName); - columns.Add(UdtAssemblyQualifiedName); - columns.Add(NonVersionedProviderType); - columns.Add(IsColumnSet); + columns.Add(columnName); + columns.Add(ordinal); + columns.Add(size); + columns.Add(precision); + columns.Add(scale); + columns.Add(isUnique); + columns.Add(isKey); + columns.Add(baseServerName); + columns.Add(baseCatalogName); + columns.Add(baseColumnName); + columns.Add(baseSchemaName); + columns.Add(baseTableName); + columns.Add(dataType); + columns.Add(allowDBNull); + columns.Add(providerType); + columns.Add(isAliased); + columns.Add(isExpression); + columns.Add(isIdentity); + columns.Add(isAutoIncrement); + columns.Add(isRowVersion); + columns.Add(isHidden); + columns.Add(isLong); + columns.Add(isReadOnly); + columns.Add(providerSpecificDataType); + columns.Add(dataTypeName); + columns.Add(xmlSchemaCollectionDatabase); + columns.Add(xmlSchemaCollectionOwningSchema); + columns.Add(xmlSchemaCollectionName); + columns.Add(udtAssemblyQualifiedName); + columns.Add(nonVersionedProviderType); + columns.Add(isColumnSet); for (int i = 0; i < md.Length; i++) { _SqlMetaData col = md[i]; DataRow schemaRow = schemaTable.NewRow(); - schemaRow[ColumnName] = col.column; - schemaRow[Ordinal] = col.ordinal; + schemaRow[columnName] = col.column; + schemaRow[ordinal] = col.ordinal; // // be sure to return character count for string types, byte count otherwise // col.length is always byte count so for unicode types, half the length @@ -602,37 +609,37 @@ internal DataTable BuildSchemaTable() if (col.cipherMD != null) { Debug.Assert(col.baseTI != null && col.baseTI.metaType != null, "col.baseTI and col.baseTI.metaType should not be null."); - schemaRow[Size] = (col.baseTI.metaType.IsSizeInCharacters && (col.baseTI.length != 0x7fffffff)) ? (col.baseTI.length / 2) : col.baseTI.length; + schemaRow[size] = (col.baseTI.metaType.IsSizeInCharacters && (col.baseTI.length != 0x7fffffff)) ? (col.baseTI.length / 2) : col.baseTI.length; } else { - schemaRow[Size] = (col.metaType.IsSizeInCharacters && (col.length != 0x7fffffff)) ? (col.length / 2) : col.length; + schemaRow[size] = (col.metaType.IsSizeInCharacters && (col.length != 0x7fffffff)) ? (col.length / 2) : col.length; } - schemaRow[DataType] = GetFieldTypeInternal(col); - schemaRow[ProviderSpecificDataType] = GetProviderSpecificFieldTypeInternal(col); - schemaRow[NonVersionedProviderType] = (int)(col.cipherMD != null ? col.baseTI.type : col.type); // SqlDbType enum value - does not change with TypeSystem. - schemaRow[DataTypeName] = GetDataTypeNameInternal(col); + schemaRow[dataType] = GetFieldTypeInternal(col); + schemaRow[providerSpecificDataType] = GetProviderSpecificFieldTypeInternal(col); + schemaRow[nonVersionedProviderType] = (int)(col.cipherMD != null ? col.baseTI.type : col.type); // SqlDbType enum value - does not change with TypeSystem. + schemaRow[dataTypeName] = GetDataTypeNameInternal(col); if (_typeSystem <= SqlConnectionString.TypeSystem.SQLServer2005 && col.Is2008DateTimeType) { - schemaRow[ProviderType] = SqlDbType.NVarChar; + schemaRow[providerType] = SqlDbType.NVarChar; switch (col.type) { case SqlDbType.Date: - schemaRow[Size] = TdsEnums.WHIDBEY_DATE_LENGTH; + schemaRow[size] = TdsEnums.WHIDBEY_DATE_LENGTH; break; case SqlDbType.Time: Debug.Assert(TdsEnums.UNKNOWN_PRECISION_SCALE == col.scale || (0 <= col.scale && col.scale <= 7), "Invalid scale for Time column: " + col.scale); - schemaRow[Size] = TdsEnums.WHIDBEY_TIME_LENGTH[TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale ? col.scale : col.metaType.Scale]; + schemaRow[size] = TdsEnums.WHIDBEY_TIME_LENGTH[TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale ? col.scale : col.metaType.Scale]; break; case SqlDbType.DateTime2: Debug.Assert(TdsEnums.UNKNOWN_PRECISION_SCALE == col.scale || (0 <= col.scale && col.scale <= 7), "Invalid scale for DateTime2 column: " + col.scale); - schemaRow[Size] = TdsEnums.WHIDBEY_DATETIME2_LENGTH[TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale ? col.scale : col.metaType.Scale]; + schemaRow[size] = TdsEnums.WHIDBEY_DATETIME2_LENGTH[TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale ? col.scale : col.metaType.Scale]; break; case SqlDbType.DateTimeOffset: Debug.Assert(TdsEnums.UNKNOWN_PRECISION_SCALE == col.scale || (0 <= col.scale && col.scale <= 7), "Invalid scale for DateTimeOffset column: " + col.scale); - schemaRow[Size] = TdsEnums.WHIDBEY_DATETIMEOFFSET_LENGTH[TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale ? col.scale : col.metaType.Scale]; + schemaRow[size] = TdsEnums.WHIDBEY_DATETIMEOFFSET_LENGTH[TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale ? col.scale : col.metaType.Scale]; break; } } @@ -640,12 +647,12 @@ internal DataTable BuildSchemaTable() { if (_typeSystem == SqlConnectionString.TypeSystem.SQLServer2005) { - schemaRow[ProviderType] = SqlDbType.VarBinary; + schemaRow[providerType] = SqlDbType.VarBinary; } else { // TypeSystem.SQLServer2000 - schemaRow[ProviderType] = SqlDbType.Image; + schemaRow[providerType] = SqlDbType.Image; } } else if (_typeSystem != SqlConnectionString.TypeSystem.SQLServer2000) @@ -653,19 +660,19 @@ internal DataTable BuildSchemaTable() // TypeSystem.SQLServer2005 and above // SqlDbType enum value - always the actual type for SQLServer2005. - schemaRow[ProviderType] = (int)(col.cipherMD != null ? col.baseTI.type : col.type); + schemaRow[providerType] = (int)(col.cipherMD != null ? col.baseTI.type : col.type); if (col.type == SqlDbType.Udt) { // Additional metadata for UDTs. Debug.Assert(Connection.Is2005OrNewer, "Invalid Column type received from the server"); - schemaRow[UdtAssemblyQualifiedName] = col.udt?.AssemblyQualifiedName; + schemaRow[udtAssemblyQualifiedName] = col.udt?.AssemblyQualifiedName; } else if (col.type == SqlDbType.Xml) { // Additional metadata for Xml. Debug.Assert(Connection.Is2005OrNewer, "Invalid DataType (Xml) for the column"); - schemaRow[XmlSchemaCollectionDatabase] = col.xmlSchemaCollection?.Database; - schemaRow[XmlSchemaCollectionOwningSchema] = col.xmlSchemaCollection?.OwningSchema; - schemaRow[XmlSchemaCollectionName] = col.xmlSchemaCollection?.Name; + schemaRow[xmlSchemaCollectionDatabase] = col.xmlSchemaCollection?.Database; + schemaRow[xmlSchemaCollectionOwningSchema] = col.xmlSchemaCollection?.OwningSchema; + schemaRow[xmlSchemaCollectionName] = col.xmlSchemaCollection?.Name; } } else @@ -673,7 +680,7 @@ internal DataTable BuildSchemaTable() // TypeSystem.SQLServer2000 // SqlDbType enum value - variable for certain types when SQLServer2000. - schemaRow[ProviderType] = GetVersionedMetaType(col.metaType).SqlDbType; + schemaRow[providerType] = GetVersionedMetaType(col.metaType).SqlDbType; } if (col.cipherMD != null) @@ -681,110 +688,110 @@ internal DataTable BuildSchemaTable() Debug.Assert(col.baseTI != null, @"col.baseTI should not be null."); if (TdsEnums.UNKNOWN_PRECISION_SCALE != col.baseTI.precision) { - schemaRow[Precision] = col.baseTI.precision; + schemaRow[precision] = col.baseTI.precision; } else { - schemaRow[Precision] = col.baseTI.metaType.Precision; + schemaRow[precision] = col.baseTI.metaType.Precision; } } else if (TdsEnums.UNKNOWN_PRECISION_SCALE != col.precision) { - schemaRow[Precision] = col.precision; + schemaRow[precision] = col.precision; } else { - schemaRow[Precision] = col.metaType.Precision; + schemaRow[precision] = col.metaType.Precision; } if (_typeSystem <= SqlConnectionString.TypeSystem.SQLServer2005 && col.Is2008DateTimeType) { - schemaRow[Scale] = MetaType.MetaNVarChar.Scale; + schemaRow[scale] = MetaType.MetaNVarChar.Scale; } else if (col.cipherMD != null) { Debug.Assert(col.baseTI != null, @"col.baseTI should not be null."); if (TdsEnums.UNKNOWN_PRECISION_SCALE != col.baseTI.scale) { - schemaRow[Scale] = col.baseTI.scale; + schemaRow[scale] = col.baseTI.scale; } else { - schemaRow[Scale] = col.baseTI.metaType.Scale; + schemaRow[scale] = col.baseTI.metaType.Scale; } } else if (TdsEnums.UNKNOWN_PRECISION_SCALE != col.scale) { - schemaRow[Scale] = col.scale; + schemaRow[scale] = col.scale; } else { - schemaRow[Scale] = col.metaType.Scale; + schemaRow[scale] = col.metaType.Scale; } - schemaRow[AllowDBNull] = col.IsNullable; + schemaRow[allowDBNull] = col.IsNullable; // If no ColInfo token received, do not set value, leave as null. if (_browseModeInfoConsumed) { - schemaRow[IsAliased] = col.IsDifferentName; - schemaRow[IsKey] = col.IsKey; - schemaRow[IsHidden] = col.IsHidden; - schemaRow[IsExpression] = col.IsExpression; + schemaRow[isAliased] = col.IsDifferentName; + schemaRow[isKey] = col.IsKey; + schemaRow[isHidden] = col.IsHidden; + schemaRow[isExpression] = col.IsExpression; } - schemaRow[IsIdentity] = col.IsIdentity; - schemaRow[IsAutoIncrement] = col.IsIdentity; + schemaRow[isIdentity] = col.IsIdentity; + schemaRow[isAutoIncrement] = col.IsIdentity; if (col.cipherMD != null) { Debug.Assert(col.baseTI != null, @"col.baseTI should not be null."); Debug.Assert(col.baseTI.metaType != null, @"col.baseTI.metaType should not be null."); - schemaRow[IsLong] = col.baseTI.metaType.IsLong; + schemaRow[isLong] = col.baseTI.metaType.IsLong; } else { - schemaRow[IsLong] = col.metaType.IsLong; + schemaRow[isLong] = col.metaType.IsLong; } // mark unique for timestamp columns if (SqlDbType.Timestamp == col.type) { - schemaRow[IsUnique] = true; - schemaRow[IsRowVersion] = true; + schemaRow[isUnique] = true; + schemaRow[isRowVersion] = true; } else { - schemaRow[IsUnique] = false; - schemaRow[IsRowVersion] = false; + schemaRow[isUnique] = false; + schemaRow[isRowVersion] = false; } - schemaRow[IsReadOnly] = col.IsReadOnly; - schemaRow[IsColumnSet] = col.IsColumnSet; + schemaRow[isReadOnly] = col.IsReadOnly; + schemaRow[isColumnSet] = col.IsColumnSet; - if (!ADP.IsEmpty(col.serverName)) + if (!string.IsNullOrEmpty(col.serverName)) { - schemaRow[BaseServerName] = col.serverName; + schemaRow[baseServerName] = col.serverName; } - if (!ADP.IsEmpty(col.catalogName)) + if (!string.IsNullOrEmpty(col.catalogName)) { - schemaRow[BaseCatalogName] = col.catalogName; + schemaRow[baseCatalogName] = col.catalogName; } - if (!ADP.IsEmpty(col.schemaName)) + if (!string.IsNullOrEmpty(col.schemaName)) { - schemaRow[BaseSchemaName] = col.schemaName; + schemaRow[baseSchemaName] = col.schemaName; } - if (!ADP.IsEmpty(col.tableName)) + if (!string.IsNullOrEmpty(col.tableName)) { - schemaRow[BaseTableName] = col.tableName; + schemaRow[baseTableName] = col.tableName; } - if (!ADP.IsEmpty(col.baseColumn)) + if (!string.IsNullOrEmpty(col.baseColumn)) { - schemaRow[BaseColumnName] = col.baseColumn; + schemaRow[baseColumnName] = col.baseColumn; } - else if (!ADP.IsEmpty(col.column)) + else if (!string.IsNullOrEmpty(col.column)) { - schemaRow[BaseColumnName] = col.column; + schemaRow[baseColumnName] = col.column; } schemaTable.Rows.Add(schemaRow); @@ -850,7 +857,6 @@ private TdsOperationStatus TryCleanPartialRead() } else { - // iia. if we still have bytes left from a partially read column, skip result = TryResetBlobState(); if (result != TdsOperationStatus.Done) @@ -871,14 +877,13 @@ private TdsOperationStatus TryCleanPartialRead() if (_stateObj.HasPendingData) { byte token; - TdsOperationStatus debugResult = _stateObj.TryPeekByte(out token); - if (debugResult != TdsOperationStatus.Done) + result = _stateObj.TryPeekByte(out token); + if (result != TdsOperationStatus.Done) { - return debugResult; + return result; } Debug.Assert(TdsParser.IsValidTdsToken(token), string.Format("Invalid token after performing CleanPartialRead: {0,-2:X2}", token)); - } #endif _sharedState._dataReady = false; @@ -961,11 +966,11 @@ protected override void Dispose(bool disposing) } /// - override public void Close() + public override void Close() { - SqlStatistics statistics = null; using (TryEventScope.Create(" {0}", ObjectID)) { + SqlStatistics statistics = null; try { statistics = SqlStatistics.StartTimer(Statistics); @@ -1014,14 +1019,11 @@ override public void Close() if (stateObj != null) { - // protect against concurrent close and cancel lock (stateObj) { - if (_stateObj != null) { // reader not closed while we waited for the lock - // TryCloseInternal will clear out the snapshot when it is done if (_snapshot != null) { @@ -1042,7 +1044,6 @@ override public void Close() { throw SQL.SynchronousCallMayNotPend(); } - // DO NOT USE stateObj after this point - it has been returned to the TdsParser's session pool and potentially handed out to another thread } } @@ -1078,7 +1079,6 @@ private TdsOperationStatus TryCloseInternal(bool closeReader) #endif //DEBUG if ((!_isClosed) && (parser != null) && (stateObj != null) && (stateObj.HasPendingData)) { - // It is possible for this to be called during connection close on a // broken connection, so check state first. if (parser.State == TdsParserState.OpenLoggedIn) @@ -1238,7 +1238,6 @@ private TdsOperationStatus TryCloseInternal(bool closeReader) } } } - // DO NOT USE stateObj after this point - it has been returned to the TdsParser's session pool and potentially handed out to another thread } #if DEBUG @@ -1348,12 +1347,14 @@ private TdsOperationStatus TryConsumeMetaData() { if (_parser.State == TdsParserState.Broken || _parser.State == TdsParserState.Closed) { - // Happened for DEVDIV2:180509 (SqlDataReader.ConsumeMetaData Hangs In 100% CPU Loop Forever When TdsParser._state == TdsParserState.Broken) + // Happened for DEVDIV2:180509 (SqlDataReader.ConsumeMetaData Hangs In 100% CPU Loop Forever When TdsParser._state == TdsParserState.Broken) // during request for DTC address. // NOTE: We doom connection for TdsParserState.Closed since it indicates that it is in some abnormal and unstable state, probably as a result of // closing from another thread. In general, TdsParserState.Closed does not necessitate dooming the connection. if (_parser.Connection != null) + { _parser.Connection.DoomThisConnection(); + } throw SQL.ConnectionDoomed(); } bool ignored; @@ -1365,7 +1366,6 @@ private TdsOperationStatus TryConsumeMetaData() Debug.Assert(!ignored, "Parser read a row token while trying to read metadata"); } - return TdsOperationStatus.Done; } @@ -1453,6 +1453,11 @@ override public IEnumerator GetEnumerator() } /// +#if !NETFRAMEWORK + [SuppressMessage("ReflectionAnalysis", "IL2093:MismatchOnMethodReturnValueBetweenOverrides", + Justification = "Annotations for DbDataReader was not shipped in net6.0")] + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] +#endif override public Type GetFieldType(int i) { SqlStatistics statistics = null; @@ -1469,6 +1474,9 @@ override public Type GetFieldType(int i) } } +#if !NETFRAMEWORK + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] +#endif private Type GetFieldTypeInternal(_SqlMetaData metaData) { Type fieldType = null; @@ -1493,7 +1501,6 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData) else if (_typeSystem != SqlConnectionString.TypeSystem.SQLServer2000) { // TypeSystem.SQLServer2005 and above - if (metaData.type == SqlDbType.Udt) { Debug.Assert(Connection.Is2005OrNewer, "Invalid Column type received from the server"); @@ -1516,7 +1523,6 @@ private Type GetFieldTypeInternal(_SqlMetaData metaData) else { // TypeSystem.SQLServer2000 - fieldType = GetVersionedMetaType(metaData.metaType).ClassType; // Com+ type. } @@ -1552,6 +1558,7 @@ virtual internal int GetLocaleId(int i) lcid = 0; } } + return lcid; } @@ -1565,6 +1572,9 @@ override public string GetName(int i) } /// +#if !NETFRAMEWORK && NET8_0_OR_GREATER + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] +#endif override public Type GetProviderSpecificFieldType(int i) { SqlStatistics statistics = null; @@ -1581,6 +1591,9 @@ override public Type GetProviderSpecificFieldType(int i) } } +#if !NETFRAMEWORK + [return: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] +#endif private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData) { Type providerSpecificFieldType = null; @@ -1604,7 +1617,6 @@ private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData) else if (_typeSystem != SqlConnectionString.TypeSystem.SQLServer2000) { // TypeSystem.SQLServer2005 and above - if (metaData.type == SqlDbType.Udt) { Debug.Assert(Connection.Is2005OrNewer, "Invalid Column type received from the server"); @@ -1629,7 +1641,6 @@ private Type GetProviderSpecificFieldTypeInternal(_SqlMetaData metaData) else { // TypeSystem.SQLServer2000 - providerSpecificFieldType = GetVersionedMetaType(metaData.metaType).SqlType; // SqlType type. } @@ -1670,7 +1681,7 @@ override public int GetProviderSpecificValues(object[] values) } /// - override public DataTable GetSchemaTable() + public override DataTable GetSchemaTable() { SqlStatistics statistics = null; using (TryEventScope.Create(" {0}", ObjectID)) @@ -1706,8 +1717,8 @@ override public bool GetBoolean(int i) virtual public XmlReader GetXmlReader(int i) { // NOTE: sql_variant can not contain a XML data type: http://msdn.microsoft.com/en-us/library/ms173829.aspx - // If this ever changes, the following code should be changed to be like GetStream\GetTextReader - CheckDataIsReady(columnIndex: i, methodName: "GetXmlReader"); + // If this ever changes, the following code should be changed to be like GetStream/GetTextReader + CheckDataIsReady(columnIndex: i); MetaType mt = _metaData[i].metaType; @@ -1732,7 +1743,7 @@ virtual public XmlReader GetXmlReader(int i) if (_data[i].IsNull) { // A 'null' stream - return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(new byte[0], writable: false), closeInput: true, async: false); + return SqlTypeWorkarounds.SqlXmlCreateSqlXmlReader(new MemoryStream(Array.Empty(), writable: false), closeInput: true, async: false); } else { @@ -1745,7 +1756,7 @@ virtual public XmlReader GetXmlReader(int i) /// override public Stream GetStream(int i) { - CheckDataIsReady(columnIndex: i, methodName: "GetStream"); + CheckDataIsReady(columnIndex: i); // Streaming is not supported on encrypted columns. if (_metaData[i] != null && _metaData[i].cipherMD != null) @@ -1777,7 +1788,7 @@ override public Stream GetStream(int i) if (_data[i].IsNull) { // A 'null' stream - data = new byte[0]; + data = Array.Empty(); } else { @@ -1803,7 +1814,7 @@ override public long GetBytes(int i, long dataIndex, byte[] buffer, int bufferIn SqlStatistics statistics = null; long cbBytes = 0; - CheckDataIsReady(columnIndex: i, allowPartiallyReadColumn: true, methodName: "GetBytes"); + CheckDataIsReady(columnIndex: i, allowPartiallyReadColumn: true); // don't allow get bytes on non-long or non-binary columns MetaType mt = _metaData[i].metaType; @@ -1919,7 +1930,9 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf } if (dataIndex < 0) - throw ADP.NegativeParameter("dataIndex"); + { + throw ADP.NegativeParameter(nameof(dataIndex)); + } if (dataIndex < _columnDataBytesRead) { @@ -1937,14 +1950,20 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf // if bad buffer index, throw if (bufferIndex < 0 || bufferIndex >= buffer.Length) - throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, "bufferIndex"); + { + throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); + } // if there is not enough room in the buffer for data if (length + bufferIndex > buffer.Length) + { throw ADP.InvalidBufferSizeOrIndex(length, bufferIndex); + } if (length < 0) + { throw ADP.InvalidDataLength(length); + } // Skip if needed if (cb > 0) @@ -1981,11 +2000,13 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf // note that since we are caching in an array, and arrays aren't 64 bit ready yet, // we need can cast to int if the dataIndex is in range if (dataIndex < 0) - throw ADP.NegativeParameter("dataIndex"); + { + throw ADP.NegativeParameter(nameof(dataIndex)); + } - if (dataIndex > Int32.MaxValue) + if (dataIndex > int.MaxValue) { - throw ADP.InvalidSourceBufferIndex(cbytes, dataIndex, "dataIndex"); + throw ADP.InvalidSourceBufferIndex(cbytes, dataIndex, nameof(dataIndex)); } int ndataIndex = (int)dataIndex; @@ -2035,9 +2056,13 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf { // help the user out in the case where there's less data than requested if ((ndataIndex + length) > cbytes) + { cbytes = cbytes - ndataIndex; + } else + { cbytes = length; + } } Buffer.BlockCopy(data, ndataIndex, buffer, bufferIndex, cbytes); @@ -2052,15 +2077,21 @@ private TdsOperationStatus TryGetBytesInternal(int i, long dataIndex, byte[] buf cbytes = data.Length; if (length < 0) + { throw ADP.InvalidDataLength(length); + } // if bad buffer index, throw if (bufferIndex < 0 || bufferIndex >= buffer.Length) - throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, "bufferIndex"); + { + throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); + } // if there is not enough room in the buffer for data if (cbytes + bufferIndex > buffer.Length) + { throw ADP.InvalidBufferSizeOrIndex(cbytes, bufferIndex); + } throw; } @@ -2152,6 +2183,7 @@ internal TdsOperationStatus TryGetBytesInternalSequential(int i, byte[] buffer, Debug.Assert(index + length <= buffer.Length, "Buffer too small"); bytesRead = 0; + TdsOperationStatus result; RuntimeHelpers.PrepareConstrainedRegions(); try @@ -2176,7 +2208,7 @@ internal TdsOperationStatus TryGetBytesInternalSequential(int i, byte[] buffer, if (_metaData[i].metaType.IsPlp) { // Read in data - TdsOperationStatus result = _stateObj.TryReadPlpBytes(ref buffer, index, length, out bytesRead); + result = _stateObj.TryReadPlpBytes(ref buffer, index, length, out bytesRead); _columnDataBytesRead += bytesRead; if (result != TdsOperationStatus.Done) { @@ -2198,7 +2230,7 @@ internal TdsOperationStatus TryGetBytesInternalSequential(int i, byte[] buffer, { // Read data (not exceeding the total amount of data available) int bytesToRead = (int)Math.Min((long)length, _sharedState._columnDataBytesRemaining); - TdsOperationStatus result = _stateObj.TryReadByteArray(buffer.AsSpan(start: index), bytesToRead, out bytesRead); + result = _stateObj.TryReadByteArray(buffer.AsSpan(index), bytesToRead, out bytesRead); _columnDataBytesRead += bytesRead; _sharedState._columnDataBytesRemaining -= bytesRead; return result; @@ -2244,7 +2276,7 @@ internal TdsOperationStatus TryGetBytesInternalSequential(int i, byte[] buffer, /// override public TextReader GetTextReader(int i) { - CheckDataIsReady(columnIndex: i, methodName: "GetTextReader"); + CheckDataIsReady(columnIndex: i); // Xml type is not supported MetaType mt = null; @@ -2377,7 +2409,7 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn // if bad buffer index, throw if ((bufferIndex < 0) || (buffer != null && bufferIndex >= buffer.Length)) { - throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, "bufferIndex"); + throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); } // if there is not enough room in the buffer for data @@ -2390,7 +2422,7 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn { try { - CheckDataIsReady(columnIndex: i, allowPartiallyReadColumn: true, methodName: "GetChars"); + CheckDataIsReady(columnIndex: i, allowPartiallyReadColumn: true); } catch (Exception ex) { @@ -2409,7 +2441,7 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn } else { - CheckDataIsReady(columnIndex: i, allowPartiallyReadColumn: true, methodName: "GetChars"); + CheckDataIsReady(columnIndex: i, allowPartiallyReadColumn: true); charsRead = GetCharsFromPlpData(i, dataIndex, buffer, bufferIndex, length); } _lastColumnWithDataChunkRead = i; @@ -2437,9 +2469,9 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn // note that since we are caching in an array, and arrays aren't 64 bit ready yet, // we need can cast to int if the dataIndex is in range - if (dataIndex > Int32.MaxValue) + if (dataIndex > int.MaxValue) { - throw ADP.InvalidSourceBufferIndex(cchars, dataIndex, "dataIndex"); + throw ADP.InvalidSourceBufferIndex(cchars, dataIndex, nameof(dataIndex)); } int ndataIndex = (int)dataIndex; @@ -2461,9 +2493,13 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn { // help the user out in the case where there's less data than requested if ((ndataIndex + length) > cchars) + { cchars = cchars - ndataIndex; + } else + { cchars = length; + } } Array.Copy(_columnDataChars, ndataIndex, buffer, bufferIndex, cchars); @@ -2479,15 +2515,21 @@ override public long GetChars(int i, long dataIndex, char[] buffer, int bufferIn cchars = _columnDataChars.Length; if (length < 0) + { throw ADP.InvalidDataLength(length); + } // if bad buffer index, throw if (bufferIndex < 0 || bufferIndex >= buffer.Length) - throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, "bufferIndex"); + { + throw ADP.InvalidDestinationBufferIndex(buffer.Length, bufferIndex, nameof(bufferIndex)); + } // if there is not enough room in the buffer for data if (cchars + bufferIndex > buffer.Length) + { throw ADP.InvalidBufferSizeOrIndex(cchars, bufferIndex); + } throw; } @@ -2550,7 +2592,9 @@ private long GetCharsFromPlpData(int i, long dataIndex, char[] buffer, int buffe // _columnDataCharsRead is 0 and dataIndex > _columnDataCharsRead is true below. // In both cases we will clean decoder if (dataIndex == 0) + { _stateObj._plpdecoder = null; + } bool isUnicode = _metaData[i].metaType.IsNCharType; @@ -2580,7 +2624,6 @@ private long GetCharsFromPlpData(int i, long dataIndex, char[] buffer, int buffe // Clean decoder state: we do not reset it, but destroy to ensure // that we do not start decoding the column with decoder from the old one _stateObj._plpdecoder = null; - // TODO: for DBCS encoding skip positioning dataIndex is not in characters but is interpreted as // number of chars already read + number of bytes to skip cch = dataIndex - _columnDataCharsRead; @@ -2697,7 +2740,7 @@ override public DateTime GetDateTime(int i) } /// - override public Decimal GetDecimal(int i) + override public decimal GetDecimal(int i) { ReadColumn(i); return _data[i].Decimal; @@ -2725,21 +2768,21 @@ override public Guid GetGuid(int i) } /// - override public Int16 GetInt16(int i) + override public short GetInt16(int i) { ReadColumn(i); return _data[i].Int16; } /// - override public Int32 GetInt32(int i) + override public int GetInt32(int i) { ReadColumn(i); return _data[i].Int32; } /// - override public Int64 GetInt64(int i) + override public long GetInt64(int i) { ReadColumn(i); return _data[i].Int64; @@ -2904,7 +2947,6 @@ virtual public SqlJson GetSqlJson(int i) { ReadColumn(i); SqlJson json = _data[i].IsNull ? SqlJson.Null : _data[i].SqlJson; - return json; } @@ -2961,11 +3003,11 @@ private object GetSqlValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else if (_typeSystem != SqlConnectionString.TypeSystem.SQLServer2000) { - // TypeSystem.SQLServer2005 + // TypeSystem.SQLServer2005 and above if (metaData.type == SqlDbType.Udt) { - var connection = _connection; + SqlConnection connection = _connection; if (connection != null) { connection.CheckGetExtendedUDTInfo(metaData, true); @@ -2973,7 +3015,7 @@ private object GetSqlValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met } else { - throw ADP.DataReaderClosed("GetSqlValueFromSqlBufferInternal"); + throw ADP.DataReaderClosed(); } } else @@ -3006,7 +3048,7 @@ virtual public int GetSqlValues(object[] values) CheckDataIsReady(); if (values == null) { - throw ADP.ArgumentNull("values"); + throw ADP.ArgumentNull(nameof(values)); } SetTimeout(_defaultTimeoutMilliseconds); @@ -3159,7 +3201,7 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa } else if (_typeSystem != SqlConnectionString.TypeSystem.SQLServer2000) { - // TypeSystem.SQLServer2005 + // TypeSystem.SQLServer2005 and above if (metaData.type != SqlDbType.Udt) { @@ -3167,7 +3209,7 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa } else { - var connection = _connection; + SqlConnection connection = _connection; if (connection != null) { connection.CheckGetExtendedUDTInfo(metaData, true); @@ -3175,7 +3217,7 @@ private object GetValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData metaDa } else { - throw ADP.DataReaderClosed("GetValueFromSqlBufferInternal"); + throw ADP.DataReaderClosed(); } } } @@ -3254,6 +3296,16 @@ private T GetFieldValueFromSqlBufferInternal(SqlBuffer data, _SqlMetaData met { return (T)(object)data.DateTime; } +#if !NETFRAMEWORK + else if (typeof(T) == typeof(DateOnly) && dataType == typeof(DateTime) && _typeSystem > SqlConnectionString.TypeSystem.SQLServer2005) + { + return (T)(object)data.DateOnly; + } + else if (typeof(T) == typeof(TimeOnly) && dataType == typeof(TimeOnly) && _typeSystem > SqlConnectionString.TypeSystem.SQLServer2005) + { + return (T)(object)data.TimeOnly; + } +#endif else if (typeof(T) == typeof(XmlReader)) { // XmlReader only allowed on XML types @@ -3422,7 +3474,7 @@ override public int GetValues(object[] values) if (values == null) { - throw ADP.ArgumentNull("values"); + throw ADP.ArgumentNull(nameof(values)); } CheckMetaDataIsReady(); @@ -3444,7 +3496,7 @@ override public int GetValues(object[] values) for (int i = 0; i < copyLen; i++) { - // Get the usable, TypeSystem-compatible value from the iternal buffer + // Get the usable, TypeSystem-compatible value from the internal buffer int fieldIndex = _metaData.GetVisibleColumnIndex(i); values[i] = GetValueFromSqlBufferInternal(_data[fieldIndex], _metaData[fieldIndex]); @@ -3581,7 +3633,7 @@ private TdsOperationStatus TryHasMoreResults(out bool moreResults) // Dev11 Bug 316483: Stuck at SqlDataReader::TryHasMoreResults using MARS // http://vstfdevdiv:8080/web/wi.aspx?pcguid=22f9acc9-569a-41ff-b6ac-fac1b6370209&id=316483 - // TryRun() will immediately return if the TdsParser is closed\broken, causing us to enter an infinite loop + // TryRun() will immediately return if the TdsParser is closed/broken, causing us to enter an infinite loop // Instead, we will throw a closed connection exception if (_parser.State == TdsParserState.Broken || _parser.State == TdsParserState.Closed) { @@ -3657,7 +3709,6 @@ private TdsOperationStatus TryHasMoreRows(out bool moreRows) b == TdsEnums.SQLERROR || b == TdsEnums.SQLINFO)) { - if (b == TdsEnums.SQLDONE || b == TdsEnums.SQLDONEPROC || b == TdsEnums.SQLDONEINPROC) @@ -3667,7 +3718,7 @@ private TdsOperationStatus TryHasMoreRows(out bool moreRows) // Dev11 Bug 316483: Stuck at SqlDataReader::TryHasMoreResults when using MARS // http://vstfdevdiv:8080/web/wi.aspx?pcguid=22f9acc9-569a-41ff-b6ac-fac1b6370209&id=316483 - // TryRun() will immediately return if the TdsParser is closed\broken, causing us to enter an infinite loop + // TryRun() will immediately return if the TdsParser is closed/broken, causing us to enter an infinite loop // Instead, we will throw a closed connection exception if (_parser.State == TdsParserState.Broken || _parser.State == TdsParserState.Closed) { @@ -3727,7 +3778,7 @@ override public bool IsDBNull(int i) } else { - CheckHeaderIsReady(columnIndex: i, methodName: "IsDBNull"); + CheckHeaderIsReady(columnIndex: i); SetTimeout(_defaultTimeoutMilliseconds); @@ -3755,7 +3806,6 @@ override public bool NextResult() Debug.Assert(_stateObj == null || _stateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); TdsOperationStatus result = TryNextResult(out more); - if (result != TdsOperationStatus.Done) { throw SQL.SynchronousCallMayNotPend(); @@ -3790,7 +3840,7 @@ private TdsOperationStatus TryNextResult(out bool more) if (IsClosed) { - throw ADP.DataReaderClosed("NextResult"); + throw ADP.DataReaderClosed(nameof(NextResult)); } _fieldNameLookup = null; @@ -3983,7 +4033,6 @@ override public bool Read() // user must call Read() to position on the first row private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) { - TdsOperationStatus result; SqlStatistics statistics = null; using (TryEventScope.Create(" {0}", ObjectID)) { @@ -4001,6 +4050,7 @@ private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) #else { #endif //DEBUG + TdsOperationStatus result; statistics = SqlStatistics.StartTimer(Statistics); if (_parser != null) @@ -4137,7 +4187,7 @@ private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) } else if (IsClosed) { - throw ADP.DataReaderClosed("Read"); + throw ADP.DataReaderClosed(nameof(Read)); } more = false; @@ -4164,7 +4214,7 @@ private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) } #endif //DEBUG } - catch (System.OutOfMemoryException e) + catch (OutOfMemoryException e) { _isClosed = true; SqlConnection con = _connection; @@ -4174,7 +4224,7 @@ private TdsOperationStatus TryReadInternal(bool setTimeout, out bool more) } throw; } - catch (System.StackOverflowException e) + catch (StackOverflowException e) { _isClosed = true; SqlConnection con = _connection; @@ -4218,7 +4268,7 @@ private void ReadColumn(int i, bool setTimeout = true, bool allowPartiallyReadCo private TdsOperationStatus TryReadColumn(int i, bool setTimeout, bool allowPartiallyReadColumn = false, bool forStreaming = false) { - CheckDataIsReady(columnIndex: i, permitAsync: true, allowPartiallyReadColumn: allowPartiallyReadColumn); + CheckDataIsReady(columnIndex: i, permitAsync: true, allowPartiallyReadColumn: allowPartiallyReadColumn, methodName: nameof(CheckDataIsReady)); RuntimeHelpers.PrepareConstrainedRegions(); try @@ -4460,13 +4510,7 @@ internal TdsOperationStatus TryReadColumnInternal(int i, bool readHeaderOnly/* = { bool isNull; ulong dataLength; - result = _parser.TryProcessColumnHeader( - columnMetaData, - _stateObj, - _sharedState._nextColumnHeaderToRead, - out isNull, - out dataLength - ); + result = _parser.TryProcessColumnHeader(columnMetaData, _stateObj, _sharedState._nextColumnHeaderToRead, out isNull, out dataLength); if (result != TdsOperationStatus.Done) { return result; @@ -4480,12 +4524,10 @@ out dataLength { if (columnMetaData.type != SqlDbType.Timestamp) { - TdsParser.GetNullSqlValue( - _data[_sharedState._nextColumnDataToRead], + TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, - _parser.Connection - ); + _parser.Connection); } } else @@ -4494,13 +4536,9 @@ out dataLength { // If we're in sequential mode try to read the data and then if it succeeds update shared // state so there are no remaining bytes and advance the next column to read - result = _parser.TryReadSqlValue( - _data[_sharedState._nextColumnDataToRead], - columnMetaData, - (int)dataLength, _stateObj, + result = _parser.TryReadSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, (int)dataLength, _stateObj, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, - columnMetaData.column - ); + columnMetaData.column); if (result != TdsOperationStatus.Done) { // will read UDTs as VARBINARY. @@ -4517,7 +4555,7 @@ out dataLength } else { - Debug.Assert(false, "we have read past the column somehow, this is an error"); + Debug.Assert(false, "We have read past the column somehow, this is an error"); } } else @@ -4537,12 +4575,10 @@ out dataLength // if LegacyRowVersionNullBehavior is enabled, Timestamp type must enter "else" block. if (isNull && (!LocalAppContextSwitches.LegacyRowVersionNullBehavior || columnMetaData.type != SqlDbType.Timestamp)) { - TdsParser.GetNullSqlValue( - _data[_sharedState._nextColumnDataToRead], + TdsParser.GetNullSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, - _parser.Connection - ); + _parser.Connection); if (!readHeaderOnly) { @@ -4558,7 +4594,7 @@ out dataLength // can read it out of order result = _parser.TryReadSqlValue(_data[_sharedState._nextColumnDataToRead], columnMetaData, (int)dataLength, _stateObj, _command != null ? _command.ColumnEncryptionSetting : SqlCommandColumnEncryptionSetting.UseConnectionSetting, - columnMetaData.column); + columnMetaData.column, _command); if (result != TdsOperationStatus.Done) { // will read UDTs as VARBINARY. @@ -4625,12 +4661,11 @@ private bool WillHaveEnoughData(int targetColumn, bool headerOnly = false) // Check NBC first if (!_stateObj.IsNullCompressionBitSet(currentColumn)) { - // NOTE: This is mostly duplicated from TryProcessColumnHeaderNoNBC and TryGetTokenLength var metaType = _metaData[currentColumn].metaType; if ((metaType.IsLong) || (metaType.IsPlp) || (metaType.SqlDbType == SqlDbType.Udt) || (metaType.SqlDbType == SqlDbType.Structured)) { - // Plp, Udt and TVP types have an unknownable size - so return that the estimate failed + // Plp, Udt and TVP types have an unknowable size - so return that the estimate failed return false; } int maxHeaderSize; @@ -4685,7 +4720,7 @@ private TdsOperationStatus TryResetBlobState() { if (_stateObj._longlen != 0) { - result = _stateObj.Parser.TrySkipPlpValue(UInt64.MaxValue, _stateObj, out _); + result = _stateObj.Parser.TrySkipPlpValue(ulong.MaxValue, _stateObj, out _); if (result != TdsOperationStatus.Done) { return result; @@ -4806,6 +4841,7 @@ internal TdsOperationStatus TrySetAltMetaDataSet(_SqlMetaDataSet metaDataSet, bo { _stateObj._accumulateInfoEvents = false; } + result = _stateObj.TryPeekByte(out b); if (result != TdsOperationStatus.Done) { @@ -4910,6 +4946,7 @@ internal TdsOperationStatus TrySetMetaData(_SqlMetaDataSet metaData, bool moreIn { _stateObj._accumulateInfoEvents = false; } + result = _stateObj.TryPeekByte(out b); if (result != TdsOperationStatus.Done) { @@ -4995,11 +5032,11 @@ private void CheckDataIsReady() } } - private void CheckHeaderIsReady(int columnIndex, bool permitAsync = false, string methodName = null) + private void CheckHeaderIsReady(int columnIndex, bool permitAsync = false, [CallerMemberName] string methodName = null) { if (_isClosed) { - throw ADP.DataReaderClosed(methodName ?? "CheckHeaderIsReady"); + throw ADP.DataReaderClosed(methodName ?? nameof(CheckHeaderIsReady)); } if ((!permitAsync) && (_currentTask != null)) { @@ -5021,11 +5058,11 @@ private void CheckHeaderIsReady(int columnIndex, bool permitAsync = false, strin } } - private void CheckDataIsReady(int columnIndex, bool allowPartiallyReadColumn = false, bool permitAsync = false, string methodName = null) + private void CheckDataIsReady(int columnIndex, bool allowPartiallyReadColumn = false, bool permitAsync = false, [CallerMemberName] string methodName = null) { if (_isClosed) { - throw ADP.DataReaderClosed(methodName ?? "CheckDataIsReady"); + throw ADP.DataReaderClosed(methodName ?? nameof(CheckDataIsReady)); } if ((!permitAsync) && (_currentTask != null)) { @@ -5069,12 +5106,11 @@ public override Task NextResultAsync(CancellationToken cancellationToken) using (TryEventScope.Create(" {0}", ObjectID)) using (var registrationHolder = new DisposableTemporaryOnStack()) { - TaskCompletionSource source = new TaskCompletionSource(); if (IsClosed) { - source.SetException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("NextResultAsync"))); + source.SetException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); return source.Task; } @@ -5112,17 +5148,17 @@ private static Task NextResultAsyncExecute(Task task, object state) HasNextResultAsyncCallContext context = (HasNextResultAsyncCallContext)state; if (task != null) { - SqlClientEventSource.Log.TryTraceEvent(" attempt retry {0}", context._reader.ObjectID); - context._reader.PrepareForAsyncContinuation(); + SqlClientEventSource.Log.TryTraceEvent(" attempt retry {0}", context.Reader.ObjectID); + context.Reader.PrepareForAsyncContinuation(); } - if (context._reader.TryNextResult(out bool more) == TdsOperationStatus.Done) + if (context.Reader.TryNextResult(out bool more) == TdsOperationStatus.Done) { // completed return more ? ADP.TrueTask : ADP.FalseTask; } - return context._reader.ExecuteAsyncCall(context); + return context.Reader.ExecuteAsyncCall(context); } // NOTE: This will return null if it completed sequentially @@ -5135,16 +5171,12 @@ internal Task GetBytesAsync(int columnIndex, byte[] buffer, int index, int bytesRead = 0; if (IsClosed) { - TaskCompletionSource source = new TaskCompletionSource(); - source.SetException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("GetBytesAsync"))); - return source.Task; + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); } if (_currentTask != null) { - TaskCompletionSource source = new TaskCompletionSource(); - source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); - return source.Task; + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); } if (cancellationToken.CanBeCanceled) @@ -5157,12 +5189,12 @@ internal Task GetBytesAsync(int columnIndex, byte[] buffer, int index, int var context = new GetBytesAsyncCallContext(this) { - columnIndex = columnIndex, - buffer = buffer, - index = index, - length = length, - timeout = timeout, - cancellationToken = cancellationToken, + _columnIndex = columnIndex, + _buffer = buffer, + _index = index, + _length = length, + _timeout = timeout, + _cancellationToken = cancellationToken, }; // Check if we need to skip columns @@ -5187,18 +5219,18 @@ internal Task GetBytesAsync(int columnIndex, byte[] buffer, int index, int timeoutToken = timeoutCancellationSource.Token; } - context._disposable = timeoutCancellationSource; - context.timeoutToken = timeoutToken; - context._source = source; PrepareAsyncInvocation(useSnapshot: true); + context.Set(this, source, timeoutCancellationSource); + context._timeoutToken = timeoutToken; + return InvokeAsyncCall(context); } else { // We're already at the correct column, just read the data - context.mode = GetBytesAsyncCallContext.OperationMode.Read; + context._mode = GetBytesAsyncCallContext.OperationMode.Read; // Switch to async PrepareAsyncInvocation(useSnapshot: false); @@ -5218,35 +5250,36 @@ internal Task GetBytesAsync(int columnIndex, byte[] buffer, int index, int private static Task GetBytesAsyncSeekExecute(Task task, object state) { GetBytesAsyncCallContext context = (GetBytesAsyncCallContext)state; - SqlDataReader reader = context._reader; + SqlDataReader reader = context.Reader; - Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Seek, "context.mode must be Seek to check if seeking can resume"); + Debug.Assert(context._mode == GetBytesAsyncCallContext.OperationMode.Seek, "context.mode must be Seek to check if seeking can resume"); if (task != null) { reader.PrepareForAsyncContinuation(); } + // Prepare for stateObj timeout reader.SetTimeout(reader._defaultTimeoutMilliseconds); - if (reader.TryReadColumnHeader(context.columnIndex) == TdsOperationStatus.Done) + if (reader.TryReadColumnHeader(context._columnIndex) == TdsOperationStatus.Done) { - // Only once we have read upto where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state) + // Only once we have read up to where we need to be can we check the cancellation tokens (otherwise we will be in an unknown state) - if (context.cancellationToken.IsCancellationRequested) + if (context._cancellationToken.IsCancellationRequested) { // User requested cancellation - return Task.FromCanceled(context.cancellationToken); + return Task.FromCanceled(context._cancellationToken); } - else if (context.timeoutToken.IsCancellationRequested) + else if (context._timeoutToken.IsCancellationRequested) { // Timeout - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); } else { - // Upto the correct column - continue to read - context.mode = GetBytesAsyncCallContext.OperationMode.Read; + // Up to the correct column - continue to read + context._mode = GetBytesAsyncCallContext.OperationMode.Read; reader.SwitchToAsyncWithoutSnapshot(); int totalBytesRead; var readTask = reader.GetBytesAsyncReadDataStage(context, true, out totalBytesRead); @@ -5270,18 +5303,18 @@ private static Task GetBytesAsyncSeekExecute(Task task, object state) private static Task GetBytesAsyncReadExecute(Task task, object state) { var context = (GetBytesAsyncCallContext)state; - SqlDataReader reader = context._reader; + SqlDataReader reader = context.Reader; - Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.mode must be Read to check if read can resume"); + Debug.Assert(context._mode == GetBytesAsyncCallContext.OperationMode.Read, "context.mode must be Read to check if read can resume"); reader.PrepareForAsyncContinuation(); - if (context.cancellationToken.IsCancellationRequested) + if (context._cancellationToken.IsCancellationRequested) { // User requested cancellation - return Task.FromCanceled(context.cancellationToken); + return Task.FromCanceled(context._cancellationToken); } - else if (context.timeoutToken.IsCancellationRequested) + else if (context._timeoutToken.IsCancellationRequested) { // Timeout return Task.FromException(ADP.ExceptionWithStackTrace(ADP.IO(SQLMessage.Timeout()))); @@ -5293,18 +5326,18 @@ private static Task GetBytesAsyncReadExecute(Task task, object state) int bytesReadThisIteration; TdsOperationStatus result = reader.TryGetBytesInternalSequential( - context.columnIndex, - context.buffer, - context.index + context.totalBytesRead, - context.length - context.totalBytesRead, + context._columnIndex, + context._buffer, + context._index + context._totalBytesRead, + context._length - context._totalBytesRead, out bytesReadThisIteration ); - context.totalBytesRead += bytesReadThisIteration; - Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required"); + context._totalBytesRead += bytesReadThisIteration; + Debug.Assert(context._totalBytesRead <= context._length, "Read more bytes than required"); if (result == TdsOperationStatus.Done) { - return Task.FromResult(context.totalBytesRead); + return Task.FromResult(context._totalBytesRead); } else { @@ -5315,34 +5348,32 @@ out bytesReadThisIteration private Task GetBytesAsyncReadDataStage(GetBytesAsyncCallContext context, bool isContinuation, out int bytesRead) { - Debug.Assert(context.mode == GetBytesAsyncCallContext.OperationMode.Read, "context.Mode must be Read to read data"); + Debug.Assert(context._mode == GetBytesAsyncCallContext.OperationMode.Read, "context.Mode must be Read to read data"); - _lastColumnWithDataChunkRead = context.columnIndex; + _lastColumnWithDataChunkRead = context._columnIndex; TaskCompletionSource source = null; // Prepare for stateObj timeout SetTimeout(_defaultTimeoutMilliseconds); // Try to read without any continuations (all the data may already be in the stateObj's buffer) - TdsOperationStatus filledBuffer = context._reader.TryGetBytesInternalSequential( - context.columnIndex, - context.buffer, - context.index + context.totalBytesRead, - context.length - context.totalBytesRead, + TdsOperationStatus filledBuffer = context.Reader.TryGetBytesInternalSequential( + context._columnIndex, + context._buffer, + context._index + context._totalBytesRead, + context._length - context._totalBytesRead, out bytesRead ); - context.totalBytesRead += bytesRead; - Debug.Assert(context.totalBytesRead <= context.length, "Read more bytes than required"); + context._totalBytesRead += bytesRead; + Debug.Assert(context._totalBytesRead <= context._length, "Read more bytes than required"); if (filledBuffer != TdsOperationStatus.Done) { // This will be the 'state' for the callback - int totalBytesRead = bytesRead; - if (!isContinuation) { // This is the first async operation which is happening - setup the _currentTask and timeout - Debug.Assert(context._source == null, "context._source should not be non-null when trying to change to async"); + Debug.Assert(context.Source == null, "context._source should not be non-null when trying to change to async"); source = new TaskCompletionSource(); Task original = Interlocked.CompareExchange(ref _currentTask, source.Task, null); if (original != null) @@ -5350,8 +5381,7 @@ out bytesRead source.SetException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); return source.Task; } - - context._source = source; + context.Source = source; // Check if cancellation due to close is requested (this needs to be done after setting _currentTask) if (_cancelAsyncOnCloseToken.IsCancellationRequested) { @@ -5361,29 +5391,29 @@ out bytesRead } // Timeout - Debug.Assert(context.timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation"); - if (context.timeout > 0) + Debug.Assert(context._timeoutToken == CancellationToken.None, "TimeoutToken is set when GetBytesAsyncReadDataStage is not a continuation"); + if (context._timeout > 0) { CancellationTokenSource timeoutCancellationSource = new CancellationTokenSource(); - timeoutCancellationSource.CancelAfter(context.timeout); - Debug.Assert(context._disposable is null, "setting context.disposable would lose the previous dispoable"); - context._disposable = timeoutCancellationSource; - context.timeoutToken = timeoutCancellationSource.Token; + timeoutCancellationSource.CancelAfter(context._timeout); + Debug.Assert(context.Disposable is null, "setting context.disposable would lose the previous disposable"); + context.Disposable = timeoutCancellationSource; + context._timeoutToken = timeoutCancellationSource.Token; } } Task retryTask = ExecuteAsyncCall(context); if (isContinuation) { - // Let the caller handle cleanup\completing + // Let the caller handle cleanup/completing return retryTask; } else { - Debug.Assert(context._source != null, "context._source shuld not be null when continuing"); - // setup for cleanup\completing + Debug.Assert(context.Source != null, "context._source should not be null when continuing"); + // setup for cleanup/completing retryTask.ContinueWith( - continuationAction: AAsyncCallContext.s_completeCallback, + continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback, state: context, TaskScheduler.Default ); @@ -5408,7 +5438,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) { if (IsClosed) { - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("ReadAsync"))); + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); } // Register first to catch any already expired tokens to be able to trigger cancellation event. @@ -5420,13 +5450,13 @@ public override Task ReadAsync(CancellationToken cancellationToken) // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { - return ADP.CreatedTaskWithCancellation(); + return Task.FromCanceled(cancellationToken); } // Check for existing async if (_currentTask != null) { - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(SQL.PendingBeginXXXExists())); + return Task.FromException(ADP.ExceptionWithStackTrace(SQL.PendingBeginXXXExists())); } // These variables will be captured in moreFunc so that we can skip searching for a row token once one has been read @@ -5440,7 +5470,6 @@ public override Task ReadAsync(CancellationToken cancellationToken) // NOTE: If we are in SingleRow mode and we've read that single row (i.e. _haltRead == true), then skip the shortcut if ((!_haltRead) && ((!_sharedState._dataReady) || (WillHaveEnoughData(_metaData.Length - 1)))) { - #if DEBUG try { @@ -5497,7 +5526,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) { throw; } - return ADP.CreatedTaskWithException(ex); + return Task.FromException(ex); } TaskCompletionSource source = new TaskCompletionSource(); @@ -5518,7 +5547,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) var context = Interlocked.Exchange(ref _cachedReadAsyncContext, null) ?? new ReadAsyncCallContext(); - Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ReadAsyncCallContext was not properly disposed"); + Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ReadAsyncCallContext was not properly disposed"); context.Set(this, source, registrationHolder.Take()); context._hasMoreData = more; @@ -5533,7 +5562,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) private static Task ReadAsyncExecute(Task task, object state) { var context = (ReadAsyncCallContext)state; - SqlDataReader reader = context._reader; + SqlDataReader reader = context.Reader; ref bool hasMoreData = ref context._hasMoreData; ref bool hasReadRowToken = ref context._hasReadRowToken; @@ -5541,7 +5570,7 @@ private static Task ReadAsyncExecute(Task task, object state) { reader.PrepareForAsyncContinuation(); } - TdsOperationStatus result; + if (hasReadRowToken || (reader.TryReadInternal(true, out hasMoreData) == TdsOperationStatus.Done)) { // If there are no more rows, or this is Sequential Access, then we are done @@ -5561,7 +5590,7 @@ private static Task ReadAsyncExecute(Task task, object state) } // if non-sequentialaccess then read entire row before returning - result = reader.TryReadColumn(reader._metaData.Length - 1, true); + TdsOperationStatus result = reader.TryReadColumn(reader._metaData.Length - 1, true); if (result == TdsOperationStatus.Done) { // completed @@ -5581,10 +5610,9 @@ private void SetCachedReadAsyncCallContext(ReadAsyncCallContext instance) /// override public Task IsDBNullAsync(int i, CancellationToken cancellationToken) { - try { - CheckHeaderIsReady(columnIndex: i, methodName: "IsDBNullAsync"); + CheckHeaderIsReady(columnIndex: i); } catch (Exception ex) { @@ -5592,7 +5620,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo { throw; } - return ADP.CreatedTaskWithException(ex); + return Task.FromException(ex); } // Shortcut - if there are no issues and the data is already read, then just return the value @@ -5606,7 +5634,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo else { // Reader was closed between the CheckHeaderIsReady and accessing _data - throw closed exception - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("IsDBNullAsync"))); + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); } } else @@ -5614,13 +5642,13 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo // Throw if there is any current task if (_currentTask != null) { - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); } // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { - return ADP.CreatedTaskWithCancellation(); + return Task.FromCanceled(cancellationToken); } // Shortcut - if we have enough data, then run sync @@ -5650,7 +5678,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo { throw; } - return ADP.CreatedTaskWithException(ex); + return Task.FromException(ex); } using (var registrationHolder = new DisposableTemporaryOnStack()) @@ -5680,7 +5708,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo IsDBNullAsyncCallContext context = Interlocked.Exchange(ref _cachedIsDBNullContext, null) ?? new IsDBNullAsyncCallContext(); - Debug.Assert(context._reader == null && context._source == null && context._disposable == null, "cached ISDBNullAsync context not properly disposed"); + Debug.Assert(context.Reader == null && context.Source == null && context.Disposable == default, "cached ISDBNullAsync context not properly disposed"); context.Set(this, source, registrationHolder.Take()); context._columnIndex = i; @@ -5696,7 +5724,7 @@ override public Task IsDBNullAsync(int i, CancellationToken cancellationTo private static Task IsDBNullAsyncExecute(Task task, object state) { IsDBNullAsyncCallContext context = (IsDBNullAsyncCallContext)state; - SqlDataReader reader = context._reader; + SqlDataReader reader = context.Reader; if (task != null) { @@ -5724,7 +5752,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { try { - CheckDataIsReady(columnIndex: i, methodName: "GetFieldValueAsync"); + CheckDataIsReady(columnIndex: i); // Shortcut - if there are no issues and the data is already read, then just return the value if ((!IsCommandBehavior(CommandBehavior.SequentialAccess)) && (_sharedState._nextColumnDataToRead > i) && (!cancellationToken.IsCancellationRequested) && (_currentTask == null)) @@ -5737,8 +5765,8 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat } else { - // Reader was closed between the CheckDataIsReady and accessing _data\_metaData - throw closed exception - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed("GetFieldValueAsync"))); + // Reader was closed between the CheckDataIsReady and accessing _data/_metaData - throw closed exception + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.DataReaderClosed())); } } } @@ -5748,19 +5776,19 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { throw; } - return ADP.CreatedTaskWithException(ex); + return Task.FromException(ex); } // Throw if there is any current task if (_currentTask != null) { - return ADP.CreatedTaskWithException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); + return Task.FromException(ADP.ExceptionWithStackTrace(ADP.AsyncOperationPending())); } // If user's token is canceled, return a canceled task if (cancellationToken.IsCancellationRequested) { - return ADP.CreatedTaskWithCancellation(); + return Task.FromCanceled(cancellationToken); } // Shortcut - if we have enough data, then run sync @@ -5789,7 +5817,7 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat { throw; } - return ADP.CreatedTaskWithException(ex); + return Task.FromException(ex); } using (var registrationHolder = new DisposableTemporaryOnStack()) @@ -5817,14 +5845,20 @@ override public Task GetFieldValueAsync(int i, CancellationToken cancellat registrationHolder.Set(cancellationToken.Register(SqlCommand.s_cancelIgnoreFailure, _command)); } - return InvokeAsyncCall(new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take(), i)); + // Setup async + PrepareAsyncInvocation(useSnapshot: true); + + GetFieldValueAsyncCallContext context = new GetFieldValueAsyncCallContext(this, source, registrationHolder.Take()); + context._columnIndex = i; + + return InvokeAsyncCall(context); } } private static Task GetFieldValueAsyncExecute(Task task, object state) { GetFieldValueAsyncCallContext context = (GetFieldValueAsyncCallContext)state; - SqlDataReader reader = context._reader; + SqlDataReader reader = context.Reader; int columnIndex = context._columnIndex; if (task != null) { @@ -5888,71 +5922,62 @@ internal void CompletePendingReadWithFailure(int errorCode, bool resetForcePendi #endif - private class Snapshot + internal abstract class SqlDataReaderBaseAsyncCallContext : AAsyncBaseCallContext { - public bool _dataReady; - public bool _haltRead; - public bool _metaDataConsumed; - public bool _browseModeInfoConsumed; - public bool _hasRows; - public ALTROWSTATUS _altRowStatus; - public int _nextColumnDataToRead; - public int _nextColumnHeaderToRead; - public long _columnDataBytesRead; - public long _columnDataBytesRemaining; + internal static readonly Action, object> s_completeCallback = CompleteAsyncCallCallback; - public _SqlMetaDataSet _metadata; - public _SqlMetaDataSetCollection _altMetaDataSetCollection; - public MultiPartTableName[] _tableNames; + internal static readonly Func> s_executeCallback = ExecuteAsyncCallCallback; - public SqlSequentialStream _currentStream; - public SqlSequentialTextReader _currentTextReader; - } + protected SqlDataReaderBaseAsyncCallContext() + { + } - private abstract class AAsyncCallContext : IDisposable - { - internal static readonly Action, object> s_completeCallback = SqlDataReader.CompleteAsyncCallCallback; + protected SqlDataReaderBaseAsyncCallContext(SqlDataReader owner, TaskCompletionSource source) + { + Set(owner, source); + } - internal static readonly Func> s_executeCallback = SqlDataReader.ExecuteAsyncCallCallback; + internal abstract Func> Execute { get; } - internal SqlDataReader _reader; - internal TaskCompletionSource _source; - internal IDisposable _disposable; + internal SqlDataReader Reader { get => _owner; set => _owner = value; } - protected AAsyncCallContext() - { - } + public TaskCompletionSource Source { get => _source; set => _source = value; } - protected AAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable = null) + private static Task ExecuteAsyncCallCallback(Task task, object state) { - Set(reader, source, disposable); + SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state; + return context.Reader.ContinueAsyncCall(task, context); } - internal void Set(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable = null) + private static void CompleteAsyncCallCallback(Task task, object state) { - this._reader = reader ?? throw new ArgumentNullException(nameof(reader)); - this._source = source ?? throw new ArgumentNullException(nameof(source)); - this._disposable = disposable; + SqlDataReaderBaseAsyncCallContext context = (SqlDataReaderBaseAsyncCallContext)state; + context.Reader.CompleteAsyncCall(task, context); } + } + + internal abstract class SqlDataReaderAsyncCallContext : SqlDataReaderBaseAsyncCallContext + where TDisposable : IDisposable + { + private TDisposable _disposable; + + public TDisposable Disposable { get => _disposable; set => _disposable = value; } - internal void Clear() + public void Set(SqlDataReader owner, TaskCompletionSource source, TDisposable disposable) { - _source = null; - _reader = null; - IDisposable copyDisposable = _disposable; - _disposable = null; - copyDisposable?.Dispose(); + base.Set(owner, source); + _disposable = disposable; } - internal abstract Func> Execute { get; } - - public virtual void Dispose() + protected override void DisposeCore() { - Clear(); + TDisposable copy = _disposable; + _disposable = default; + copy?.Dispose(); } } - private sealed class ReadAsyncCallContext : AAsyncCallContext + internal sealed class ReadAsyncCallContext : SqlDataReaderAsyncCallContext { internal static readonly Func> s_execute = SqlDataReader.ReadAsyncExecute; @@ -5965,15 +5990,13 @@ internal ReadAsyncCallContext() internal override Func> Execute => s_execute; - public override void Dispose() + protected override void AfterCleared(SqlDataReader owner) { - SqlDataReader reader = this._reader; - base.Dispose(); - reader.SetCachedReadAsyncCallContext(this); + owner.SetCachedReadAsyncCallContext(this); } } - private sealed class IsDBNullAsyncCallContext : AAsyncCallContext + internal sealed class IsDBNullAsyncCallContext : SqlDataReaderAsyncCallContext { internal static readonly Func> s_execute = SqlDataReader.IsDBNullAsyncExecute; @@ -5983,27 +6006,25 @@ internal IsDBNullAsyncCallContext() { } internal override Func> Execute => s_execute; - public override void Dispose() + protected override void AfterCleared(SqlDataReader owner) { - SqlDataReader reader = this._reader; - base.Dispose(); - reader.SetCachedIDBNullAsyncCallContext(this); + owner.SetCachedIDBNullAsyncCallContext(this); } } - private sealed class HasNextResultAsyncCallContext : AAsyncCallContext + private sealed class HasNextResultAsyncCallContext : SqlDataReaderAsyncCallContext { private static readonly Func> s_execute = SqlDataReader.NextResultAsyncExecute; - public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable) - : base(reader, source, disposable) + public HasNextResultAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable) { + Set(reader, source, disposable); } internal override Func> Execute => s_execute; } - private sealed class GetBytesAsyncCallContext : AAsyncCallContext + private sealed class GetBytesAsyncCallContext : SqlDataReaderAsyncCallContext { internal enum OperationMode { @@ -6014,63 +6035,66 @@ internal enum OperationMode private static readonly Func> s_executeSeek = SqlDataReader.GetBytesAsyncSeekExecute; private static readonly Func> s_executeRead = SqlDataReader.GetBytesAsyncReadExecute; - internal int columnIndex; - internal byte[] buffer; - internal int index; - internal int length; - internal int timeout; - internal CancellationToken cancellationToken; - internal CancellationToken timeoutToken; - internal int totalBytesRead; + internal int _columnIndex; + internal byte[] _buffer; + internal int _index; + internal int _length; + internal int _timeout; + internal CancellationToken _cancellationToken; + internal CancellationToken _timeoutToken; + internal int _totalBytesRead; - internal OperationMode mode; + internal OperationMode _mode; internal GetBytesAsyncCallContext(SqlDataReader reader) { - this._reader = reader ?? throw new ArgumentNullException(nameof(reader)); + Reader = reader ?? throw new ArgumentNullException(nameof(reader)); } - internal override Func> Execute => mode == OperationMode.Seek ? s_executeSeek : s_executeRead; + internal override Func> Execute => _mode == OperationMode.Seek ? s_executeSeek : s_executeRead; - public override void Dispose() + protected override void Clear() { - buffer = null; - cancellationToken = default; - timeoutToken = default; - base.Dispose(); + _buffer = null; + _cancellationToken = default; + _timeoutToken = default; + base.Clear(); } } - private sealed class GetFieldValueAsyncCallContext : AAsyncCallContext + private sealed class GetFieldValueAsyncCallContext : SqlDataReaderAsyncCallContext { private static readonly Func> s_execute = SqlDataReader.GetFieldValueAsyncExecute; - internal readonly int _columnIndex; + internal int _columnIndex; + + internal GetFieldValueAsyncCallContext() { } - internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, IDisposable disposable, int columnIndex) - : base(reader, source, disposable) + internal GetFieldValueAsyncCallContext(SqlDataReader reader, TaskCompletionSource source, CancellationTokenRegistration disposable) { - _columnIndex = columnIndex; + Set(reader, source, disposable); } - internal override Func> Execute => s_execute; - } - - private static Task ExecuteAsyncCallCallback(Task task, object state) - { - AAsyncCallContext context = (AAsyncCallContext)state; - return context._reader.ExecuteAsyncCall(task, context); - } + protected override void Clear() + { + _columnIndex = -1; + base.Clear(); + } - private static void CompleteAsyncCallCallback(Task task, object state) - { - AAsyncCallContext context = (AAsyncCallContext)state; - context._reader.CompleteAsyncCall(task, context); + internal override Func> Execute => s_execute; } - private Task InvokeAsyncCall(AAsyncCallContext context) + /// + /// Starts the process of executing an async call using an SqlDataReaderAsyncCallContext derived context object. + /// After this call the context lifetime is handled by BeginAsyncCall ContinueAsyncCall and CompleteAsyncCall AsyncCall methods + /// + /// + /// + /// + /// + private Task InvokeAsyncCall(SqlDataReaderBaseAsyncCallContext context) { - TaskCompletionSource source = context._source; + TaskCompletionSource source = context.Source; try { Task task; @@ -6080,7 +6104,7 @@ private Task InvokeAsyncCall(AAsyncCallContext context) } catch (Exception ex) { - task = ADP.CreatedTaskWithException(ex); + task = Task.FromException(ex); } if (task.IsCompleted) @@ -6090,7 +6114,7 @@ private Task InvokeAsyncCall(AAsyncCallContext context) else { task.ContinueWith( - continuationAction: AAsyncCallContext.s_completeCallback, + continuationAction: SqlDataReaderBaseAsyncCallContext.s_completeCallback, state: context, TaskScheduler.Default ); @@ -6105,7 +6129,7 @@ private Task InvokeAsyncCall(AAsyncCallContext context) source.TrySetException(e); } - // Fall through for exceptions\completing async + // Fall through for exceptions/completing async return source.Task; } @@ -6115,7 +6139,7 @@ private Task InvokeAsyncCall(AAsyncCallContext context) /// /// /// - private Task ExecuteAsyncCall(AAsyncCallContext context) + private Task ExecuteAsyncCall(AAsyncBaseCallContext context) { // _networkPacketTaskSource could be null if the connection was closed // while an async invocation was outstanding. @@ -6128,7 +6152,7 @@ private Task ExecuteAsyncCall(AAsyncCallContext context) else { return completionSource.Task.ContinueWith( - continuationFunction: AAsyncCallContext.s_executeCallback, + continuationFunction: SqlDataReaderBaseAsyncCallContext.s_executeCallback, state: context, TaskScheduler.Default ).Unwrap(); @@ -6144,10 +6168,10 @@ private Task ExecuteAsyncCall(AAsyncCallContext context) /// /// /// - private Task ExecuteAsyncCall(Task task, AAsyncCallContext context) + private Task ContinueAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context) { // this function must be an instance function called from the static callback because otherwise a compiler error - // is caused by accessing the _cancelAsyncOnCloseToken field of a MarchalByRefObject derived class + // is caused by accessing the _cancelAsyncOnCloseToken field of a MarshalByRefObject derived class if (task.IsFaulted) { // Somehow the network task faulted - return the exception @@ -6204,9 +6228,9 @@ private Task ExecuteAsyncCall(Task task, AAsyncCallContext context) /// /// /// - private void CompleteAsyncCall(Task task, AAsyncCallContext context) + private void CompleteAsyncCall(Task task, SqlDataReaderBaseAsyncCallContext context) { - TaskCompletionSource source = context._source; + TaskCompletionSource source = context.Source; context.Dispose(); // If something has forced us to switch to SyncOverAsync mode while in an async task then we need to guarantee that we do the cleanup @@ -6233,6 +6257,27 @@ private void CompleteAsyncCall(Task task, AAsyncCallContext context) } } + private sealed class Snapshot + { + public bool _dataReady; + public bool _haltRead; + public bool _metaDataConsumed; + public bool _browseModeInfoConsumed; + public bool _hasRows; + public ALTROWSTATUS _altRowStatus; + public int _nextColumnDataToRead; + public int _nextColumnHeaderToRead; + public long _columnDataBytesRead; + public long _columnDataBytesRemaining; + + public _SqlMetaDataSet _metadata; + public _SqlMetaDataSetCollection _altMetaDataSetCollection; + public MultiPartTableName[] _tableNames; + + public SqlSequentialStream _currentStream; + public SqlSequentialTextReader _currentTextReader; + } + private void PrepareAsyncInvocation(bool useSnapshot) { // if there is already a snapshot, then the previous async command @@ -6244,28 +6289,27 @@ private void PrepareAsyncInvocation(bool useSnapshot) if (_snapshot == null) { - _snapshot = new Snapshot - { - _dataReady = _sharedState._dataReady, - _haltRead = _haltRead, - _metaDataConsumed = _metaDataConsumed, - _browseModeInfoConsumed = _browseModeInfoConsumed, - _hasRows = _hasRows, - _altRowStatus = _altRowStatus, - _nextColumnDataToRead = _sharedState._nextColumnDataToRead, - _nextColumnHeaderToRead = _sharedState._nextColumnHeaderToRead, - _columnDataBytesRead = _columnDataBytesRead, - _columnDataBytesRemaining = _sharedState._columnDataBytesRemaining, - - // _metadata and _altaMetaDataSetCollection must be Cloned - // before they are updated - _metadata = _metaData, - _altMetaDataSetCollection = _altMetaDataSetCollection, - _tableNames = _tableNames, - - _currentStream = _currentStream, - _currentTextReader = _currentTextReader, - }; + _snapshot = new Snapshot(); + + _snapshot._dataReady = _sharedState._dataReady; + _snapshot._haltRead = _haltRead; + _snapshot._metaDataConsumed = _metaDataConsumed; + _snapshot._browseModeInfoConsumed = _browseModeInfoConsumed; + _snapshot._hasRows = _hasRows; + _snapshot._altRowStatus = _altRowStatus; + _snapshot._nextColumnDataToRead = _sharedState._nextColumnDataToRead; + _snapshot._nextColumnHeaderToRead = _sharedState._nextColumnHeaderToRead; + _snapshot._columnDataBytesRead = _columnDataBytesRead; + _snapshot._columnDataBytesRemaining = _sharedState._columnDataBytesRemaining; + + // _metadata and _altaMetaDataSetCollection must be Cloned + // before they are updated + _snapshot._metadata = _metaData; + _snapshot._altMetaDataSetCollection = _altMetaDataSetCollection; + _snapshot._tableNames = _tableNames; + + _snapshot._currentStream = _currentStream; + _snapshot._currentTextReader = _currentTextReader; _stateObj.SetSnapshot(); } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs similarity index 100% rename from src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs rename to src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AAsyncCallContext.cs