diff --git a/src/DataDog.Tracing.Sql.Tests/EntityFrameworkCore/TraceDbCommandTests.cs b/src/DataDog.Tracing.Sql.Tests/EntityFrameworkCore/TraceDbCommandTests.cs new file mode 100644 index 0000000..138ba80 --- /dev/null +++ b/src/DataDog.Tracing.Sql.Tests/EntityFrameworkCore/TraceDbCommandTests.cs @@ -0,0 +1,98 @@ +using System.Collections.Generic; +using System.Data; +using FluentAssertions; +using Microsoft.Data.Sqlite; +using NUnit.Framework; + +namespace DataDog.Tracing.Sql.Tests.EntityFrameworkCore +{ + [TestFixture] + public class TraceDbCommandTests + { + RootSpan _root; + SqliteConnection _conn; + + [SetUp] + public void SetUp() + { + _conn = new SqliteConnection("Filename=./test.db"); + _conn.Open(); + _root = new RootSpan(); + } + + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.CloseConnection)] + public void ExecuteReader_is_traced(CommandBehavior commandBehavior) + { + var customers = new List(); + using (var command = new Sql.EntityFrameworkCore.TraceDbCommand(_conn.CreateCommand(), _root)) + { + command.CommandText = "SELECT * FROM Customers"; + command.CommandType = CommandType.Text; + using (var reader = command.ExecuteReader(commandBehavior)) + { + while (reader.Read()) + { + customers.Add(new Customer(reader)); + } + } + } + customers.Count.Should().Be(2); + _root.Spans[1].Name.Should().Be("sql." + nameof(IDbCommand.ExecuteReader)); + _root.Spans[1].Service.Should().Be("sql"); + _root.Spans[1].Resource.Should().Be("main"); + _root.Spans[1].Type.Should().Be("sql"); + _root.Spans[1].Error.Should().Be(0); + _root.Spans[1].Meta["sql.CommandBehavior"].Should().Be(commandBehavior.ToString("x")); + _root.Spans[1].Meta["sql.CommandText"].Should().Be("SELECT * FROM Customers"); + _root.Spans[1].Meta["sql.CommandType"].Should().Be("Text"); + } + + [Test] + public void ExecuteNonQuery_is_traced() + { + int rows; + using (var command = new Sql.EntityFrameworkCore.TraceDbCommand(_conn.CreateCommand(), _root)) + { + command.CommandText = "SELECT * FROM Customers"; + command.CommandType = CommandType.Text; + rows = command.ExecuteNonQuery(); + } + _root.Spans[1].Name.Should().Be("sql." + nameof(IDbCommand.ExecuteNonQuery)); + _root.Spans[1].Service.Should().Be("sql"); + _root.Spans[1].Resource.Should().Be("main"); + _root.Spans[1].Type.Should().Be("sql"); + _root.Spans[1].Error.Should().Be(0); + _root.Spans[1].Meta["sql.RowsAffected"].Should().Be(rows.ToString()); + _root.Spans[1].Meta["sql.CommandText"].Should().Be("SELECT * FROM Customers"); + _root.Spans[1].Meta["sql.CommandType"].Should().Be("Text"); + } + + [Test] + public void ExecuteScalar_is_traced() + { + object result; + using (var command = new Sql.EntityFrameworkCore.TraceDbCommand(_conn.CreateCommand(), _root)) + { + command.CommandText = "SELECT COUNT(*) FROM Customers"; + command.CommandType = CommandType.Text; + result = command.ExecuteScalar(); + } + result.Should().Be(2L); + _root.Spans[1].Name.Should().Be("sql." + nameof(IDbCommand.ExecuteScalar)); + _root.Spans[1].Service.Should().Be("sql"); + _root.Spans[1].Resource.Should().Be("main"); + _root.Spans[1].Type.Should().Be("sql"); + _root.Spans[1].Error.Should().Be(0); + _root.Spans[1].Meta["sql.CommandText"].Should().Be("SELECT COUNT(*) FROM Customers"); + _root.Spans[1].Meta["sql.CommandType"].Should().Be("Text"); + } + + [TearDown] + public void TearDown() + { + _conn.Close(); + } + } +} diff --git a/src/DataDog.Tracing.Sql.Tests/EntityFrameworkCore/TraceDbConnectionTests.cs b/src/DataDog.Tracing.Sql.Tests/EntityFrameworkCore/TraceDbConnectionTests.cs new file mode 100644 index 0000000..a09cd3b --- /dev/null +++ b/src/DataDog.Tracing.Sql.Tests/EntityFrameworkCore/TraceDbConnectionTests.cs @@ -0,0 +1,41 @@ +using FluentAssertions; +using Microsoft.Data.Sqlite; +using NUnit.Framework; + +namespace DataDog.Tracing.Sql.Tests.EntityFrameworkCore +{ + [TestFixture] + public class TraceDbConnectionTests + { + RootSpan _root; + SqliteConnection _conn; + + [SetUp] + public void SetUp() + { + _root = new RootSpan(); + _conn = new SqliteConnection("Filename=./test.db"); + } + + [TearDown] + public void TearDown() + { + _conn.Dispose(); + } + + [Test] + public void Open_should_be_traced() + { + var conn = new Sql.EntityFrameworkCore.TraceDbConnection(_conn, _root); + _root.Spans.Count.Should().Be(1); + conn.Open(); + _root.Spans.Count.Should().Be(2); + var s = _root.Spans[1]; + s.Error.Should().Be(0); + s.Name.Should().Be("sql.connect"); + s.Service.Should().Be("sql"); + s.Resource.Should().Be("main"); + s.Type.Should().Be("sql"); + } + } +} diff --git a/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbCommand.cs b/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbCommand.cs new file mode 100644 index 0000000..5d535ce --- /dev/null +++ b/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbCommand.cs @@ -0,0 +1,173 @@ +using System; +using System.Data; +using System.Data.Common; + +namespace DataDog.Tracing.Sql.EntityFrameworkCore +{ + public class TraceDbCommand : DbCommand + { + private const string DefaultServiceName = "sql"; + private const string TypeName = "sql"; + + private string ServiceName { get; } + + private readonly DbCommand _command; + private readonly ISpanSource _spanSource; + + public IDbCommand InnerCommand => _command; + + protected override DbParameterCollection DbParameterCollection => _command.Parameters; + + public override bool DesignTimeVisible + { + get => _command.DesignTimeVisible; + set => _command.DesignTimeVisible = value; + } + + public override string CommandText + { + get => _command.CommandText; + set => _command.CommandText = value; + } + + public override int CommandTimeout + { + get => _command.CommandTimeout; + set => _command.CommandTimeout = value; + } + + public override CommandType CommandType + { + get => _command.CommandType; + set => _command.CommandType = value; + } + + protected override DbConnection DbConnection + { + get => _command.Connection; + set => _command.Connection = value; + } + + public override UpdateRowSource UpdatedRowSource + { + get => _command.UpdatedRowSource; + set => _command.UpdatedRowSource = value; + } + + protected override DbTransaction DbTransaction + { + get => _command.Transaction; + set => _command.Transaction = + value is TraceDbTransaction transaction + ? transaction.Transaction + : value; + } + + public TraceDbCommand(DbCommand command) + : this(command, DefaultServiceName, TraceContextSpanSource.Instance) { } + + public TraceDbCommand(DbCommand command, string serviceName) + : this(command, serviceName, TraceContextSpanSource.Instance) { } + + public TraceDbCommand(DbCommand command, ISpanSource spanSource) + : this(command, DefaultServiceName, spanSource) { } + + public TraceDbCommand(DbCommand command, string serviceName, ISpanSource spanSource) + { + _command = command ?? throw new ArgumentNullException(nameof(command)); + _spanSource = spanSource ?? throw new ArgumentNullException(nameof(spanSource)); + + ServiceName = string.IsNullOrWhiteSpace(serviceName) + ? DefaultServiceName + : serviceName; + } + + public new void Dispose() => _command.Dispose(); + + public override void Cancel() => _command.Cancel(); + + public override void Prepare() => _command.Prepare(); + + protected override DbParameter CreateDbParameter() => _command.CreateParameter(); + + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + { + const string name = "sql." + nameof(ExecuteReader); + var span = _spanSource.Begin(name, ServiceName, _command.Connection.Database, TypeName); + try + { + if (span != null) + { + const string metaKey = "sql." + nameof(CommandBehavior); + span.SetMeta(metaKey, behavior.ToString("x")); + SetMeta(span); + } + + return _command.ExecuteReader(behavior); + } + catch (Exception ex) + { + span?.SetError(ex); + throw; + } + finally + { + span?.Dispose(); + } + } + + public override int ExecuteNonQuery() + { + const string name = "sql." + nameof(ExecuteNonQuery); + var span = _spanSource.Begin(name, ServiceName, _command.Connection.Database, TypeName); + try + { + var result = _command.ExecuteNonQuery(); + if (span != null) + { + span.SetMeta("sql.RowsAffected", result.ToString()); + SetMeta(span); + } + + return result; + } + catch (Exception ex) + { + span?.SetError(ex); + throw; + } + finally + { + span?.Dispose(); + } + } + + public override object ExecuteScalar() + { + const string name = "sql." + nameof(ExecuteScalar); + var span = _spanSource.Begin(name, ServiceName, _command.Connection.Database, TypeName); + try + { + if (span != null) + SetMeta(span); + + return _command.ExecuteScalar(); + } + catch (Exception ex) + { + span?.SetError(ex); + throw; + } + finally + { + span?.Dispose(); + } + } + + private void SetMeta(ISpan span) + { + span.SetMeta("sql.CommandText", CommandText); + span.SetMeta("sql.CommandType", CommandType.ToString()); + } + } +} diff --git a/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbConnection.cs b/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbConnection.cs new file mode 100644 index 0000000..ca278da --- /dev/null +++ b/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbConnection.cs @@ -0,0 +1,87 @@ +using System; +using System.Data; +using System.Data.Common; + +namespace DataDog.Tracing.Sql.EntityFrameworkCore +{ + public class TraceDbConnection : DbConnection + { + private const string DefaultServiceName = "sql"; + private const string TypeName = "sql"; + + private string ServiceName { get; } + + private readonly ISpanSource _spanSource; + private readonly DbConnection _connection; + + public IDbConnection InnerConnection => _connection; + + public override int ConnectionTimeout => _connection.ConnectionTimeout; + + public override string Database => _connection.Database; + + public override string DataSource => _connection.DataSource; + + public override string ServerVersion => _connection.ServerVersion; + + public override ConnectionState State => _connection.State; + + public override string ConnectionString + { + get => _connection.ConnectionString; + set => _connection.ConnectionString = value; + } + + public TraceDbConnection(DbConnection connection) + : this(connection, DefaultServiceName, TraceContextSpanSource.Instance) { } + + public TraceDbConnection(DbConnection connection, string serviceName) + : this(connection, serviceName, TraceContextSpanSource.Instance) { } + + public TraceDbConnection(DbConnection connection, ISpanSource spanSource) + : this(connection, DefaultServiceName, spanSource) { } + + public TraceDbConnection(DbConnection connection, string serviceName, ISpanSource spanSource) + { + _connection = connection ?? throw new ArgumentNullException(nameof(connection)); + _spanSource = spanSource ?? throw new ArgumentNullException(nameof(spanSource)); + + ServiceName = string.IsNullOrWhiteSpace(serviceName) + ? DefaultServiceName + : serviceName; + } + + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) + => new TraceDbTransaction(this, _connection.BeginTransaction(isolationLevel), ServiceName, _spanSource); + + protected override DbCommand CreateDbCommand() + => new TraceDbCommand(_connection.CreateCommand(), ServiceName, _spanSource); + + public new void Dispose() + => _connection.Dispose(); + + public override void ChangeDatabase(string databaseName) + => _connection.ChangeDatabase(databaseName); + + public override void Close() + => _connection.Close(); + + public override void Open() + { + var span = _spanSource.Begin("sql.connect", ServiceName, _connection.Database, TypeName); + try + { + _connection.Open(); + } + catch (Exception ex) + { + span?.SetError(ex); + throw; + } + finally + { + span?.Dispose(); + } + } + } +} diff --git a/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbTransaction.cs b/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbTransaction.cs new file mode 100644 index 0000000..01a9279 --- /dev/null +++ b/src/DataDog.Tracing.Sql/EntityFrameworkCore/TraceDbTransaction.cs @@ -0,0 +1,85 @@ +using System; +using System.Data; +using System.Data.Common; + +namespace DataDog.Tracing.Sql.EntityFrameworkCore +{ + // Entity Framework has a check like this: + // if (connection.DbConnection != transaction.Connection) + // throw new InvalidOperationException(RelationalStrings.TransactionAssociatedWithDifferentConnection); + // Where connection.DbConnection is of type TraceDbConnection and transaction.Connection is of type SqlConnection + // Because of this we need to implement TraceDbTransaction + public class TraceDbTransaction : DbTransaction + { + private const string DefaultServiceName = "sql"; + private const string TypeName = "sql"; + + private string ServiceName { get; } + + private readonly ISpanSource _spanSource; + + public DbTransaction Transaction { get; } + + protected override DbConnection DbConnection { get; } + + public override IsolationLevel IsolationLevel => Transaction.IsolationLevel; + + public TraceDbTransaction(DbConnection connection, DbTransaction transaction) + : this(connection, transaction, DefaultServiceName, TraceContextSpanSource.Instance) { } + + public TraceDbTransaction(DbConnection connection, DbTransaction transaction, string serviceName) + : this(connection, transaction, serviceName, TraceContextSpanSource.Instance) { } + + public TraceDbTransaction(DbConnection connection, DbTransaction transaction, ISpanSource spanSource) + : this(connection, transaction, DefaultServiceName, spanSource) { } + + public TraceDbTransaction(DbConnection connection, DbTransaction transaction, string serviceName, ISpanSource spanSource) + { + DbConnection = connection ?? throw new ArgumentNullException(nameof(connection)); + Transaction = transaction ?? throw new ArgumentNullException(nameof(transaction)); + _spanSource = spanSource ?? throw new ArgumentNullException(nameof(spanSource)); + + ServiceName = string.IsNullOrWhiteSpace(serviceName) + ? DefaultServiceName + : serviceName; + } + + public override void Commit() + { + const string name = "sql." + nameof(Commit); + var span = _spanSource.Begin(name, ServiceName, Transaction.Connection.Database, TypeName); + try + { + Transaction.Commit(); + } + catch (Exception ex) + { + span?.SetError(ex); + throw; + } + finally + { + span?.Dispose(); + } + } + + public override void Rollback() + { + const string name = "sql." + nameof(Commit); + var span = _spanSource.Begin(name, ServiceName, Transaction.Connection.Database, TypeName); + try + { + Transaction.Rollback(); + } + catch (Exception ex) + { + span?.SetError(ex); + throw; + } + finally + { + span?.Dispose(); + } + } + } +}