@@ -219,7 +219,7 @@ func TestConnection(t *testing.T) {
219219 for _ , tc := range testCases {
220220 t .Run (tc .name , func (t * testing.T ) {
221221 var sentCfg * tls.Config
222- var testTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) * tls. Conn {
222+ var testTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) tlsConn {
223223 sentCfg = cfg
224224 return tls .Client (nc , cfg )
225225 }
@@ -252,6 +252,143 @@ func TestConnection(t *testing.T) {
252252 }
253253 })
254254 })
255+ t .Run ("connectTimeout is applied correctly" , func (t * testing.T ) {
256+ testCases := []struct {
257+ name string
258+ contextTimeout time.Duration
259+ connectTimeout time.Duration
260+ maxConnectTime time.Duration
261+ }{
262+ // The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
263+ // both of the tests declared below. Both tests also specify a 10ms max connect time to provide
264+ // a large buffer for lag and avoid test flakiness.
265+
266+ {"context timeout is lower" , 1 * time .Millisecond , 100 * time .Millisecond , 10 * time .Millisecond },
267+ {"connect timeout is lower" , 100 * time .Millisecond , 1 * time .Millisecond , 10 * time .Millisecond },
268+ }
269+
270+ for _ , tc := range testCases {
271+ t .Run ("timeout applied to socket establishment: " + tc .name , func (t * testing.T ) {
272+ // Ensure the initial connection dial can be timed out and the connection propagates the error
273+ // from the dialer in this case.
274+
275+ connOpts := []ConnectionOption {
276+ WithDialer (func (Dialer ) Dialer {
277+ return DialerFunc (func (ctx context.Context , _ , _ string ) (net.Conn , error ) {
278+ <- ctx .Done ()
279+ return nil , ctx .Err ()
280+ })
281+ }),
282+ WithConnectTimeout (func (time.Duration ) time.Duration {
283+ return tc .connectTimeout
284+ }),
285+ }
286+ conn , err := newConnection ("" , connOpts ... )
287+ assert .Nil (t , err , "newConnection error: %v" , err )
288+
289+ ctx , cancel := context .WithTimeout (context .Background (), tc .contextTimeout )
290+ defer cancel ()
291+ var connectErr error
292+ callback := func () {
293+ conn .connect (ctx )
294+ connectErr = conn .wait ()
295+ }
296+ assert .Soon (t , callback , tc .maxConnectTime )
297+
298+ ce , ok := connectErr .(ConnectionError )
299+ assert .True (t , ok , "expected error %v to be of type %T" , connectErr , ConnectionError {})
300+ assert .Equal (t , context .DeadlineExceeded , ce .Unwrap (), "expected wrapped error to be %v, got %v" ,
301+ context .DeadlineExceeded , ce .Unwrap ())
302+ })
303+ t .Run ("timeout applied to TLS handshake: " + tc .name , func (t * testing.T ) {
304+ // Ensure the TLS handshake can be timed out and the connection propagates the error from the
305+ // tlsConn in this case.
306+
307+ var hangingTLSConnectionSource tlsConnectionSourceFn = func (nc net.Conn , cfg * tls.Config ) tlsConn {
308+ tlsConn := tls .Client (nc , cfg )
309+ return newHangingTLSConn (tlsConn , tc .maxConnectTime )
310+ }
311+
312+ connOpts := []ConnectionOption {
313+ WithConnectTimeout (func (time.Duration ) time.Duration {
314+ return tc .connectTimeout
315+ }),
316+ WithDialer (func (Dialer ) Dialer {
317+ return DialerFunc (func (context.Context , string , string ) (net.Conn , error ) {
318+ return & net.TCPConn {}, nil
319+ })
320+ }),
321+ WithTLSConfig (func (* tls.Config ) * tls.Config {
322+ return & tls.Config {}
323+ }),
324+ withTLSConnectionSource (func (tlsConnectionSource ) tlsConnectionSource {
325+ return hangingTLSConnectionSource
326+ }),
327+ }
328+ conn , err := newConnection ("" , connOpts ... )
329+ assert .Nil (t , err , "newConnection error: %v" , err )
330+
331+ ctx , cancel := context .WithTimeout (context .Background (), tc .contextTimeout )
332+ defer cancel ()
333+ var connectErr error
334+ callback := func () {
335+ conn .connect (ctx )
336+ connectErr = conn .wait ()
337+ }
338+ assert .Soon (t , callback , tc .maxConnectTime )
339+
340+ ce , ok := connectErr .(ConnectionError )
341+ assert .True (t , ok , "expected error %v to be of type %T" , connectErr , ConnectionError {})
342+ assert .Equal (t , context .DeadlineExceeded , ce .Unwrap (), "expected wrapped error to be %v, got %v" ,
343+ context .DeadlineExceeded , ce .Unwrap ())
344+ })
345+ t .Run ("timeout is not applied to handshaker: " + tc .name , func (t * testing.T ) {
346+ // Ensure that no additional timeout is applied to the handshake after the connection has been
347+ // established.
348+
349+ var getInfoCtx , finishCtx context.Context
350+ handshaker := & testHandshaker {
351+ getHandshakeInformation : func (ctx context.Context , _ address.Address , _ driver.Connection ) (driver.HandshakeInformation , error ) {
352+ getInfoCtx = ctx
353+ return driver.HandshakeInformation {}, nil
354+ },
355+ finishHandshake : func (ctx context.Context , _ driver.Connection ) error {
356+ finishCtx = ctx
357+ return nil
358+ },
359+ }
360+
361+ connOpts := []ConnectionOption {
362+ WithConnectTimeout (func (time.Duration ) time.Duration {
363+ return tc .connectTimeout
364+ }),
365+ WithDialer (func (Dialer ) Dialer {
366+ return DialerFunc (func (context.Context , string , string ) (net.Conn , error ) {
367+ return & net.TCPConn {}, nil
368+ })
369+ }),
370+ WithHandshaker (func (Handshaker ) Handshaker {
371+ return handshaker
372+ }),
373+ }
374+ conn , err := newConnection ("" , connOpts ... )
375+ assert .Nil (t , err , "newConnection error: %v" , err )
376+
377+ bgCtx := context .Background ()
378+ conn .connect (bgCtx )
379+ err = conn .wait ()
380+ assert .Nil (t , err , "connect error: %v" , err )
381+
382+ assertNoContextTimeout := func (t * testing.T , ctx context.Context ) {
383+ t .Helper ()
384+ dl , ok := ctx .Deadline ()
385+ assert .False (t , ok , "expected context to have no deadline, but got deadline %v" , dl )
386+ }
387+ assertNoContextTimeout (t , getInfoCtx )
388+ assertNoContextTimeout (t , finishCtx )
389+ })
390+ }
391+ })
255392 })
256393 t .Run ("writeWireMessage" , func (t * testing.T ) {
257394 t .Run ("closed connection" , func (t * testing.T ) {
@@ -993,3 +1130,24 @@ func (t *testCancellationListener) assertMethodsCalled(testingT *testing.T, numL
9931130 assert .Equal (testingT , numStopListening , t .numStopListening , "expected StopListening to be called %d times, got %d" ,
9941131 numListen , t .numListen )
9951132}
1133+
1134+ // hangingTLSConn is an implementation of tlsConn that wraps the tls.Conn type and overrides the Handshake function to
1135+ // sleep for a fixed amount of time.
1136+ type hangingTLSConn struct {
1137+ * tls.Conn
1138+ sleepTime time.Duration
1139+ }
1140+
1141+ var _ tlsConn = (* hangingTLSConn )(nil )
1142+
1143+ func newHangingTLSConn (conn * tls.Conn , sleepTime time.Duration ) * hangingTLSConn {
1144+ return & hangingTLSConn {
1145+ Conn : conn ,
1146+ sleepTime : sleepTime ,
1147+ }
1148+ }
1149+
1150+ func (h * hangingTLSConn ) Handshake () error {
1151+ time .Sleep (h .sleepTime )
1152+ return h .Conn .Handshake ()
1153+ }
0 commit comments