Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/communication.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl From<UStatus> for RegistrationError {
}

/// General options that clients might want to specify when sending a uProtocol message.
#[derive(Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq)]
pub struct CallOptions {
ttl: u32,
message_id: Option<UUID>,
Expand Down
227 changes: 188 additions & 39 deletions src/communication/in_memory_rpc_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,9 @@ use super::{

fn handle_response_message(response: UMessage) -> Result<Option<UPayload>, ServiceInvocationError> {
let Some(attribs) = response.attributes.as_ref() else {
return Err(ServiceInvocationError::RpcError(UStatus::fail_with_code(
UCode::INTERNAL,
"response message does not contain attributes",
)));
return Err(ServiceInvocationError::InvalidArgument(
"response message does not contain attributes".to_string(),
));
};

match attribs.commstatus.map(|v| v.enum_value_or_default()) {
Expand Down Expand Up @@ -98,14 +97,19 @@ impl ResponseListener {
// channel seems to be closed already
debug!(
request_id = reqid.to_hyphenated_string(),
"failed to deliver response message, channel already closed"
"failed to deliver RPC Response message, channel already closed"
);
} else {
debug!(
request_id = reqid.to_hyphenated_string(),
"successfully delivered RPC Response message"
)
}
} else {
// we seem to have received a duplicate of the response message, ignoring it ...
debug!(
request_id = reqid.to_hyphenated_string(),
"ignoring response message for unknown request"
"ignoring (duplicate?) RPC Response message with unknown request ID"
);
}
}
Expand Down Expand Up @@ -249,14 +253,35 @@ impl RpcClient for InMemoryRpcClient {
self.response_listener.remove_pending_request(&message_id);
e
})?;
debug!(
request_id = message_id.to_hyphenated_string(),
ttl = call_options.ttl(),
"successfully sent RPC Request message"
);

if let Ok(Ok(response_message)) =
timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await
{
handle_response_message(response_message)
} else {
self.response_listener.remove_pending_request(&message_id);
Err(ServiceInvocationError::DeadlineExceeded)
match timeout(Duration::from_millis(call_options.ttl() as u64), receiver).await {
Err(_) => {
debug!(
request_id = message_id.to_hyphenated_string(),
ttl = call_options.ttl(),
"invocation of service operation has timed out"
);
self.response_listener.remove_pending_request(&message_id);
Err(ServiceInvocationError::DeadlineExceeded)
}
Ok(result) => match result {
Ok(response_message) => handle_response_message(response_message),
Err(_e) => {
debug!(
request_id = message_id.to_hyphenated_string(),
"response listener failed to forward response message"
);
self.response_listener.remove_pending_request(&message_id);
Err(ServiceInvocationError::Internal(
"error receiving response message".to_string(),
))
}
},
}
}
}
Expand All @@ -267,6 +292,7 @@ mod tests {
use super::*;

use protobuf::{well_known_types::wrappers::StringValue, Enum};
use tokio::{join, sync::Notify};

use crate::{
utransport::{MockLocalUriProvider, MockTransport},
Expand Down Expand Up @@ -362,20 +388,24 @@ mod tests {
Some(crate::UPriority::UPRIORITY_CS6),
);

let (captured_listener_tx, captured_listener_rx) = std::sync::mpsc::channel();
let (captured_listener_tx, captured_listener_rx) = tokio::sync::oneshot::channel();
let request_sent = Arc::new(Notify::new());
let request_sent_clone = request_sent.clone();

// GIVEN an RPC client
let mut mock_transport = MockTransport::default();
mock_transport.expect_do_register_listener().returning(
move |_source_filter, _sink_filter, listener| {
mock_transport
.expect_do_register_listener()
.once()
.return_once(move |_source_filter, _sink_filter, listener| {
captured_listener_tx
.send(listener)
.map_err(|_e| UStatus::fail("cannot capture listener"))
},
);
});
let expected_message_id = message_id.clone();
mock_transport
.expect_do_send()
.once()
.withf(move |request_message| {
request_message
.attributes
Expand All @@ -387,42 +417,152 @@ mod tests {
&& attribs.token == Some("my_token".to_string())
})
})
.returning(move |request_message| {
let request_payload: StringValue = request_message.extract_protobuf().unwrap();
let response_payload = StringValue {
value: format!("Hello {}", request_payload.value),
..Default::default()
};
.returning(move |_request_message| {
request_sent_clone.notify_one();
Ok(())
});

let response_message = UMessageBuilder::response_for_request(
request_message.attributes.as_ref().unwrap(),
let uri_provider = new_uri_provider();
let rpc_client = Arc::new(
InMemoryRpcClient::new(Arc::new(mock_transport), uri_provider.clone())
.await
.unwrap(),
);
let client: Arc<dyn RpcClient> = rpc_client.clone();

// WHEN invoking a remote service operation
let response_handle = tokio::spawn(async move {
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
client
.invoke_proto_method::<_, StringValue>(
service_method_uri(),
call_options,
request_payload,
)
.build_with_protobuf_payload(&response_payload)
.unwrap();
let captured_listener = captured_listener_rx.recv().unwrap().to_owned();
tokio::spawn(async move { captured_listener.on_receive(response_message).await });
.await
});

// AND the remote service sends the corresponding RPC Response message
let response_payload = StringValue {
value: "Hello World".to_string(),
..Default::default()
};
let response_message = UMessageBuilder::response(
uri_provider.get_source_uri(),
message_id.clone(),
service_method_uri(),
)
.build_with_protobuf_payload(&response_payload)
.unwrap();

// wait for the RPC Request message having been sent
let (response_listener_result, _) = join!(captured_listener_rx, request_sent.notified());
let response_listener = response_listener_result.unwrap();

// send the RPC Response message which completes the request
let cloned_response_message = response_message.clone();
let cloned_response_listener = response_listener.clone();
tokio::spawn(async move {
cloned_response_listener
.on_receive(cloned_response_message)
.await
});

// THEN the response contains the expected payload
let response = response_handle.await.unwrap();
assert!(response.is_ok_and(|payload| payload.value == *"Hello World"));
assert!(!rpc_client.contains_pending_request(&message_id));

// AND if the remote service sends its response message again
response_listener.on_receive(response_message).await;
// the duplicate response is silently ignored
assert!(!rpc_client.contains_pending_request(&message_id));
}

#[tokio::test]
async fn test_invoke_method_fails_on_repeated_invocation() {
let message_id = UUID::build();
let first_request_sent = Arc::new(Notify::new());
let first_request_sent_clone = first_request_sent.clone();

// GIVEN an RPC client
let mut mock_transport = MockTransport::default();
mock_transport
.expect_do_register_listener()
.once()
.return_const(Ok(()));
let expected_message_id = message_id.clone();
mock_transport
.expect_do_send()
.once()
.withf(move |request_message| {
request_message
.attributes
.as_ref()
.map_or(false, |attribs| {
attribs.id.as_ref() == Some(&expected_message_id)
})
})
.returning(move |_request_message| {
first_request_sent_clone.notify_one();
Ok(())
});

let rpc_client = Arc::new(
let in_memory_rpc_client = Arc::new(
InMemoryRpcClient::new(Arc::new(mock_transport), new_uri_provider())
.await
.unwrap(),
);
let client: Arc<dyn RpcClient> = rpc_client.clone();
let rpc_client: Arc<dyn RpcClient> = in_memory_rpc_client.clone();

// WHEN invoking a remote service operation
let call_options =
CallOptions::for_rpc_request(5_000, Some(message_id.clone()), None, None);
let cloned_call_options = call_options.clone();
let cloned_rpc_client = rpc_client.clone();

tokio::spawn(async move {
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
cloned_rpc_client
.invoke_proto_method::<_, StringValue>(
service_method_uri(),
cloned_call_options,
request_payload,
)
.await
});

// we wait for the first request message having been sent via the transport
// in order to be sure that the pending request has been added to the client's
// internal state
first_request_sent.notified().await;

// AND invoking the same operation before the response to the first request has arrived
let request_payload = StringValue {
value: "World".to_string(),
..Default::default()
};
let response: StringValue = client
.invoke_proto_method(service_method_uri(), call_options, request_payload)
.await
.expect("invoking method should have succeeded");
// THEN the response contains the expected payload
assert_eq!(response.value, "Hello World");
assert!(!rpc_client.contains_pending_request(&message_id));
let second_request_handle = tokio::spawn(async move {
rpc_client
.invoke_proto_method::<_, StringValue>(
service_method_uri(),
call_options,
request_payload,
)
.await
});

// THEN the second invocation fails
let response = second_request_handle.await.unwrap();
assert!(response.is_err_and(|e| matches!(e, ServiceInvocationError::AlreadyExists(_))));
// because there is a pending request for the message ID used in both requests
assert!(in_memory_rpc_client.contains_pending_request(&message_id));
}

#[tokio::test]
Expand Down Expand Up @@ -498,4 +638,13 @@ mod tests {
assert!(response.is_err_and(|e| { matches!(e, ServiceInvocationError::DeadlineExceeded) }));
assert!(!client.contains_pending_request(&message_id));
}

#[test]
fn test_handle_response_message_fails_for_missing_attributes() {
let response_msg = UMessage {
..Default::default()
};
let result = handle_response_message(response_msg);
assert!(result.is_err_and(|e| matches!(e, ServiceInvocationError::InvalidArgument(_))));
}
}