diff --git a/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java b/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java index 49a1a2819de..f282b0d9c41 100644 --- a/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java +++ b/spring-integration-core/src/main/java/org/springframework/integration/scattergather/ScatterGatherHandler.java @@ -20,6 +20,7 @@ import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanInitializationException; import org.springframework.context.Lifecycle; +import org.springframework.integration.channel.ChannelInterceptorAware; import org.springframework.integration.channel.FixedSubscriberChannel; import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.channel.ReactiveStreamsSubscribableChannel; @@ -30,7 +31,6 @@ import org.springframework.integration.endpoint.PollingConsumer; import org.springframework.integration.endpoint.ReactiveStreamsConsumer; import org.springframework.integration.handler.AbstractReplyProducingMessageHandler; -import org.springframework.integration.support.channel.HeaderChannelRegistry; import org.springframework.messaging.Message; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.MessageDeliveryException; @@ -38,6 +38,7 @@ import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.PollableChannel; import org.springframework.messaging.SubscribableChannel; +import org.springframework.messaging.support.ChannelInterceptor; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -66,8 +67,6 @@ public class ScatterGatherHandler extends AbstractReplyProducingMessageHandler i private AbstractEndpoint gatherEndpoint; - private HeaderChannelRegistry replyChannelRegistry; - public ScatterGatherHandler(MessageHandler scatterer, MessageHandler gatherer) { this(new FixedSubscriberChannel(scatterer), gatherer); @@ -134,52 +133,64 @@ else if (this.gatherChannel instanceof ReactiveStreamsSubscribableChannel) { ((MessageProducer) this.gatherer) .setOutputChannel(new FixedSubscriberChannel(message -> { MessageHeaders headers = message.getHeaders(); - if (headers.containsKey(GATHER_RESULT_CHANNEL)) { - Object gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL); - if (gatherResultChannel instanceof MessageChannel) { - messagingTemplate.send((MessageChannel) gatherResultChannel, message); - } - else if (gatherResultChannel instanceof String) { - messagingTemplate.send((String) gatherResultChannel, message); - } + MessageChannel gatherResultChannel = headers.get(GATHER_RESULT_CHANNEL, MessageChannel.class); + if (gatherResultChannel != null) { + this.messagingTemplate.send(gatherResultChannel, message); } else { throw new MessageDeliveryException(message, - "The 'gatherResultChannel' header is required to delivery gather result."); + "The 'gatherResultChannel' header is required to deliver the gather result."); } })); - - this.replyChannelRegistry = - beanFactory.getBean(IntegrationContextUtils.INTEGRATION_HEADER_CHANNEL_REGISTRY_BEAN_NAME, - HeaderChannelRegistry.class); } @Override protected Object handleRequestMessage(Message requestMessage) { PollableChannel gatherResultChannel = new QueueChannel(); - Object gatherResultChannelName = this.replyChannelRegistry.channelToChannelName(gatherResultChannel); + MessageChannel replyChannel = this.gatherChannel; + + if (replyChannel instanceof ChannelInterceptorAware) { + ((ChannelInterceptorAware) replyChannel) + .addInterceptor(0, + new ChannelInterceptor() { + + @Override + public Message preSend(Message message, MessageChannel channel) { + return enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage); + } + + }); + } + else { + replyChannel = + new FixedSubscriberChannel(message -> + this.messagingTemplate.send(this.gatherChannel, + enhanceScatterReplyMessage(message, gatherResultChannel, requestMessage))); + } Message scatterMessage = getMessageBuilderFactory() .fromMessage(requestMessage) - .setHeader(GATHER_RESULT_CHANNEL, gatherResultChannelName) - .setReplyChannel(this.gatherChannel) + .setReplyChannel(replyChannel) .setErrorChannelName(this.errorChannelName) .build(); this.messagingTemplate.send(this.scatterChannel, scatterMessage); - Message gatherResult = gatherResultChannel.receive(this.gatherTimeout); - if (gatherResult != null) { - return getMessageBuilderFactory() - .fromMessage(gatherResult) - .removeHeader(GATHER_RESULT_CHANNEL) - .setHeader(MessageHeaders.REPLY_CHANNEL, requestMessage.getHeaders().getReplyChannel()) - .setHeader(MessageHeaders.ERROR_CHANNEL, requestMessage.getHeaders().getErrorChannel()); - } + return gatherResultChannel.receive(this.gatherTimeout); + } + + private Message enhanceScatterReplyMessage(Message message, PollableChannel gatherResultChannel, + Message requestMessage) { - return null; + MessageHeaders requestMessageHeaders = requestMessage.getHeaders(); + return getMessageBuilderFactory() + .fromMessage(message) + .setHeader(GATHER_RESULT_CHANNEL, gatherResultChannel) + .setHeader(MessageHeaders.REPLY_CHANNEL, requestMessageHeaders.getReplyChannel()) + .setHeader(MessageHeaders.ERROR_CHANNEL, requestMessageHeaders.getErrorChannel()) + .build(); } @Override @@ -201,11 +212,11 @@ public boolean isRunning() { return this.gatherEndpoint == null || this.gatherEndpoint.isRunning(); } - private void checkClass(Class gathererClass, String className, String type) throws LinkageError { + private static void checkClass(Class gathererClass, String className, String type) throws LinkageError { try { Class clazz = ClassUtils.forName(className, ClassUtils.getDefaultClassLoader()); - Assert.isAssignable(clazz, gathererClass, () -> "the '" + type + "' must be an " + className + " " + - "instance"); + Assert.isAssignable(clazz, gathererClass, + () -> "the '" + type + "' must be an " + className + " " + "instance"); } catch (ClassNotFoundException e) { throw new IllegalStateException("The class for '" + className + "' cannot be loaded", e); diff --git a/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java b/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java index ac58a07ed04..4eab6a7cade 100644 --- a/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java +++ b/spring-integration-core/src/test/java/org/springframework/integration/dsl/routers/RouterTests.java @@ -16,6 +16,7 @@ package org.springframework.integration.dsl.routers; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; @@ -29,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; import java.util.stream.Collectors; import org.junit.Test; @@ -590,6 +592,16 @@ public void testScatterGatherWithExecutorChannelSubFlow() { assertThat(((List) payload).get(1), instanceOf(RuntimeException.class)); } + @Autowired + @Qualifier("propagateErrorFromGatherer.gateway") + private Function propagateErrorFromGathererGateway; + + @Test + public void propagateErrorFromGatherer() { + assertThatThrownBy(() -> propagateErrorFromGathererGateway.apply("bar")) + .hasMessage("intentional"); + } + @Configuration @EnableIntegration @EnableMessageHistory({ "recipientListOrder*", "recipient1*", "recipient2*" }) @@ -881,6 +893,22 @@ public Message processAsyncScatterError(MessagingException payload) { .build(); } + @Bean + public IntegrationFlow propagateErrorFromGatherer(TaskExecutor taskExecutor) { + return IntegrationFlows.from(Function.class) + .scatterGather(s -> s + .applySequence(true) + .recipientFlow(subFlow -> subFlow + .channel(c -> c.executor(taskExecutor)) + .transform(p -> "foo")), + g -> g + .outputProcessor(group -> { + throw new RuntimeException("intentional"); + }), + sg -> sg.gatherTimeout(100)) + .get(); + } + } private static class RoutingTestBean { diff --git a/spring-integration-jmx/src/test/java/org/springframework/integration/monitor/ScatterGatherHandlerIntegrationTests.java b/spring-integration-jmx/src/test/java/org/springframework/integration/monitor/ScatterGatherHandlerIntegrationTests.java index 3cb22e8e521..ac5a1e5a2d0 100644 --- a/spring-integration-jmx/src/test/java/org/springframework/integration/monitor/ScatterGatherHandlerIntegrationTests.java +++ b/spring-integration-jmx/src/test/java/org/springframework/integration/monitor/ScatterGatherHandlerIntegrationTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2014-2015 the original author or authors. + * Copyright 2014-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import java.util.Arrays; import java.util.List; -import java.util.concurrent.Executors; +import java.util.concurrent.Executor; import org.junit.Test; import org.junit.runner.RunWith; @@ -278,8 +278,8 @@ public MessageChannel gatherChannel() { } @Bean - public SubscribableChannel scatterAuctionWithGatherChannel() { - PublishSubscribeChannel channel = new PublishSubscribeChannel(Executors.newCachedThreadPool()); + public SubscribableChannel scatterAuctionWithGatherChannel(Executor executor) { + PublishSubscribeChannel channel = new PublishSubscribeChannel(executor); channel.setApplySequence(true); return channel; } @@ -296,7 +296,8 @@ public MessageHandler gatherer2() { @Bean @ServiceActivator(inputChannel = "inputAuctionWithGatherChannel") public MessageHandler scatterGatherAuctionWithGatherChannel() { - ScatterGatherHandler handler = new ScatterGatherHandler(scatterAuctionWithGatherChannel(), gatherer2()); + ScatterGatherHandler handler = + new ScatterGatherHandler(scatterAuctionWithGatherChannel(null), gatherer2()); handler.setGatherChannel(gatherChannel()); handler.setOutputChannel(output()); return handler; diff --git a/src/reference/asciidoc/scatter-gather.adoc b/src/reference/asciidoc/scatter-gather.adoc index f00fa296af8..661c191ea91 100644 --- a/src/reference/asciidoc/scatter-gather.adoc +++ b/src/reference/asciidoc/scatter-gather.adoc @@ -206,3 +206,8 @@ public Message processAsyncScatterError(MessagingException payload) { To produce a proper reply, we have to copy headers (including `replyChannel` and `errorChannel`) from the `failedMessage` of the `MessagingException` that has been sent to the `scatterGatherErrorChannel` by the `MessagePublishingErrorHandler`. This way the target exception is returned to the gatherer of the `ScatterGatherHandler` for reply messages group completion. Such an exception `payload` can be filtered out in the `MessageGroupProcessor` of the gatherer or processed other way downstream, after the scatter-gather endpoint. + +NOTE: Before sending scattering results to the gatherer, `ScatterGatherHandler` reinstates the request message headers, including reply and error channels if any. +This way errors from the `AggregatingMessageHandler` are going to be propagated to the caller, even if async an hand off is applied in scatter recipient subflows. +In this case a reasonable, finite `gatherTimeout` must be configured for the `ScatterGatherHandler`. +Otherwise it is going to be blocked waiting for a reply from the gatherer forever, by default.