Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -567,17 +567,6 @@ internal SqlDataReader FindLiveReader(SqlCommand command)
return reader;
}

internal SqlCommand FindLiveCommand(TdsParserStateObject stateObj)
{
SqlCommand command = null;
SqlReferenceCollection referenceCollection = (SqlReferenceCollection)ReferenceCollection;
if (null != referenceCollection)
{
command = referenceCollection.FindLiveCommand(stateObj);
}
return command;
}

abstract protected byte[] GetDTCAddress();

static private byte[] GetTransactionCookie(Transaction transaction, byte[] whereAbouts)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,38 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Diagnostics;
using System.Threading;
using Microsoft.Data.ProviderBase;

namespace Microsoft.Data.SqlClient
{
sealed internal class SqlReferenceCollection : DbReferenceCollection
{
private sealed class FindLiveReaderContext
{
public readonly Func<SqlDataReader, bool> Func;

private SqlCommand _command;

public FindLiveReaderContext() => Func = Predicate;

public void Setup(SqlCommand command) => _command = command;

public void Clear() => _command = null;

private bool Predicate(SqlDataReader reader) => (!reader.IsClosed) && (_command == reader.Command);
}

internal const int DataReaderTag = 1;
internal const int CommandTag = 2;
internal const int BulkCopyTag = 3;

override public void Add(object value, int tag)
private readonly static Func<SqlDataReader, bool> s_hasOpenReaderFunc = HasOpenReaderPredicate;
private static FindLiveReaderContext s_cachedFindLiveReaderContext;

public override void Add(object value, int tag)
{
Debug.Assert(DataReaderTag == tag || CommandTag == tag || BulkCopyTag == tag, "unexpected tag?");
Debug.Assert(DataReaderTag != tag || value is SqlDataReader, "tag doesn't match object type: SqlDataReader");
Expand All @@ -30,25 +50,24 @@ internal void Deactivate()

internal SqlDataReader FindLiveReader(SqlCommand command)
{
if (command == null)
if (command is null)
{
// if null == command, will find first live datareader
return FindItem<SqlDataReader>(DataReaderTag, (dataReader) => (!dataReader.IsClosed));
return FindItem(DataReaderTag, s_hasOpenReaderFunc);
}
else
{
// else will find live datareader associated with the command
return FindItem<SqlDataReader>(DataReaderTag, (dataReader) => ((!dataReader.IsClosed) && (command == dataReader.Command)));
FindLiveReaderContext context = Interlocked.Exchange(ref s_cachedFindLiveReaderContext, null) ?? new FindLiveReaderContext();
context.Setup(command);
SqlDataReader retval = FindItem(DataReaderTag, context.Func);
context.Clear();
Interlocked.CompareExchange(ref s_cachedFindLiveReaderContext, context, null);
return retval;
}
}

// Finds a SqlCommand associated with the given StateObject
internal SqlCommand FindLiveCommand(TdsParserStateObject stateObj)
{
return FindItem<SqlCommand>(CommandTag, (command) => (command.StateObject == stateObj));
}

override protected void NotifyItem(int message, int tag, object value)
protected override void NotifyItem(int message, int tag, object value)
{
Debug.Assert(0 == message, "unexpected message?");
Debug.Assert(DataReaderTag == tag || CommandTag == tag || BulkCopyTag == tag, "unexpected tag?");
Expand All @@ -74,11 +93,13 @@ override protected void NotifyItem(int message, int tag, object value)
}
}

override public void Remove(object value)
public override void Remove(object value)
{
Debug.Assert(value is SqlDataReader || value is SqlCommand || value is SqlBulkCopy, "SqlReferenceCollection.Remove expected a SqlDataReader or SqlCommand or SqlBulkCopy");

base.RemoveItem(value);
}

private static bool HasOpenReaderPredicate(SqlDataReader reader) => !reader.IsClosed;
}
}