@@ -25,7 +25,8 @@ mod timer;
2525use self :: iface:: InterfaceState ;
2626use crate :: behaviour:: { socket:: AsyncSocket , timer:: Builder } ;
2727use crate :: Config ;
28- use futures:: Stream ;
28+ use futures:: channel:: mpsc;
29+ use futures:: { Stream , StreamExt } ;
2930use if_watch:: IfEvent ;
3031use libp2p_core:: { Endpoint , Multiaddr } ;
3132use libp2p_identity:: PeerId ;
@@ -36,6 +37,8 @@ use libp2p_swarm::{
3637} ;
3738use smallvec:: SmallVec ;
3839use std:: collections:: hash_map:: { Entry , HashMap } ;
40+ use std:: future:: Future ;
41+ use std:: sync:: { Arc , RwLock } ;
3942use std:: { cmp, fmt, io, net:: IpAddr , pin:: Pin , task:: Context , task:: Poll , time:: Instant } ;
4043
4144/// An abstraction to allow for compatibility with various async runtimes.
@@ -47,16 +50,27 @@ pub trait Provider: 'static {
4750 /// The IfWatcher type.
4851 type Watcher : Stream < Item = std:: io:: Result < IfEvent > > + fmt:: Debug + Unpin ;
4952
53+ type TaskHandle : Abort ;
54+
5055 /// Create a new instance of the `IfWatcher` type.
5156 fn new_watcher ( ) -> Result < Self :: Watcher , std:: io:: Error > ;
57+
58+ fn spawn ( task : impl Future < Output = ( ) > + Send + ' static ) -> Self :: TaskHandle ;
59+ }
60+
61+ #[ allow( unreachable_pub) ] // Not re-exported.
62+ pub trait Abort {
63+ fn abort ( self ) ;
5264}
5365
5466/// The type of a [`Behaviour`] using the `async-io` implementation.
5567#[ cfg( feature = "async-io" ) ]
5668pub mod async_io {
5769 use super :: Provider ;
58- use crate :: behaviour:: { socket:: asio:: AsyncUdpSocket , timer:: asio:: AsyncTimer } ;
70+ use crate :: behaviour:: { socket:: asio:: AsyncUdpSocket , timer:: asio:: AsyncTimer , Abort } ;
71+ use async_std:: task:: JoinHandle ;
5972 use if_watch:: smol:: IfWatcher ;
73+ use std:: future:: Future ;
6074
6175 #[ doc( hidden) ]
6276 pub enum AsyncIo { }
@@ -65,10 +79,21 @@ pub mod async_io {
6579 type Socket = AsyncUdpSocket ;
6680 type Timer = AsyncTimer ;
6781 type Watcher = IfWatcher ;
82+ type TaskHandle = JoinHandle < ( ) > ;
6883
6984 fn new_watcher ( ) -> Result < Self :: Watcher , std:: io:: Error > {
7085 IfWatcher :: new ( )
7186 }
87+
88+ fn spawn ( task : impl Future < Output = ( ) > + Send + ' static ) -> JoinHandle < ( ) > {
89+ async_std:: task:: spawn ( task)
90+ }
91+ }
92+
93+ impl Abort for JoinHandle < ( ) > {
94+ fn abort ( self ) {
95+ async_std:: task:: spawn ( self . cancel ( ) ) ;
96+ }
7297 }
7398
7499 pub type Behaviour = super :: Behaviour < AsyncIo > ;
@@ -78,8 +103,10 @@ pub mod async_io {
78103#[ cfg( feature = "tokio" ) ]
79104pub mod tokio {
80105 use super :: Provider ;
81- use crate :: behaviour:: { socket:: tokio:: TokioUdpSocket , timer:: tokio:: TokioTimer } ;
106+ use crate :: behaviour:: { socket:: tokio:: TokioUdpSocket , timer:: tokio:: TokioTimer , Abort } ;
82107 use if_watch:: tokio:: IfWatcher ;
108+ use std:: future:: Future ;
109+ use tokio:: task:: JoinHandle ;
83110
84111 #[ doc( hidden) ]
85112 pub enum Tokio { }
@@ -88,10 +115,21 @@ pub mod tokio {
88115 type Socket = TokioUdpSocket ;
89116 type Timer = TokioTimer ;
90117 type Watcher = IfWatcher ;
118+ type TaskHandle = JoinHandle < ( ) > ;
91119
92120 fn new_watcher ( ) -> Result < Self :: Watcher , std:: io:: Error > {
93121 IfWatcher :: new ( )
94122 }
123+
124+ fn spawn ( task : impl Future < Output = ( ) > + Send + ' static ) -> Self :: TaskHandle {
125+ tokio:: spawn ( task)
126+ }
127+ }
128+
129+ impl Abort for JoinHandle < ( ) > {
130+ fn abort ( self ) {
131+ JoinHandle :: abort ( & self )
132+ }
95133 }
96134
97135 pub type Behaviour = super :: Behaviour < Tokio > ;
@@ -110,8 +148,11 @@ where
110148 /// Iface watcher.
111149 if_watch : P :: Watcher ,
112150
113- /// Mdns interface states.
114- iface_states : HashMap < IpAddr , InterfaceState < P :: Socket , P :: Timer > > ,
151+ /// Handles to tasks running the mDNS queries.
152+ if_tasks : HashMap < IpAddr , P :: TaskHandle > ,
153+
154+ query_response_receiver : mpsc:: Receiver < ( PeerId , Multiaddr , Instant ) > ,
155+ query_response_sender : mpsc:: Sender < ( PeerId , Multiaddr , Instant ) > ,
115156
116157 /// List of nodes that we have discovered, the address, and when their TTL expires.
117158 ///
@@ -124,7 +165,11 @@ where
124165 /// `None` if `discovered_nodes` is empty.
125166 closest_expiration : Option < P :: Timer > ,
126167
127- listen_addresses : ListenAddresses ,
168+ /// The current set of listen addresses.
169+ ///
170+ /// This is shared across all interface tasks using an [`RwLock`].
171+ /// The [`Behaviour`] updates this upon new [`FromSwarm`] events where as [`InterfaceState`]s read from it to answer inbound mDNS queries.
172+ listen_addresses : Arc < RwLock < ListenAddresses > > ,
128173
129174 local_peer_id : PeerId ,
130175}
@@ -135,10 +180,14 @@ where
135180{
136181 /// Builds a new `Mdns` behaviour.
137182 pub fn new ( config : Config , local_peer_id : PeerId ) -> io:: Result < Self > {
183+ let ( tx, rx) = mpsc:: channel ( 10 ) ; // Chosen arbitrarily.
184+
138185 Ok ( Self {
139186 config,
140187 if_watch : P :: new_watcher ( ) ?,
141- iface_states : Default :: default ( ) ,
188+ if_tasks : Default :: default ( ) ,
189+ query_response_receiver : rx,
190+ query_response_sender : tx,
142191 discovered_nodes : Default :: default ( ) ,
143192 closest_expiration : Default :: default ( ) ,
144193 listen_addresses : Default :: default ( ) ,
@@ -147,6 +196,7 @@ where
147196 }
148197
149198 /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
199+ #[ deprecated( note = "Use `discovered_nodes` iterator instead." ) ]
150200 pub fn has_node ( & self , peer_id : & PeerId ) -> bool {
151201 self . discovered_nodes ( ) . any ( |p| p == peer_id)
152202 }
@@ -157,6 +207,7 @@ where
157207 }
158208
159209 /// Expires a node before the ttl.
210+ #[ deprecated( note = "Unused API. Will be removed in the next release." ) ]
160211 pub fn expire_node ( & mut self , peer_id : & PeerId ) {
161212 let now = Instant :: now ( ) ;
162213 for ( peer, _addr, expires) in & mut self . discovered_nodes {
@@ -225,28 +276,10 @@ where
225276 }
226277
227278 fn on_swarm_event ( & mut self , event : FromSwarm < Self :: ConnectionHandler > ) {
228- self . listen_addresses . on_swarm_event ( & event) ;
229-
230- match event {
231- FromSwarm :: NewListener ( _) => {
232- log:: trace!( "waking interface state because listening address changed" ) ;
233- for iface in self . iface_states . values_mut ( ) {
234- iface. fire_timer ( ) ;
235- }
236- }
237- FromSwarm :: ConnectionClosed ( _)
238- | FromSwarm :: ConnectionEstablished ( _)
239- | FromSwarm :: DialFailure ( _)
240- | FromSwarm :: AddressChange ( _)
241- | FromSwarm :: ListenFailure ( _)
242- | FromSwarm :: NewListenAddr ( _)
243- | FromSwarm :: ExpiredListenAddr ( _)
244- | FromSwarm :: ListenerError ( _)
245- | FromSwarm :: ListenerClosed ( _)
246- | FromSwarm :: NewExternalAddrCandidate ( _)
247- | FromSwarm :: ExternalAddrExpired ( _)
248- | FromSwarm :: ExternalAddrConfirmed ( _) => { }
249- }
279+ self . listen_addresses
280+ . write ( )
281+ . unwrap_or_else ( |e| e. into_inner ( ) )
282+ . on_swarm_event ( & event) ;
250283 }
251284
252285 fn poll (
@@ -267,43 +300,50 @@ where
267300 {
268301 continue ;
269302 }
270- if let Entry :: Vacant ( e) = self . iface_states . entry ( addr) {
271- match InterfaceState :: new ( addr, self . config . clone ( ) , self . local_peer_id ) {
303+ if let Entry :: Vacant ( e) = self . if_tasks . entry ( addr) {
304+ match InterfaceState :: < P :: Socket , P :: Timer > :: new (
305+ addr,
306+ self . config . clone ( ) ,
307+ self . local_peer_id ,
308+ self . listen_addresses . clone ( ) ,
309+ self . query_response_sender . clone ( ) ,
310+ ) {
272311 Ok ( iface_state) => {
273- e. insert ( iface_state) ;
312+ e. insert ( P :: spawn ( iface_state) ) ;
274313 }
275314 Err ( err) => log:: error!( "failed to create `InterfaceState`: {}" , err) ,
276315 }
277316 }
278317 }
279318 Ok ( IfEvent :: Down ( inet) ) => {
280- if self . iface_states . contains_key ( & inet. addr ( ) ) {
319+ if let Some ( handle ) = self . if_tasks . remove ( & inet. addr ( ) ) {
281320 log:: info!( "dropping instance {}" , inet. addr( ) ) ;
282- self . iface_states . remove ( & inet. addr ( ) ) ;
321+
322+ handle. abort ( ) ;
283323 }
284324 }
285325 Err ( err) => log:: error!( "if watch returned an error: {}" , err) ,
286326 }
287327 }
288328 // Emit discovered event.
289329 let mut discovered = Vec :: new ( ) ;
290- for iface_state in self . iface_states . values_mut ( ) {
291- while let Poll :: Ready ( ( peer, addr, expiration) ) =
292- iface_state. poll ( cx, & self . listen_addresses )
330+
331+ while let Poll :: Ready ( Some ( ( peer, addr, expiration) ) ) =
332+ self . query_response_receiver . poll_next_unpin ( cx)
333+ {
334+ if let Some ( ( _, _, cur_expires) ) = self
335+ . discovered_nodes
336+ . iter_mut ( )
337+ . find ( |( p, a, _) | * p == peer && * a == addr)
293338 {
294- if let Some ( ( _, _, cur_expires) ) = self
295- . discovered_nodes
296- . iter_mut ( )
297- . find ( |( p, a, _) | * p == peer && * a == addr)
298- {
299- * cur_expires = cmp:: max ( * cur_expires, expiration) ;
300- } else {
301- log:: info!( "discovered: {} {}" , peer, addr) ;
302- self . discovered_nodes . push ( ( peer, addr. clone ( ) , expiration) ) ;
303- discovered. push ( ( peer, addr) ) ;
304- }
339+ * cur_expires = cmp:: max ( * cur_expires, expiration) ;
340+ } else {
341+ log:: info!( "discovered: {} {}" , peer, addr) ;
342+ self . discovered_nodes . push ( ( peer, addr. clone ( ) , expiration) ) ;
343+ discovered. push ( ( peer, addr) ) ;
305344 }
306345 }
346+
307347 if !discovered. is_empty ( ) {
308348 let event = Event :: Discovered ( discovered) ;
309349 return Poll :: Ready ( ToSwarm :: GenerateEvent ( event) ) ;
0 commit comments