diff --git a/rsocket/src/core/client.rs b/rsocket/src/core/client.rs index 667d1e4..2368642 100644 --- a/rsocket/src/core/client.rs +++ b/rsocket/src/core/client.rs @@ -14,7 +14,7 @@ use crate::payload::{Payload, SetupPayload, SetupPayloadBuilder}; use crate::runtime; use crate::spi::{ClientResponder, Flux, RSocket}; use crate::transport::{ - self, Acceptor, Connection, DuplexSocket, FrameSink, FrameStream, Splitter, Transport, + self, Connection, DuplexSocket, FrameSink, FrameStream, Splitter, Transport, }; use crate::Result; @@ -127,7 +127,11 @@ where let mut socket = DuplexSocket::new(1, snd_tx, splitter).await; let mut cloned_socket = socket.clone(); - let acceptor: Option = self.responder.map(|it| Acceptor::Simple(Arc::new(it))); + + if let Some(f) = self.responder { + let responder = f(); + socket.bind_responder(responder).await; + } let conn = tp.connect().await?; let (mut sink, mut stream) = conn.split(); @@ -191,7 +195,7 @@ where runtime::spawn(async move { while let Some(next) = read_rx.next().await { - if let Err(e) = cloned_socket.dispatch(next, &acceptor).await { + if let Err(e) = cloned_socket.dispatch(next, None).await { error!("dispatch frame failed: {}", e); break; } diff --git a/rsocket/src/core/server.rs b/rsocket/src/core/server.rs index 9d48c8c..60da6d3 100644 --- a/rsocket/src/core/server.rs +++ b/rsocket/src/core/server.rs @@ -3,9 +3,7 @@ use crate::frame::{self, Frame}; use crate::payload::SetupPayload; use crate::runtime; use crate::spi::{RSocket, ServerResponder}; -use crate::transport::{ - Acceptor, Connection, DuplexSocket, ServerTransport, Splitter, Transport, MIN_MTU, -}; +use crate::transport::{Connection, DuplexSocket, ServerTransport, Splitter, Transport, MIN_MTU}; use crate::utils::EmptyRSocket; use crate::Result; use futures::{SinkExt, StreamExt}; @@ -70,7 +68,7 @@ where { pub async fn serve(mut self) -> Result<()> { let mut server_transport = self.transport.take().expect("missing transport"); - let acceptor = self.on_setup.map(|v| Acceptor::Generate(Arc::new(v))); + // let acceptor = self.on_setup.map(|v| Acceptor::Generate(Arc::new(v))); let mtu = self.mtu; @@ -80,6 +78,7 @@ where invoke(); } + let acceptor = Arc::new(self.on_setup); while let Some(next) = server_transport.next().await { match next { Ok(tp) => { @@ -99,7 +98,7 @@ where } #[inline] - async fn on_transport(mtu: usize, tp: C, acceptor: Option) -> Result<()> { + async fn on_transport(mtu: usize, tp: C, acceptor: Arc>) -> Result<()> { // Establish connection. let conn = tp.connect().await?; let (mut writer, mut reader) = conn.split(); @@ -148,7 +147,7 @@ where }); while let Some(frame) = read_rx.next().await { - if let Err(e) = socket.dispatch(frame, &acceptor).await { + if let Err(e) = socket.dispatch(frame, acceptor.as_ref().as_ref()).await { error!("dispatch incoming frame failed: {}", e); break; } diff --git a/rsocket/src/lib.rs b/rsocket/src/lib.rs index a1703b7..073d267 100644 --- a/rsocket/src/lib.rs +++ b/rsocket/src/lib.rs @@ -96,6 +96,8 @@ pub use async_stream::stream; /// A re-export of [`async-trait`](https://docs.rs/async-trait) for use with RSocket trait implementation. pub use async_trait::async_trait; +#[macro_use] +extern crate anyhow; #[macro_use] extern crate log; #[macro_use] diff --git a/rsocket/src/transport/socket.rs b/rsocket/src/transport/socket.rs index b3e3b6b..261f88e 100644 --- a/rsocket/src/transport/socket.rs +++ b/rsocket/src/transport/socket.rs @@ -4,7 +4,7 @@ use super::spi::*; use crate::error::{self, RSocketError}; use crate::frame::{self, Body, Frame}; use crate::payload::{Payload, SetupPayload}; -use crate::spi::{Flux, RSocket}; +use crate::spi::{Flux, RSocket, ServerResponder}; use crate::utils::EmptyRSocket; use crate::{runtime, Result}; use async_stream::stream; @@ -103,7 +103,7 @@ impl DuplexSocket { pub(crate) async fn dispatch( &mut self, frame: Frame, - acceptor: &Option, + acceptor: Option<&ServerResponder>, ) -> Result<()> { if let Some(frame) = self.join_frame(frame).await { self.process_once(frame, acceptor).await; @@ -112,7 +112,7 @@ impl DuplexSocket { } #[inline] - async fn process_once(&mut self, msg: Frame, acceptor: &Option) { + async fn process_once(&mut self, msg: Frame, acceptor: Option<&ServerResponder>) { let sid = msg.get_stream_id(); let flag = msg.get_flag(); debug_frame(false, &msg); @@ -343,10 +343,14 @@ impl DuplexSocket { } } + pub(crate) async fn bind_responder(&self, responder: Box) { + self.responder.set(responder).await; + } + #[inline] async fn on_setup( &self, - acceptor: &Option, + acceptor: Option<&ServerResponder>, sid: u32, flag: u16, setup: SetupPayload, @@ -356,11 +360,7 @@ impl DuplexSocket { self.responder.set(Box::new(EmptyRSocket)).await; Ok(()) } - Some(Acceptor::Simple(gen)) => { - self.responder.set(gen()).await; - Ok(()) - } - Some(Acceptor::Generate(gen)) => match gen(setup, Box::new(self.clone())) { + Some(gen) => match gen(setup, Box::new(self.clone())) { Ok(it) => { self.responder.set(it).await; Ok(()) @@ -414,7 +414,7 @@ impl DuplexSocket { Err(e) => { let sending = frame::Error::builder(sid, 0) .set_code(error::ERR_APPLICATION) - .set_data(Bytes::from("TODO: should be error details")) + .set_data(Bytes::from(e.to_string())) .build(); if let Err(e) = tx.send(sending) { error!("respond REQUEST_RESPONSE failed: {}", e); diff --git a/rsocket/src/transport/spi.rs b/rsocket/src/transport/spi.rs index dfd5b01..1198c6c 100644 --- a/rsocket/src/transport/spi.rs +++ b/rsocket/src/transport/spi.rs @@ -15,12 +15,6 @@ use crate::spi::{ClientResponder, RSocket, ServerResponder}; use crate::{error::RSocketError, frame::Frame}; use crate::{Error, Result}; -#[derive(Clone)] -pub(crate) enum Acceptor { - Simple(Arc), - Generate(Arc), -} - pub type FrameSink = dyn Sink + Send + Unpin; pub type FrameStream = dyn Stream> + Send + Unpin; diff --git a/rsocket/src/utils.rs b/rsocket/src/utils.rs index 1bff3bf..c54c58e 100644 --- a/rsocket/src/utils.rs +++ b/rsocket/src/utils.rs @@ -60,23 +60,27 @@ pub(crate) struct EmptyRSocket; #[async_trait] impl RSocket for EmptyRSocket { async fn metadata_push(&self, _req: Payload) -> Result<()> { - Err(RSocketError::ApplicationException("UNIMPLEMENT".into()).into()) + Err(anyhow!("UNIMPLEMENT")) } async fn fire_and_forget(&self, _req: Payload) -> Result<()> { - Err(RSocketError::ApplicationException("UNIMPLEMENT".into()).into()) + Err(anyhow!("UNIMPLEMENT")) } async fn request_response(&self, _req: Payload) -> Result> { - Err(RSocketError::ApplicationException("UNIMPLEMENT".into()).into()) + Err(anyhow!("UNIMPLEMENT")) } fn request_stream(&self, _req: Payload) -> Flux> { - Box::pin(futures::stream::empty()) + Box::pin(stream! { + yield Err(anyhow!("UNIMPLEMENT")); + }) } fn request_channel(&self, _reqs: Flux>) -> Flux> { - Box::pin(futures::stream::empty()) + Box::pin(stream! { + yield Err(anyhow!("UNIMPLEMENT")); + }) } }