From a57137c0bd2557757fddd29667b036fc90a25c51 Mon Sep 17 00:00:00 2001 From: Mikhail Mazurskiy Date: Wed, 27 Aug 2025 11:30:59 +1000 Subject: [PATCH] Ensure Close() waits for goroutines when called concurrently --- net/multi_listen.go | 43 +++++++++++++++++++--------------------- net/multi_listen_test.go | 8 +------- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/net/multi_listen.go b/net/multi_listen.go index e5d50805..4655075c 100644 --- a/net/multi_listen.go +++ b/net/multi_listen.go @@ -21,7 +21,6 @@ import ( "fmt" "net" "sync" - "sync/atomic" ) // connErrPair pairs conn and error which is returned by accept on sub-listeners. @@ -38,8 +37,8 @@ type multiListener struct { // connCh passes accepted connections, from child listeners to parent. connCh chan connErrPair // stopCh communicates from parent to child listeners. - stopCh chan struct{} - closed atomic.Bool + stopCh chan struct{} + closeOnce sync.Once } // compile time check to ensure *multiListener implements net.Listener @@ -152,29 +151,27 @@ func (ml *multiListener) Accept() (net.Conn, error) { // the go-routines to exit. func (ml *multiListener) Close() error { // Make sure this can be called repeatedly without explosions. - if !ml.closed.CompareAndSwap(false, true) { - return fmt.Errorf("use of closed network connection") - } - - // Tell all sub-listeners to stop. - close(ml.stopCh) - - // Closing the listeners causes Accept() to immediately return an error in - // the sub-listener go-routines. - for _, l := range ml.listeners { - _ = l.Close() - } + ml.closeOnce.Do(func() { + // Tell all sub-listeners to stop. + close(ml.stopCh) + + // Closing the listeners causes Accept() to immediately return an error in + // the sub-listener go-routines. + for _, l := range ml.listeners { + _ = l.Close() + } - // Wait for all the sub-listener go-routines to exit. - ml.wg.Wait() - close(ml.connCh) + // Wait for all the sub-listener go-routines to exit. + ml.wg.Wait() + close(ml.connCh) - // Drain any already-queued connections. - for connErr := range ml.connCh { - if connErr.conn != nil { - _ = connErr.conn.Close() + // Drain any already-queued connections. + for connErr := range ml.connCh { + if connErr.conn != nil { + _ = connErr.conn.Close() + } } - } + }) return nil } diff --git a/net/multi_listen_test.go b/net/multi_listen_test.go index 9a10feb7..32cd14b9 100644 --- a/net/multi_listen_test.go +++ b/net/multi_listen_test.go @@ -227,7 +227,6 @@ func TestMultiListen_Close(t *testing.T) { runner func(listener net.Listener, acceptCalls int) error fakeListeners []*fakeListener acceptCalls int - errString string }{ { name: "close", @@ -327,7 +326,6 @@ func TestMultiListen_Close(t *testing.T) { return nil }, fakeListeners: []*fakeListener{{}, {}, {}}, - errString: "use of closed network connection", }, } @@ -339,11 +337,7 @@ func TestMultiListen_Close(t *testing.T) { t.Errorf("Did not expect error: %v", err) } err = tc.runner(ml, tc.acceptCalls) - if tc.errString != "" { - assertError(t, tc.errString, err) - } else { - assertNoError(t, err) - } + assertNoError(t, err) for _, f := range tc.fakeListeners { if !f.closed.Load() {