diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java index 6712b00687d..a6f76bdfd18 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/inbound/AbstractMqttMessageDrivenChannelAdapter.java @@ -94,14 +94,22 @@ public AbstractMqttMessageDrivenChannelAdapter(ClientManager clientManager this.clientId = null; } - private static Map initTopics(String[] topic) { - Assert.notNull(topic, "'topics' cannot be null"); - Assert.noNullElements(topic, "'topics' cannot have null elements"); + private static Map initTopics(String[] topics) { + validateTopics(topics); - return Arrays.stream(topic) + return Arrays.stream(topics) .collect(Collectors.toMap(Function.identity(), (key) -> 1, (x, y) -> y, LinkedHashMap::new)); } + private static void validateTopics(String[] topics) { + Assert.notNull(topics, "'topics' cannot be null"); + Assert.noNullElements(topics, "'topics' cannot have null elements"); + + for (String topic : topics) { + Assert.hasText(topic, "The topic to subscribe cannot be empty string"); + } + } + public void setConverter(MqttMessageConverter converter) { Assert.notNull(converter, "'converter' cannot be null"); this.converter = converter; @@ -178,7 +186,7 @@ public String[] getTopic() { /** * Set the completion timeout when disconnecting. - * Default {@value #DISCONNECT_COMPLETION_TIMEOUT} milliseconds. + * Default {@value ClientManager#DISCONNECT_COMPLETION_TIMEOUT} milliseconds. * @param completionTimeout The timeout. * @since 5.1.10 */ @@ -256,6 +264,7 @@ protected long getCompletionTimeout() { */ @ManagedOperation public void addTopic(String topic, int qos) { + validateTopics(new String[] {topic}); this.topicLock.lock(); try { if (this.topics.containsKey(topic)) { @@ -271,16 +280,16 @@ public void addTopic(String topic, int qos) { /** * Add a topic (or topics) to the subscribed list (qos=1). - * @param topic The topics. - * @throws MessagingException if the topic is already in the list. + * @param topics The topics. + * @throws MessagingException if the topics is already in the list. * @since 4.1 */ @ManagedOperation - public void addTopic(String... topic) { - Assert.notNull(topic, "'topic' cannot be null"); + public void addTopic(String... topics) { + validateTopics(topics); this.topicLock.lock(); try { - for (String t : topic) { + for (String t : topics) { addTopic(t, 1); } } @@ -291,25 +300,24 @@ public void addTopic(String... topic) { /** * Add topics to the subscribed list. - * @param topic The topics. + * @param topics The topics. * @param qos The qos for each topic. - * @throws MessagingException if a topic is already in the list. + * @throws MessagingException if a topics is already in the list. * @since 4.1 */ @ManagedOperation - public void addTopics(String[] topic, int[] qos) { - Assert.notNull(topic, "'topic' cannot be null."); - Assert.noNullElements(topic, "'topic' cannot contain any null elements."); - Assert.isTrue(topic.length == qos.length, "topic and qos arrays must the be the same length."); + public void addTopics(String[] topics, int[] qos) { + validateTopics(topics); + Assert.isTrue(topics.length == qos.length, "topics and qos arrays must the be the same length."); this.topicLock.lock(); try { - for (String newTopic : topic) { + for (String newTopic : topics) { if (this.topics.containsKey(newTopic)) { throw new MessagingException("Topic '" + newTopic + "' is already subscribed."); } } - for (int i = 0; i < topic.length; i++) { - addTopic(topic[i], qos[i]); + for (int i = 0; i < topics.length; i++) { + addTopic(topics[i], qos[i]); } } finally { diff --git a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java index 6054f6b7fe3..4b43a6722c9 100644 --- a/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java +++ b/spring-integration-mqtt/src/main/java/org/springframework/integration/mqtt/outbound/AbstractMqttMessageHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2022 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; @@ -126,6 +127,7 @@ protected ApplicationEventPublisher getApplicationEventPublisher() { * @param defaultTopic the default topic. */ public void setDefaultTopic(String defaultTopic) { + Assert.hasText(defaultTopic, "'defaultTopic' must not be empty"); this.defaultTopic = defaultTopic; } @@ -320,14 +322,17 @@ protected ClientManager getClientManager() { @Override protected void onInit() { super.onInit(); - if (this.topicProcessor instanceof BeanFactoryAware && getBeanFactory() != null) { - ((BeanFactoryAware) this.topicProcessor).setBeanFactory(getBeanFactory()); - } - if (this.qosProcessor instanceof BeanFactoryAware && getBeanFactory() != null) { - ((BeanFactoryAware) this.qosProcessor).setBeanFactory(getBeanFactory()); - } - if (this.retainedProcessor instanceof BeanFactoryAware && getBeanFactory() != null) { - ((BeanFactoryAware) this.retainedProcessor).setBeanFactory(getBeanFactory()); + BeanFactory beanFactory = getBeanFactory(); + if (beanFactory != null) { + if (this.topicProcessor instanceof BeanFactoryAware beanFactoryAware) { + beanFactoryAware.setBeanFactory(beanFactory); + } + if (this.qosProcessor instanceof BeanFactoryAware beanFactoryAware) { + beanFactoryAware.setBeanFactory(beanFactory); + } + if (this.retainedProcessor instanceof BeanFactoryAware beanFactoryAware) { + beanFactoryAware.setBeanFactory(beanFactory); + } } } @@ -358,11 +363,13 @@ public boolean isRunning() { protected void handleMessageInternal(Message message) { Object mqttMessage = this.converter.fromMessage(message, Object.class); String topic = this.topicProcessor.processMessage(message); - if (topic == null && this.defaultTopic == null) { - throw new IllegalStateException( - "No topic could be determined from the message and no default topic defined"); + if (topic == null) { + topic = this.defaultTopic; } - publish(topic == null ? this.defaultTopic : topic, mqttMessage, message); + + Assert.state(topic != null, "No topic could be determined from the message and no default topic defined"); + + publish(topic, mqttMessage, message); } protected abstract void publish(String topic, Object mqttMessage, Message message); diff --git a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java index ed82b6f6e43..501f37ef659 100644 --- a/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java +++ b/spring-integration-mqtt/src/test/java/org/springframework/integration/mqtt/MqttAdapterTests.java @@ -54,6 +54,7 @@ import org.springframework.integration.channel.QueueChannel; import org.springframework.integration.handler.MessageProcessor; import org.springframework.integration.mqtt.core.DefaultMqttPahoClientFactory; +import org.springframework.integration.mqtt.core.MqttPahoClientFactory; import org.springframework.integration.mqtt.core.Mqttv3ClientManager; import org.springframework.integration.mqtt.event.MqttConnectionFailedEvent; import org.springframework.integration.mqtt.event.MqttIntegrationEvent; @@ -73,6 +74,7 @@ import org.springframework.util.ReflectionUtils; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; @@ -515,6 +517,19 @@ public void testDifferentQos() throws Exception { verify(client).disconnectForcibly(5_000L); } + @Test + public void emptyTopicNotAllowed() { + assertThatIllegalArgumentException() + .isThrownBy(() -> + new MqttPahoMessageDrivenChannelAdapter("client_id", mock(MqttPahoClientFactory.class), "")) + .withMessage("The topic to subscribe cannot be empty string"); + + var adapter = new MqttPahoMessageDrivenChannelAdapter("client_id", mock(MqttPahoClientFactory.class), "topic1"); + assertThatIllegalArgumentException() + .isThrownBy(() -> adapter.addTopic("")) + .withMessage("The topic to subscribe cannot be empty string"); + } + private MqttPahoMessageDrivenChannelAdapter buildAdapterIn(final IMqttAsyncClient client, Boolean cleanSession) throws MqttException {