@@ -24,7 +24,7 @@ use log::info;
2424use std:: fmt;
2525use std:: fmt:: { Debug , Display , Formatter } ;
2626use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
27- use std:: sync:: { Arc , Mutex , Weak } ;
27+ use std:: sync:: { Arc , Condvar , Mutex , Weak } ;
2828
2929static mut CONSUMER_ID : AtomicUsize = AtomicUsize :: new ( 0 ) ;
3030
@@ -91,23 +91,32 @@ pub trait MemoryConsumer: Send + Sync {
9191 /// reached for this consumer.
9292 async fn try_grow ( & self , required : usize ) -> Result < ( ) > {
9393 let current = self . mem_used ( ) ;
94- let can_grow = self
94+ info ! (
95+ "trying to acquire {} whiling holding {} from consumer {}" ,
96+ human_readable_size( required) ,
97+ human_readable_size( current) ,
98+ self . id( ) ,
99+ ) ;
100+
101+ let can_grow_directly = self
95102 . memory_manager ( )
96- . can_grow ( required, current, self . id ( ) )
103+ . can_grow_directly ( required, current)
97104 . await ;
98- if !can_grow {
105+ if !can_grow_directly {
99106 info ! (
100- "Failed to grow memory of {} from consumer {}, spilling..." ,
107+ "Failed to grow memory of {} directly from consumer {}, spilling first ..." ,
101108 human_readable_size( required) ,
102109 self . id( )
103110 ) ;
104- self . spill ( ) . await ?;
111+ let freed = self . spill ( ) . await ?;
112+ self . memory_manager ( ) . record_free ( freed) ;
105113 }
114+ self . memory_manager ( ) . record_acquire ( required) ;
106115 Ok ( ( ) )
107116 }
108117
109- /// Spill in-memory buffers to disk, free memory
110- async fn spill ( & self ) -> Result < ( ) > ;
118+ /// Spill in-memory buffers to disk, free memory, return the previous used
119+ async fn spill ( & self ) -> Result < usize > ;
111120
112121 /// Current memory used by this consumer
113122 fn mem_used ( & self ) -> usize ;
@@ -160,10 +169,13 @@ pub struct MemoryManager {
160169 requesters : Arc < Mutex < HashMap < MemoryConsumerId , Weak < dyn MemoryConsumer > > > > ,
161170 trackers : Arc < Mutex < HashMap < MemoryConsumerId , Weak < dyn MemoryConsumer > > > > ,
162171 pool_size : usize ,
172+ requesters_total : Arc < Mutex < usize > > ,
173+ cv : Condvar ,
163174}
164175
165176impl MemoryManager {
166177 /// Create new memory manager based on max available pool_size
178+ #[ allow( clippy:: mutex_atomic) ]
167179 pub fn new ( pool_size : usize ) -> Self {
168180 info ! (
169181 "Creating memory manager with initial size {}" ,
@@ -173,6 +185,8 @@ impl MemoryManager {
173185 requesters : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
174186 trackers : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
175187 pool_size,
188+ requesters_total : Arc :: new ( Mutex :: new ( 0 ) ) ,
189+ cv : Condvar :: new ( ) ,
176190 }
177191 }
178192
@@ -189,10 +203,7 @@ impl MemoryManager {
189203 }
190204
191205 /// Register a new memory consumer for memory usage tracking
192- pub ( crate ) fn register_consumer (
193- self : & Arc < Self > ,
194- consumer : & Arc < dyn MemoryConsumer > ,
195- ) {
206+ pub ( crate ) fn register_consumer ( & self , consumer : & Arc < dyn MemoryConsumer > ) {
196207 let id = consumer. id ( ) . clone ( ) ;
197208 match consumer. type_ ( ) {
198209 ConsumerType :: Requesting => {
@@ -206,32 +217,58 @@ impl MemoryManager {
206217 }
207218 }
208219
220+ fn max_mem_for_requesters ( & self ) -> usize {
221+ let trk_total = self . get_tracker_total ( ) ;
222+ self . pool_size - trk_total
223+ }
224+
209225 /// Grow memory attempt from a consumer, return if we could grant that much to it
210- async fn can_grow (
211- self : & Arc < Self > ,
212- required : usize ,
213- current : usize ,
214- consumer_id : & MemoryConsumerId ,
215- ) -> bool {
216- let tracker_total = self . get_tracker_total ( ) ;
217- let max_per_op = {
218- let total_available = self . pool_size - tracker_total;
219- let ops = self . requesters . lock ( ) . unwrap ( ) . len ( ) ;
220- ( total_available / ops) as usize
221- } ;
222- let granted = required + current < max_per_op;
223- info ! (
224- "trying to acquire {} whiling holding {} from consumer {}, got: {}" ,
225- human_readable_size( required) ,
226- human_readable_size( current) ,
227- consumer_id,
228- granted,
229- ) ;
226+ async fn can_grow_directly ( & self , required : usize , current : usize ) -> bool {
227+ let num_rqt = self . requesters . lock ( ) . unwrap ( ) . len ( ) ;
228+ let mut rqt_current_used = self . requesters_total . lock ( ) . unwrap ( ) ;
229+ let mut rqt_max = self . max_mem_for_requesters ( ) ;
230+
231+ let granted;
232+ loop {
233+ let remaining = rqt_max - * rqt_current_used;
234+ let max_per_rqt = rqt_max / num_rqt;
235+ let min_per_rqt = max_per_rqt / 2 ;
236+
237+ if required + current >= max_per_rqt {
238+ granted = false ;
239+ break ;
240+ }
241+
242+ if remaining >= required {
243+ granted = true ;
244+ break ;
245+ } else if current < min_per_rqt {
246+ // if we cannot acquire at lease 1/2n memory, just wait for others
247+ // to spill instead spill self frequently with limited total mem
248+ rqt_current_used = self . cv . wait ( rqt_current_used) . unwrap ( ) ;
249+ } else {
250+ granted = false ;
251+ break ;
252+ }
253+
254+ rqt_max = self . max_mem_for_requesters ( ) ;
255+ }
256+
230257 granted
231258 }
232259
260+ fn record_free ( & self , freed : usize ) {
261+ let mut requesters_total = self . requesters_total . lock ( ) . unwrap ( ) ;
262+ * requesters_total -= freed;
263+ self . cv . notify_all ( )
264+ }
265+
266+ fn record_acquire ( & self , acquired : usize ) {
267+ * self . requesters_total . lock ( ) . unwrap ( ) += acquired;
268+ }
269+
233270 /// Drop a memory consumer from memory usage tracking
234- pub ( crate ) fn drop_consumer ( self : & Arc < Self > , id : & MemoryConsumerId ) {
271+ pub ( crate ) fn drop_consumer ( & self , id : & MemoryConsumerId ) {
235272 // find in requesters first
236273 {
237274 let mut requesters = self . requesters . lock ( ) . unwrap ( ) ;
@@ -319,8 +356,10 @@ mod tests {
319356 }
320357 }
321358
322- fn set_used ( & self , used : usize ) {
323- self . mem_used . store ( used, Ordering :: SeqCst ) ;
359+ async fn do_with_mem ( & self , grow : usize ) -> Result < ( ) > {
360+ self . try_grow ( grow) . await ?;
361+ self . mem_used . fetch_add ( grow, Ordering :: SeqCst ) ;
362+ Ok ( ( ) )
324363 }
325364
326365 fn get_spills ( & self ) -> usize {
@@ -346,10 +385,10 @@ mod tests {
346385 & ConsumerType :: Requesting
347386 }
348387
349- async fn spill ( & self ) -> Result < ( ) > {
388+ async fn spill ( & self ) -> Result < usize > {
350389 self . spills . fetch_add ( 1 , Ordering :: SeqCst ) ;
351- self . mem_used . store ( 0 , Ordering :: SeqCst ) ;
352- Ok ( ( ) )
390+ let used = self . mem_used . swap ( 0 , Ordering :: SeqCst ) ;
391+ Ok ( used )
353392 }
354393
355394 fn mem_used ( & self ) -> usize {
@@ -391,8 +430,8 @@ mod tests {
391430 & ConsumerType :: Tracking
392431 }
393432
394- async fn spill ( & self ) -> Result < ( ) > {
395- Ok ( ( ) )
433+ async fn spill ( & self ) -> Result < usize > {
434+ Ok ( 0 )
396435 }
397436
398437 fn mem_used ( & self ) -> usize {
@@ -426,21 +465,25 @@ mod tests {
426465 runtime. register_consumer ( & ( requester1. clone ( ) as Arc < dyn MemoryConsumer > ) ) ;
427466
428467 // first requester entered, should be able to use any of the remaining 80
429- requester1. set_used ( 40 ) ;
430- requester1. try_grow ( 10 ) . await ?;
468+ requester1. do_with_mem ( 40 ) . await ? ;
469+ requester1. do_with_mem ( 10 ) . await ?;
431470 assert_eq ! ( requester1. get_spills( ) , 0 ) ;
471+ assert_eq ! ( requester1. mem_used( ) , 50 ) ;
472+ assert_eq ! ( * runtime. memory_manager. requesters_total. lock( ) . unwrap( ) , 50 ) ;
432473
433474 let requester2 = Arc :: new ( DummyRequester :: new ( 0 , runtime. clone ( ) ) ) ;
434475 runtime. register_consumer ( & ( requester2. clone ( ) as Arc < dyn MemoryConsumer > ) ) ;
435476
436- requester2. set_used ( 20 ) ;
437- requester2. try_grow ( 30 ) . await ?;
477+ requester2. do_with_mem ( 20 ) . await ? ;
478+ requester2. do_with_mem ( 30 ) . await ?;
438479 assert_eq ! ( requester2. get_spills( ) , 1 ) ;
439- assert_eq ! ( requester2. mem_used( ) , 0 ) ;
480+ assert_eq ! ( requester2. mem_used( ) , 30 ) ;
440481
441- requester1. try_grow ( 10 ) . await ?;
482+ requester1. do_with_mem ( 10 ) . await ?;
442483 assert_eq ! ( requester1. get_spills( ) , 1 ) ;
443- assert_eq ! ( requester1. mem_used( ) , 0 ) ;
484+ assert_eq ! ( requester1. mem_used( ) , 10 ) ;
485+
486+ assert_eq ! ( * runtime. memory_manager. requesters_total. lock( ) . unwrap( ) , 40 ) ;
444487
445488 Ok ( ( ) )
446489 }
0 commit comments