Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions src/Client/Connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <Common/FailPoint.h>

#include <Common/config_version.h>
#include <Common/scope_guard_safe.h>
#include <Core/Types.h>
#include "config.h"

Expand Down Expand Up @@ -220,7 +221,7 @@ void Connection::connect(const ConnectionTimeouts & timeouts)
connected = true;
setDescription();

sendHello();
sendHello(timeouts.handshake_timeout);
receiveHello(timeouts.handshake_timeout);

if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_CHUNKED_PACKETS)
Expand Down Expand Up @@ -371,7 +372,7 @@ void Connection::disconnect()
}


void Connection::sendHello()
void Connection::sendHello([[maybe_unused]] const Poco::Timespan & handshake_timeout)
{
/** Disallow control characters in user controlled parameters
* to mitigate the possibility of SSRF.
Expand Down Expand Up @@ -424,7 +425,7 @@ void Connection::sendHello()
writeStringBinary(String(EncodedUserInfo::SSH_KEY_AUTHENTICAION_MARKER) + user, *out);
writeStringBinary(password, *out);

performHandshakeForSSHAuth();
performHandshakeForSSHAuth(handshake_timeout);
}
#endif
else if (!jwt.empty())
Expand Down Expand Up @@ -461,8 +462,10 @@ void Connection::sendAddendum()


#if USE_SSH
void Connection::performHandshakeForSSHAuth()
void Connection::performHandshakeForSSHAuth(const Poco::Timespan & handshake_timeout)
{
TimeoutSetter timeout_setter(*socket, handshake_timeout, handshake_timeout);

String challenge;
{
writeVarUInt(Protocol::Client::SSHChallengeRequest, *out);
Expand All @@ -479,11 +482,7 @@ void Connection::performHandshakeForSSHAuth()
else if (packet_type == Protocol::Server::Exception)
receiveException()->rethrow();
else
{
/// Close connection, to not stay in unsynchronised state.
disconnect();
throwUnexpectedPacket(packet_type, "SSHChallenge or Exception");
}
throwUnexpectedPacket(timeout_setter, packet_type, "SSHChallenge or Exception");
}

writeVarUInt(Protocol::Client::SSHChallengeResponse, *out);
Expand Down Expand Up @@ -569,15 +568,7 @@ void Connection::receiveHello(const Poco::Timespan & handshake_timeout)
else if (packet_type == Protocol::Server::Exception)
receiveException()->rethrow();
else
{
/// Reset timeout_setter before disconnect,
/// because after disconnect socket will be invalid.
timeout_setter.reset();

/// Close connection, to not stay in unsynchronised state.
disconnect();
throwUnexpectedPacket(packet_type, "Hello or Exception");
}
throwUnexpectedPacket(timeout_setter, packet_type, "Hello or Exception");
}

void Connection::setDefaultDatabase(const String & database)
Expand Down Expand Up @@ -702,7 +693,7 @@ bool Connection::ping(const ConnectionTimeouts & timeouts)
}

if (pong != Protocol::Server::Pong)
throwUnexpectedPacket(pong, "Pong");
throwUnexpectedPacket(timeout_setter, pong, "Pong");
}
catch (const Poco::Exception & e)
{
Expand Down Expand Up @@ -741,7 +732,7 @@ TablesStatusResponse Connection::getTablesStatus(const ConnectionTimeouts & time
if (response_type == Protocol::Server::Exception)
receiveException()->rethrow();
else if (response_type != Protocol::Server::TablesStatusResponse)
throwUnexpectedPacket(response_type, "TablesStatusResponse");
throwUnexpectedPacket(timeout_setter, response_type, "TablesStatusResponse");

TablesStatusResponse response;
response.read(*in, server_revision);
Expand Down Expand Up @@ -810,6 +801,14 @@ void Connection::sendQuery(

query_id = query_id_;

/// Avoid reusing connections that had been left in the intermediate state
/// (i.e. not all packets had been sent).
bool completed = false;
SCOPE_EXIT({
if (!completed)
disconnect();
});

writeVarUInt(Protocol::Client::Query, *out);
writeStringBinary(query_id, *out);

Expand Down Expand Up @@ -910,6 +909,8 @@ void Connection::sendQuery(
sendData(Block(), "", false);
out->next();
}

completed = true;
}


Expand Down Expand Up @@ -1436,8 +1437,14 @@ InitialAllRangesAnnouncement Connection::receiveInitialParallelReadAnnouncement(
}


void Connection::throwUnexpectedPacket(UInt64 packet_type, const char * expected) const
void Connection::throwUnexpectedPacket(TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected)
{
/// Reset timeout_setter before disconnect, because after disconnect socket will be invalid.
timeout_setter.reset();

/// Close connection, to avoid leaving it in an unsynchronised state.
disconnect();

throw NetException(ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER,
"Unexpected packet from server {} (expected {}, got {})",
getDescription(), expected, String(Protocol::Server::toString(packet_type)));
Expand Down
7 changes: 4 additions & 3 deletions src/Client/Connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace DB
{

struct Settings;
struct TimeoutSetter;

class Connection;
struct ConnectionParameters;
Expand Down Expand Up @@ -275,10 +276,10 @@ class Connection : public IServerConnection
AsyncCallback async_callback = {};

void connect(const ConnectionTimeouts & timeouts);
void sendHello();
void sendHello(const Poco::Timespan & handshake_timeout);

#if USE_SSH
void performHandshakeForSSHAuth();
void performHandshakeForSSHAuth(const Poco::Timespan & handshake_timeout);
#endif

void sendAddendum();
Expand Down Expand Up @@ -306,7 +307,7 @@ class Connection : public IServerConnection
void initBlockLogsInput();
void initBlockProfileEventsInput();

[[noreturn]] void throwUnexpectedPacket(UInt64 packet_type, const char * expected) const;
[[noreturn]] void throwUnexpectedPacket(TimeoutSetter & timeout_setter, UInt64 packet_type, const char * expected);
};

template <typename Conn>
Expand Down
Loading