- 
                Notifications
    
You must be signed in to change notification settings  - Fork 348
 
Verify server ID before KILL QUERY to prevent cross-server cancellation #1575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
a6b86d7
              3e141d3
              09e91c5
              780bc2b
              b0ee0f8
              d2619e9
              ae54398
              9da52c6
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -47,6 +47,8 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool) | |
| public int ActiveCommandId { get; private set; } | ||
| public int CancellationTimeout { get; private set; } | ||
| public int ConnectionId { get; set; } | ||
| public string? ServerUuid { get; set; } | ||
| public long? ServerId { get; set; } | ||
| public byte[]? AuthPluginData { get; set; } | ||
| public long CreatedTimestamp { get; } | ||
| public ConnectionPool? Pool { get; } | ||
| 
          
            
          
           | 
    @@ -117,6 +119,14 @@ public void DoCancel(ICancellableCommand commandToCancel, MySqlCommand killComma | |
| return; | ||
| } | ||
| 
     | 
||
| // Verify server identity before executing KILL QUERY to prevent cancelling on the wrong server | ||
| var killSession = killCommand.Connection!.Session; | ||
| if (!VerifyServerIdentity(killSession)) | ||
| { | ||
| Log.IgnoringCancellationForDifferentServer(m_logger, Id, killSession.Id, ServerUuid, killSession.ServerUuid, ServerId, killSession.ServerId); | ||
| return; | ||
| } | ||
| 
     | 
||
| // NOTE: This command is executed while holding the lock to prevent race conditions during asynchronous cancellation. | ||
| // For example, if the lock weren't held, the current command could finish and the other thread could set ActiveCommandId | ||
| // to zero, then start executing a new command. By the time this "KILL QUERY" command reached the server, the wrong | ||
| 
        
          
        
         | 
    @@ -137,6 +147,26 @@ public void AbortCancel(ICancellableCommand command) | |
| } | ||
| } | ||
| 
     | 
||
| private bool VerifyServerIdentity(ServerSession otherSession) | ||
                
       | 
||
| { | ||
| // If server UUID is available, use it as the primary identifier (most unique) | ||
| if (!string.IsNullOrEmpty(ServerUuid) && !string.IsNullOrEmpty(otherSession.ServerUuid)) | ||
| { | ||
| return string.Equals(ServerUuid, otherSession.ServerUuid, StringComparison.Ordinal); | ||
| } | ||
| 
     | 
||
| // Fall back to server ID if UUID is not available | ||
| if (ServerId.HasValue && otherSession.ServerId.HasValue) | ||
| { | ||
| return ServerId.Value == otherSession.ServerId.Value; | ||
| } | ||
| 
     | 
||
| // If no server identification is available, allow the operation to proceed | ||
| // This maintains backward compatibility with older MySQL versions | ||
| Log.NoServerIdentificationForVerification(m_logger, Id, otherSession.Id); | ||
| return true; | ||
| } | ||
| 
     | 
||
| public bool IsCancelingQuery => m_state == State.CancelingQuery; | ||
| 
     | 
||
| public async Task PrepareAsync(IMySqlCommand command, IOBehavior ioBehavior, CancellationToken cancellationToken) | ||
| 
          
            
          
           | 
    @@ -635,6 +665,9 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella | |
| ConnectionId = newConnectionId; | ||
| } | ||
| 
     | 
||
| // Get server identification for KILL QUERY verification | ||
| await GetServerIdentificationAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| 
     | 
||
| m_payloadHandler.ByteHandler.RemainingTimeout = Constants.InfiniteTimeout; | ||
| return redirectionUrl; | ||
| } | ||
| 
          
            
          
           | 
    @@ -1951,6 +1984,93 @@ private async Task GetRealServerDetailsAsync(IOBehavior ioBehavior, Cancellation | |
| } | ||
| } | ||
| 
     | 
||
| private async Task GetServerIdentificationAsync(IOBehavior ioBehavior, CancellationToken cancellationToken) | ||
                
       | 
||
| { | ||
| Log.GettingServerIdentification(m_logger, Id); | ||
| try | ||
| { | ||
| PayloadData payload; | ||
| 
     | 
||
| // Try to get both server_uuid and server_id if server supports server_uuid (MySQL 5.6+) | ||
| if (!ServerVersion.IsMariaDb && ServerVersion.Version >= ServerVersions.SupportsServerUuid) | ||
| { | ||
| payload = SupportsQueryAttributes ? s_selectServerIdWithAttributesPayload : s_selectServerIdNoAttributesPayload; | ||
| await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); | ||
| 
     | 
||
| // column count: 2 | ||
| _ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); | ||
| 
     | 
||
| // @@server_uuid and @@server_id columns | ||
| _ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| _ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| 
     | 
||
| if (!SupportsDeprecateEof) | ||
| { | ||
| payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| _ = EofPayload.Create(payload.Span); | ||
| } | ||
| 
     | 
||
| // first (and only) row | ||
| payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| 
     | 
||
| var reader = new ByteArrayReader(payload.Span); | ||
| var length = reader.ReadLengthEncodedIntegerOrNull(); | ||
| var serverUuid = length > 0 ? Encoding.UTF8.GetString(reader.ReadByteString(length)) : null; | ||
| length = reader.ReadLengthEncodedIntegerOrNull(); | ||
| var serverId = (length > 0 && Utf8Parser.TryParse(reader.ReadByteString(length), out long id, out _)) ? id : default(long?); | ||
| 
     | 
||
| ServerUuid = serverUuid; | ||
| ServerId = serverId; | ||
| 
     | 
||
| Log.RetrievedServerIdentification(m_logger, Id, serverUuid, serverId); | ||
| } | ||
| else | ||
| { | ||
| // Fall back to just server_id for older versions or MariaDB | ||
| payload = SupportsQueryAttributes ? s_selectServerIdOnlyWithAttributesPayload : s_selectServerIdOnlyNoAttributesPayload; | ||
| await SendAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); | ||
| 
     | 
||
| // column count: 1 | ||
| _ = await ReceiveReplyAsync(ioBehavior, cancellationToken).ConfigureAwait(false); | ||
| 
     | 
||
| // @@server_id column | ||
| _ = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| 
     | 
||
| if (!SupportsDeprecateEof) | ||
| { | ||
| payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| _ = EofPayload.Create(payload.Span); | ||
| } | ||
| 
     | 
||
| // first (and only) row | ||
| payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| 
     | 
||
| var reader = new ByteArrayReader(payload.Span); | ||
| var length = reader.ReadLengthEncodedIntegerOrNull(); | ||
| var serverId = (length > 0 && Utf8Parser.TryParse(reader.ReadByteString(length), out long id, out _)) ? id : default(long?); | ||
| 
     | 
||
| ServerUuid = null; | ||
| ServerId = serverId; | ||
| 
     | 
||
| Log.RetrievedServerIdentification(m_logger, Id, null, serverId); | ||
| } | ||
| 
     | 
||
| // OK/EOF payload | ||
| payload = await ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); | ||
| if (OkPayload.IsOk(payload.Span, this)) | ||
| OkPayload.Verify(payload.Span, this); | ||
| else | ||
| EofPayload.Create(payload.Span); | ||
| } | ||
| catch (MySqlException ex) | ||
| { | ||
| Log.FailedToGetServerIdentification(m_logger, ex, Id); | ||
| // Set fallback values to ensure operation can continue | ||
| ServerUuid = null; | ||
| ServerId = null; | ||
| } | ||
| } | ||
| 
     | 
||
| private void ShutdownSocket() | ||
| { | ||
| Log.ClosingStreamSocket(m_logger, Id); | ||
| 
          
            
          
           | 
    @@ -2182,6 +2302,10 @@ protected override void OnStatementBegin(int index) | |
| private static readonly PayloadData s_sleepWithAttributesPayload = QueryPayload.Create(true, "SELECT SLEEP(0) INTO @__MySqlConnector__Sleep;"u8); | ||
| private static readonly PayloadData s_selectConnectionIdVersionNoAttributesPayload = QueryPayload.Create(false, "SELECT CONNECTION_ID(), VERSION();"u8); | ||
| private static readonly PayloadData s_selectConnectionIdVersionWithAttributesPayload = QueryPayload.Create(true, "SELECT CONNECTION_ID(), VERSION();"u8); | ||
| private static readonly PayloadData s_selectServerIdNoAttributesPayload = QueryPayload.Create(false, "SELECT @@server_uuid, @@server_id;"u8); | ||
| private static readonly PayloadData s_selectServerIdWithAttributesPayload = QueryPayload.Create(true, "SELECT @@server_uuid, @@server_id;"u8); | ||
| private static readonly PayloadData s_selectServerIdOnlyNoAttributesPayload = QueryPayload.Create(false, "SELECT @@server_id;"u8); | ||
| private static readonly PayloadData s_selectServerIdOnlyWithAttributesPayload = QueryPayload.Create(true, "SELECT @@server_id;"u8); | ||
| 
     | 
||
| private readonly ILogger m_logger; | ||
| #if NET9_0_OR_GREATER | ||
| 
          
            
          
           | 
    ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| 
          
            
          
           | 
    @@ -19,4 +19,7 @@ internal static class ServerVersions | |
| 
     | 
||
| // https://mariadb.com/kb/en/set-statement/ | ||
| public static readonly Version MariaDbSupportsPerQueryVariables = new(10, 1, 2); | ||
| 
     | 
||
| // https://dev.mysql.com/doc/refman/5.6/en/replication-options.html#sysvar_server_uuid | ||
| public static readonly Version SupportsServerUuid = new(5, 6, 0); | ||
                
       | 
||
| } | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| using System.Diagnostics; | ||
| 
     | 
||
| namespace IntegrationTests; | ||
| 
     | 
||
| public class ServerIdentificationTests : IClassFixture<DatabaseFixture>, IDisposable | ||
| { | ||
| public ServerIdentificationTests(DatabaseFixture database) | ||
| { | ||
| m_database = database; | ||
| } | ||
| 
     | 
||
| public void Dispose() | ||
| { | ||
| } | ||
| 
     | 
||
| [SkippableFact(ServerFeatures.Timeout)] | ||
| public void CancelCommand_WithServerVerification() | ||
| { | ||
| // This test verifies that cancellation still works with server verification | ||
| using var connection = new MySqlConnection(AppConfig.ConnectionString); | ||
| connection.Open(); | ||
| 
     | 
||
| using var cmd = new MySqlCommand("SELECT SLEEP(5)", connection); | ||
| var task = Task.Run(async () => | ||
| { | ||
| await Task.Delay(TimeSpan.FromSeconds(0.5)); | ||
| cmd.Cancel(); | ||
| }); | ||
| 
     | 
||
| var stopwatch = Stopwatch.StartNew(); | ||
| TestUtilities.AssertExecuteScalarReturnsOneOrIsCanceled(cmd); | ||
| Assert.InRange(stopwatch.ElapsedMilliseconds, 250, 2500); | ||
| 
     | 
||
| #pragma warning disable xUnit1031 // Do not use blocking task operations in test method | ||
| task.Wait(); // shouldn't throw | ||
| #pragma warning restore xUnit1031 // Do not use blocking task operations in test method | ||
| 
     | 
||
| TestUtilities.LogInfo("Cancellation with server verification completed successfully"); | ||
                
       | 
||
| } | ||
| 
     | 
||
| [SkippableFact(ServerFeatures.KnownCertificateAuthority)] | ||
| public void ServerHasServerIdentification() | ||
| { | ||
| using var connection = new MySqlConnection(AppConfig.ConnectionString); | ||
| connection.Open(); | ||
| 
     | 
||
| // Test that we can query server identification manually | ||
| using var cmd = new MySqlCommand("SELECT @@server_id", connection); | ||
| var serverId = cmd.ExecuteScalar(); | ||
| Assert.NotNull(serverId); | ||
| TestUtilities.LogInfo($"Server ID: {serverId}"); | ||
| 
     | 
||
| // Test server UUID if available (MySQL 5.6+) | ||
| if (connection.ServerVersion.Version.Major > 5 || | ||
| (connection.ServerVersion.Version.Major == 5 && connection.ServerVersion.Version.Minor >= 6)) | ||
| { | ||
| try | ||
| { | ||
| using var uuidCmd = new MySqlCommand("SELECT @@server_uuid", connection); | ||
| var serverUuid = uuidCmd.ExecuteScalar(); | ||
| Assert.NotNull(serverUuid); | ||
| TestUtilities.LogInfo($"Server UUID: {serverUuid}"); | ||
| } | ||
| catch (MySqlException ex) when (ex.ErrorCode == MySqlErrorCode.UnknownSystemVariable) | ||
| { | ||
| // Some MySQL-compatible servers might not support server_uuid | ||
| TestUtilities.LogInfo("Server UUID not supported on this server"); | ||
| } | ||
| } | ||
| } | ||
| 
     | 
||
| private readonly DatabaseFixture m_database; | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's replace these both with
ServerHostname. It will be filled in with@@hostnameornullif that variable isn't available.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced ServerUuid and ServerId properties with a single ServerHostname property that uses @@hostname in commit a6b86d7.