Skip to content

Commit ab3b823

Browse files
committed
[pkg/signaller] Add a package to handle Pause signal
This essentially is a copy of code used is xcontext to do the same function. The major difference is that now signals (like Pause) are completely decoupled from context cancelling. Thus IsSignaledWith() and Until(nil) do not catch a Context closure. Signed-off-by: Dmitrii Okunev <[email protected]>
1 parent 4d6d981 commit ab3b823

File tree

10 files changed

+353
-0
lines changed

10 files changed

+353
-0
lines changed

pkg/signaller/broadcast_signal.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package signaller
2+
3+
type broadcastSignalT struct{}
4+
5+
func (broadcastSignalT) Error() string { return "broadcast-signal" }
6+
7+
var broadcastSignal = broadcastSignalT{}

pkg/signaller/ctx_key.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package signaller
2+
3+
type signallerCtxKeyT struct{}
4+
5+
var signallerCtxKey = signallerCtxKeyT{}

pkg/signaller/is_signaled_with.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package signaller
2+
3+
import (
4+
"context"
5+
"errors"
6+
)
7+
8+
// IsSignaledWith returns true if the context received a signal equals
9+
// to any of passed ones.
10+
//
11+
// If signals is empty, then returns true if the context received any
12+
// signal.
13+
func IsSignaledWith(ctx context.Context, signals ...error) bool {
14+
s := getSignaller(ctx)
15+
if s == nil {
16+
return false
17+
}
18+
19+
s.locker.Lock()
20+
defer s.locker.Unlock()
21+
22+
return s.isSignaledWith(signals...)
23+
}
24+
25+
func (s *signaller) isSignaledWith(signals ...error) bool {
26+
if len(signals) == 0 && (len(s.receivedSignals) != 0) {
27+
return true
28+
}
29+
30+
for _, receivedErr := range s.receivedSignals {
31+
for _, err := range signals {
32+
if errors.Is(receivedErr, err) {
33+
return true
34+
}
35+
}
36+
}
37+
38+
return false
39+
}

pkg/signaller/received_signals.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package signaller
2+
3+
import "context"
4+
5+
func ReceivedSignals(ctx context.Context) []error {
6+
s := getSignaller(ctx)
7+
return s.ReceivedSignals()
8+
}
9+
10+
// ReceivedSignals returns all the received signals (including events
11+
// received by parents).
12+
//
13+
// This is a read-only value, do not modify it.
14+
func (s *signaller) ReceivedSignals() []error {
15+
if s == nil {
16+
return nil
17+
}
18+
19+
s.locker.Lock()
20+
defer s.locker.Unlock()
21+
22+
return s.receivedSignals
23+
}

pkg/signaller/signaller.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package signaller
2+
3+
import (
4+
"context"
5+
"runtime"
6+
"sync"
7+
)
8+
9+
var (
10+
openChan = make(chan struct{})
11+
closedChan = make(chan struct{})
12+
)
13+
14+
func init() {
15+
close(closedChan)
16+
}
17+
18+
type signaller struct {
19+
// locker is used exclude concurrent access to any data below (in this
20+
// structure).
21+
locker sync.Mutex
22+
23+
// children are all signallers derived from this one. And for example
24+
// if this one will a signal it will be propagated to all the children.
25+
children map[*signaller]struct{}
26+
27+
// receivedSignals accumulates any signals ever received by this instance
28+
// and/or by any parent.
29+
receivedSignals []error
30+
31+
// signalChans are the channels returned by Until.
32+
signalChans map[error]chan struct{}
33+
}
34+
35+
func getSignaller(ctx context.Context) *signaller {
36+
signaller, _ := ctx.Value(signallerCtxKey).(*signaller)
37+
return signaller
38+
}
39+
40+
func withSignaller(ctx context.Context) (*signaller, context.Context) {
41+
parent := getSignaller(ctx)
42+
s := &signaller{
43+
children: make(map[*signaller]struct{}),
44+
signalChans: make(map[error]chan struct{}),
45+
}
46+
ctx = context.WithValue(ctx, signallerCtxKey, s)
47+
if parent == nil {
48+
return s, ctx
49+
}
50+
51+
parent.locker.Lock()
52+
s.receivedSignals = make([]error, len(parent.receivedSignals))
53+
copy(s.receivedSignals, parent.receivedSignals)
54+
parent.children[s] = struct{}{}
55+
parent.locker.Unlock()
56+
57+
runtime.SetFinalizer(s, func(s *signaller) {
58+
parent.locker.Lock()
59+
delete(parent.children, s)
60+
parent.locker.Unlock()
61+
})
62+
63+
return s, ctx
64+
}

pkg/signaller/signaller_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates.
2+
//
3+
// This source code is licensed under the MIT license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
package signaller
7+
8+
import (
9+
"context"
10+
"fmt"
11+
"sync"
12+
"testing"
13+
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
var unitTestCustomSignal0 = fmt.Errorf("unit-test custom signal #0")
18+
var unitTestCustomSignal1 = fmt.Errorf("unit-test custom signal #1")
19+
var unitTestCustomSignal2 = fmt.Errorf("unit-test custom signal #2")
20+
21+
func TestBackgroundContext(t *testing.T) {
22+
ctx := context.Background()
23+
24+
var blocked bool
25+
select {
26+
case <-Until(ctx, nil):
27+
default:
28+
blocked = true
29+
}
30+
31+
require.True(t, blocked)
32+
require.Nil(t, ReceivedSignals(ctx))
33+
}
34+
35+
func TestRace(t *testing.T) {
36+
ctx, sendSignal := WithSignal(context.Background(), unitTestCustomSignal0)
37+
var wg sync.WaitGroup
38+
readyToSelect := make(chan struct{})
39+
startSelect := make(chan struct{})
40+
wg.Add(1)
41+
go func() {
42+
defer wg.Done()
43+
ch := Until(ctx, nil)
44+
select {
45+
case <-ch:
46+
t.Fail()
47+
default:
48+
}
49+
close(readyToSelect)
50+
<-startSelect
51+
<-ch
52+
}()
53+
<-readyToSelect
54+
sendSignal()
55+
close(startSelect)
56+
wg.Wait()
57+
}
58+
59+
func TestSendSignal(t *testing.T) {
60+
ctx, pauseFunc := WithSignal(context.Background(), unitTestCustomSignal0)
61+
require.NotNil(t, ctx)
62+
63+
pauseFunc()
64+
65+
<-Until(ctx, unitTestCustomSignal0)
66+
<-Until(ctx, nil)
67+
68+
require.Nil(t, ctx.Err())
69+
require.Equal(t, []error{unitTestCustomSignal0}, ReceivedSignals(ctx))
70+
71+
var canceled bool
72+
select {
73+
case <-ctx.Done():
74+
canceled = true
75+
default:
76+
}
77+
require.False(t, canceled)
78+
}
79+
80+
func TestSendMultipleSignals(t *testing.T) {
81+
ctx := context.Background()
82+
ctx, sendSignal0 := WithSignal(ctx, unitTestCustomSignal0)
83+
ctx, sendSignal1 := WithSignal(ctx, unitTestCustomSignal1)
84+
require.NotNil(t, ctx)
85+
86+
sendSignal1()
87+
sendSignal0()
88+
sendSignal1()
89+
sendSignal0()
90+
91+
<-Until(ctx, nil)
92+
<-Until(ctx, unitTestCustomSignal0)
93+
<-Until(ctx, unitTestCustomSignal1)
94+
95+
require.Equal(t, []error{unitTestCustomSignal1, unitTestCustomSignal0, unitTestCustomSignal1, unitTestCustomSignal0}, ReceivedSignals(ctx))
96+
}
97+
98+
func TestGrandGrandGrandChild(t *testing.T) {
99+
type myUniqueType string
100+
ctx0, sendSignal0 := WithSignal(context.Background(), unitTestCustomSignal0)
101+
ctx1, _ := WithSignal(context.WithValue(ctx0, myUniqueType("someKey1"), "someValue1"), unitTestCustomSignal1)
102+
ctx2, _ := WithSignal(context.WithValue(ctx1, myUniqueType("someKey2"), "someValue2"), unitTestCustomSignal2)
103+
104+
require.False(t, IsSignaledWith(ctx2, unitTestCustomSignal0))
105+
sendSignal0()
106+
<-Until(ctx2, unitTestCustomSignal0)
107+
require.True(t, IsSignaledWith(ctx2, unitTestCustomSignal0))
108+
require.False(t, IsSignaledWith(ctx2, unitTestCustomSignal1))
109+
110+
select {
111+
case <-Until(ctx2, unitTestCustomSignal1):
112+
require.FailNow(t, "unexpected closed chan for signal1")
113+
case <-Until(ctx2, unitTestCustomSignal2):
114+
require.FailNow(t, "unexpected closed chan for signal2")
115+
default:
116+
}
117+
118+
}

pkg/signaller/until.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package signaller
2+
3+
import (
4+
"context"
5+
)
6+
7+
// Until works similar to Done(), but it is possible to specify specific
8+
// signal to wait for.
9+
//
10+
// If signal is nil, then waits for any event.
11+
func Until(ctx context.Context, signal error) <-chan struct{} {
12+
s := getSignaller(ctx)
13+
if s == nil {
14+
return openChan
15+
}
16+
return s.getSignalChan(signal)
17+
}
18+
19+
func (s *signaller) getSignalChan(signal error) <-chan struct{} {
20+
if signal == nil {
21+
signal = broadcastSignal
22+
}
23+
24+
s.locker.Lock()
25+
defer s.locker.Unlock()
26+
27+
if (signal == broadcastSignal && s.isSignaledWith()) || (signal != broadcastSignal && s.isSignaledWith(signal)) {
28+
return closedChan
29+
}
30+
31+
if s.signalChans[signal] == nil {
32+
s.signalChans[signal] = make(chan struct{})
33+
}
34+
return s.signalChans[signal]
35+
}

pkg/signaller/with_signal.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package signaller
2+
3+
import (
4+
"context"
5+
)
6+
7+
// WithSignal is similar to context.WithCancel, but it does not close the context
8+
// and sends a signal to the Context which use observable using Until and IsSignaledWith
9+
// functions.
10+
func WithSignal(ctx context.Context, signal error, moreSignals ...error) (context.Context, context.CancelFunc) {
11+
s, ctx := withSignaller(ctx)
12+
return ctx, s.signalFunc(signal, moreSignals...)
13+
}
14+
15+
func (s *signaller) signalFunc(signal error, moreSignals ...error) context.CancelFunc {
16+
return func() {
17+
s.signal(signal, moreSignals...)
18+
}
19+
}
20+
21+
func (s *signaller) signal(signal error, moreSignals ...error) {
22+
s.locker.Lock()
23+
defer s.locker.Unlock()
24+
25+
s.receivedSignals = append(s.receivedSignals, signal)
26+
s.receivedSignals = append(s.receivedSignals, moreSignals...)
27+
28+
s.sendOneSignal(signal)
29+
for _, signal := range moreSignals {
30+
s.sendOneSignal(signal)
31+
}
32+
s.sendOneSignal(broadcastSignal)
33+
34+
for child := range s.children {
35+
child.signal(signal, moreSignals...)
36+
}
37+
}
38+
39+
func (s *signaller) sendOneSignal(signal error) {
40+
if s.signalChans[signal] == nil {
41+
return
42+
}
43+
close(s.signalChans[signal])
44+
s.signalChans[signal] = nil
45+
}

pkg/signaller/without_signaller.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package signaller
2+
3+
import (
4+
"context"
5+
)
6+
7+
// WithoutSignaller derives a context without any signals setup.
8+
func WithoutSignaller(ctx context.Context) context.Context {
9+
return context.WithValue(ctx, signallerCtxKey, nil)
10+
}

pkg/signals/paused.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package signals
2+
3+
type pausedT struct{}
4+
5+
func (pausedT) Error() string { return "job is paused" }
6+
7+
var Paused = error(pausedT{})

0 commit comments

Comments
 (0)