Skip to content

Commit 2756b1c

Browse files
committed
Add AdditionalDataSubscriber to allow users to send additional data to the downstream subscriber
1 parent 0326bf1 commit 2756b1c

File tree

3 files changed

+302
-0
lines changed

3 files changed

+302
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.utils.async;
17+
18+
import java.util.concurrent.atomic.AtomicBoolean;
19+
import java.util.concurrent.atomic.AtomicLong;
20+
import java.util.function.Supplier;
21+
import org.reactivestreams.Subscriber;
22+
import org.reactivestreams.Subscription;
23+
import software.amazon.awssdk.annotations.SdkProtectedApi;
24+
import software.amazon.awssdk.utils.Logger;
25+
import software.amazon.awssdk.utils.Validate;
26+
27+
/**
28+
* Allows to send trailing data before invoking onComplete on the downstream subscriber.
29+
* If the trailingDataSupplier returns null, this class will invoke onComplete directly
30+
*/
31+
@SdkProtectedApi
32+
public class AddingTrailingDataSubscriber<T> extends DelegatingSubscriber<T, T> {
33+
private static final Logger log = Logger.loggerFor(AddingTrailingDataSubscriber.class);
34+
35+
/**
36+
* The subscription to the upstream subscriber.
37+
*/
38+
private Subscription upstreamSubscription;
39+
40+
/**
41+
* The amount of unfulfilled demand the downstream subscriber has opened against us.
42+
*/
43+
private final AtomicLong downstreamDemand = new AtomicLong(0);
44+
45+
/**
46+
* Whether the upstream subscriber has called onComplete on us.
47+
*/
48+
private volatile boolean onCompleteCalledByUpstream = false;
49+
50+
/**
51+
* Whether the upstream subscriber has called onError on us.
52+
*/
53+
private volatile boolean onErrorCalledByUpstream = false;
54+
55+
/**
56+
* Whether we have called onComplete on the downstream subscriber.
57+
*/
58+
private AtomicBoolean onCompleteCalledOnDownstream = new AtomicBoolean(false);
59+
60+
private final Supplier<T> trailingDataSupplier;
61+
private volatile T trailingData;
62+
63+
public AddingTrailingDataSubscriber(Subscriber<? super T> subscriber,
64+
Supplier<T> trailingDataSupplier) {
65+
super(Validate.paramNotNull(subscriber, "subscriber"));
66+
this.trailingDataSupplier = Validate.paramNotNull(trailingDataSupplier, "trailingDataSupplier");
67+
}
68+
69+
@Override
70+
public void onSubscribe(Subscription subscription) {
71+
72+
if (upstreamSubscription != null) {
73+
log.warn(() -> "Received duplicate subscription, cancelling the duplicate.", new IllegalStateException());
74+
subscription.cancel();
75+
return;
76+
}
77+
78+
upstreamSubscription = subscription;
79+
80+
subscriber.onSubscribe(new Subscription() {
81+
82+
@Override
83+
public void request(long l) {
84+
if (onErrorCalledByUpstream) {
85+
return;
86+
}
87+
88+
if (onCompleteCalledByUpstream) {
89+
sendTrailingDataIfNeededAndComplete();
90+
return;
91+
}
92+
93+
addDownstreamDemand(l);
94+
upstreamSubscription.request(l);
95+
}
96+
97+
@Override
98+
public void cancel() {
99+
upstreamSubscription.cancel();
100+
}
101+
});
102+
}
103+
104+
@Override
105+
public void onError(Throwable throwable) {
106+
onErrorCalledByUpstream = true;
107+
subscriber.onError(throwable);
108+
}
109+
110+
@Override
111+
public void onNext(T t) {
112+
Validate.paramNotNull(t, "item");
113+
downstreamDemand.decrementAndGet();
114+
subscriber.onNext(t);
115+
}
116+
117+
@Override
118+
public void onComplete() {
119+
onCompleteCalledByUpstream = true;
120+
121+
trailingData = trailingDataSupplier.get();
122+
if (trailingData == null || downstreamDemand.get() > 0) {
123+
sendTrailingDataIfNeededAndComplete();
124+
}
125+
}
126+
127+
private void addDownstreamDemand(long l) {
128+
129+
if (l > 0) {
130+
downstreamDemand.getAndUpdate(current -> {
131+
long newValue = current + l;
132+
return newValue >= 0 ? newValue : Long.MAX_VALUE;
133+
});
134+
} else {
135+
upstreamSubscription.cancel();
136+
onError(new IllegalArgumentException("Demand must not be negative"));
137+
}
138+
}
139+
140+
private void sendTrailingDataIfNeededAndComplete() {
141+
if (onCompleteCalledOnDownstream.compareAndSet(false, true)) {
142+
if (trailingData != null) {
143+
subscriber.onNext(trailingData);
144+
}
145+
subscriber.onComplete();
146+
}
147+
}
148+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.utils.async;
17+
18+
import java.util.concurrent.CompletableFuture;
19+
import org.reactivestreams.Subscriber;
20+
import org.reactivestreams.Subscription;
21+
import org.reactivestreams.tck.SubscriberWhiteboxVerification;
22+
import org.reactivestreams.tck.TestEnvironment;
23+
24+
public class AddingTrailingDataSubscriberTckTest extends SubscriberWhiteboxVerification<Integer> {
25+
protected AddingTrailingDataSubscriberTckTest() {
26+
super(new TestEnvironment());
27+
}
28+
29+
@Override
30+
public Subscriber<Integer> createSubscriber(WhiteboxSubscriberProbe<Integer> probe) {
31+
Subscriber<Integer> foo = new SequentialSubscriber<>(s -> {}, new CompletableFuture<>());
32+
33+
return new AddingTrailingDataSubscriber<Integer>(foo, () -> Integer.MIN_VALUE) {
34+
@Override
35+
public void onError(Throwable throwable) {
36+
super.onError(throwable);
37+
probe.registerOnError(throwable);
38+
}
39+
40+
@Override
41+
public void onSubscribe(Subscription subscription) {
42+
super.onSubscribe(subscription);
43+
probe.registerOnSubscribe(new SubscriberPuppet() {
44+
@Override
45+
public void triggerRequest(long elements) {
46+
subscription.request(elements);
47+
}
48+
49+
@Override
50+
public void signalCancel() {
51+
subscription.cancel();
52+
}
53+
});
54+
}
55+
56+
@Override
57+
public void onNext(Integer nextItem) {
58+
super.onNext(nextItem);
59+
probe.registerOnNext(nextItem);
60+
}
61+
62+
@Override
63+
public void onComplete() {
64+
super.onComplete();
65+
probe.registerOnComplete();
66+
}
67+
};
68+
}
69+
70+
@Override
71+
public Integer createElement(int i) {
72+
return i;
73+
}
74+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.utils.async;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
20+
21+
import java.util.ArrayList;
22+
import java.util.List;
23+
import java.util.concurrent.CompletableFuture;
24+
import org.junit.jupiter.api.Test;
25+
import org.reactivestreams.Subscriber;
26+
27+
public class AddingTrailingDataSubscriberTest {
28+
29+
@Test
30+
void trailingDataSupplierNull_shouldThrowException() {
31+
SequentialSubscriber<Integer> downstreamSubscriber = new SequentialSubscriber<Integer>(i -> {}, new CompletableFuture());
32+
assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(downstreamSubscriber, null))
33+
.hasMessageContaining("must not be null");
34+
}
35+
36+
@Test
37+
void subscriberNull_shouldThrowException() {
38+
assertThatThrownBy(() -> new AddingTrailingDataSubscriber<>(null, () -> 1))
39+
.hasMessageContaining("must not be null");
40+
}
41+
42+
@Test
43+
void trailingDataNotNull_shouldNotSendAdditionalData() {
44+
List<Integer> result = new ArrayList<>();
45+
CompletableFuture future = new CompletableFuture();
46+
SequentialSubscriber<Integer> downstreamSubscriber = new SequentialSubscriber<Integer>(i -> result.add(i), future);
47+
48+
Subscriber<Integer> subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> Integer.MAX_VALUE);
49+
50+
publishData(subscriber);
51+
52+
future.join();
53+
54+
assertThat(result).containsExactly(0, 1, 2, Integer.MAX_VALUE);
55+
}
56+
57+
@Test
58+
void trailingDataNull_shouldNotSendAdditionalData() {
59+
List<Integer> result = new ArrayList<>();
60+
CompletableFuture future = new CompletableFuture();
61+
SequentialSubscriber<Integer> downstreamSubscriber = new SequentialSubscriber<Integer>(i -> result.add(i), future);
62+
63+
Subscriber<Integer> subscriber = new AddingTrailingDataSubscriber<>(downstreamSubscriber, () -> null);
64+
65+
publishData(subscriber);
66+
67+
future.join();
68+
69+
assertThat(result).containsExactly(0, 1, 2);
70+
}
71+
72+
private void publishData(Subscriber<Integer> subscriber) {
73+
SimplePublisher<Integer> simplePublisher = new SimplePublisher<>();
74+
simplePublisher.subscribe(subscriber);
75+
for (int i = 0; i < 3; i++) {
76+
simplePublisher.send(i);
77+
}
78+
simplePublisher.complete();
79+
}
80+
}

0 commit comments

Comments
 (0)