Skip to content

Commit 535eb3f

Browse files
bkchrGeneral-Beck
authored andcommitted
Make sure we remove a peer on disconnect in gossip (paritytech#5104)
* Make sure we remove peers on disconnect in gossip state machine * Clear up the code * Add a comment
1 parent 9d08f54 commit 535eb3f

File tree

3 files changed

+63
-16
lines changed

3 files changed

+63
-16
lines changed

client/network-gossip/src/state_machine.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ impl<B: BlockT> ConsensusGossip<B> {
258258
let mut context = NetworkContext { gossip: self, network, engine_id: engine_id.clone() };
259259
v.peer_disconnected(&mut context, &who);
260260
}
261+
self.peers.remove(&who);
261262
}
262263

263264
/// Perform periodic maintenance
@@ -644,4 +645,52 @@ mod tests {
644645
let _ = consensus.live_message_sinks.remove(&([0, 0, 0, 0], topic));
645646
assert_eq!(stream.next(), None);
646647
}
648+
649+
#[test]
650+
fn peer_is_removed_on_disconnect() {
651+
struct TestNetwork;
652+
impl Network<Block> for TestNetwork {
653+
fn event_stream(
654+
&self,
655+
) -> std::pin::Pin<Box<dyn futures::Stream<Item = crate::Event> + Send>> {
656+
unimplemented!("Not required in tests")
657+
}
658+
659+
fn report_peer(&self, _: PeerId, _: crate::ReputationChange) {
660+
unimplemented!("Not required in tests")
661+
}
662+
663+
fn disconnect_peer(&self, _: PeerId) {
664+
unimplemented!("Not required in tests")
665+
}
666+
667+
fn write_notification(&self, _: PeerId, _: crate::ConsensusEngineId, _: Vec<u8>) {
668+
unimplemented!("Not required in tests")
669+
}
670+
671+
fn register_notifications_protocol(
672+
&self,
673+
_: ConsensusEngineId,
674+
_: std::borrow::Cow<'static, [u8]>,
675+
) {
676+
unimplemented!("Not required in tests")
677+
}
678+
679+
fn announce(&self, _: H256, _: Vec<u8>) {
680+
unimplemented!("Not required in tests")
681+
}
682+
}
683+
684+
let mut consensus = ConsensusGossip::<Block>::new();
685+
consensus.register_validator_internal([0, 0, 0, 0], Arc::new(AllowAll));
686+
687+
let mut network = TestNetwork;
688+
689+
let peer_id = PeerId::random();
690+
consensus.new_peer(&mut network, peer_id.clone(), Roles::FULL);
691+
assert!(consensus.peers.contains_key(&peer_id));
692+
693+
consensus.peer_disconnected(&mut network, peer_id.clone());
694+
assert!(!consensus.peers.contains_key(&peer_id));
695+
}
647696
}

client/network/src/protocol/sync.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,7 @@ impl<B: BlockT> ChainSync<B> {
11671167
}
11681168

11691169
/// Restart the sync process.
1170-
fn restart<'a>(&'a mut self) -> impl Iterator<Item = Result<(PeerId, BlockRequest<B>), BadPeer>> + 'a
1171-
{
1170+
fn restart<'a>(&'a mut self) -> impl Iterator<Item = Result<(PeerId, BlockRequest<B>), BadPeer>> + 'a {
11721171
self.queue_blocks.clear();
11731172
self.blocks.clear();
11741173
let info = self.client.info();

client/network/src/protocol/sync/blocks.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ impl<B: BlockT> BlockCollection<B> {
104104
common: NumberFor<B>,
105105
max_parallel: u32,
106106
max_ahead: u32,
107-
) -> Option<Range<NumberFor<B>>>
108-
{
107+
) -> Option<Range<NumberFor<B>>> {
109108
if peer_best <= common {
110109
// Bail out early
111110
return None;
@@ -165,20 +164,20 @@ impl<B: BlockT> BlockCollection<B> {
165164
pub fn drain(&mut self, from: NumberFor<B>) -> Vec<BlockData<B>> {
166165
let mut drained = Vec::new();
167166
let mut ranges = Vec::new();
168-
{
169-
let mut prev = from;
170-
for (start, range_data) in &mut self.blocks {
171-
match range_data {
172-
&mut BlockRangeState::Complete(ref mut blocks) if *start <= prev => {
173-
prev = *start + (blocks.len() as u32).into();
174-
let mut blocks = mem::replace(blocks, Vec::new());
175-
drained.append(&mut blocks);
176-
ranges.push(*start);
177-
},
178-
_ => break,
179-
}
167+
168+
let mut prev = from;
169+
for (start, range_data) in &mut self.blocks {
170+
match range_data {
171+
&mut BlockRangeState::Complete(ref mut blocks) if *start <= prev => {
172+
prev = *start + (blocks.len() as u32).into();
173+
// Remove all elements from `blocks` and add them to `drained`
174+
drained.append(blocks);
175+
ranges.push(*start);
176+
},
177+
_ => break,
180178
}
181179
}
180+
182181
for r in ranges {
183182
self.blocks.remove(&r);
184183
}

0 commit comments

Comments
 (0)