Skip to content

Commit 921874a

Browse files
committed
Prevent allocate more memory than we actually have
1 parent a84f329 commit 921874a

File tree

3 files changed

+95
-52
lines changed

3 files changed

+95
-52
lines changed

datafusion/src/execution/memory_manager.rs

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use log::info;
2424
use std::fmt;
2525
use std::fmt::{Debug, Display, Formatter};
2626
use std::sync::atomic::{AtomicUsize, Ordering};
27-
use std::sync::{Arc, Mutex, Weak};
27+
use std::sync::{Arc, Condvar, Mutex, Weak};
2828

2929
static mut CONSUMER_ID: AtomicUsize = AtomicUsize::new(0);
3030

@@ -91,23 +91,32 @@ pub trait MemoryConsumer: Send + Sync {
9191
/// reached for this consumer.
9292
async fn try_grow(&self, required: usize) -> Result<()> {
9393
let current = self.mem_used();
94-
let can_grow = self
94+
info!(
95+
"trying to acquire {} whiling holding {} from consumer {}",
96+
human_readable_size(required),
97+
human_readable_size(current),
98+
self.id(),
99+
);
100+
101+
let can_grow_directly = self
95102
.memory_manager()
96-
.can_grow(required, current, self.id())
103+
.can_grow_directly(required, current)
97104
.await;
98-
if !can_grow {
105+
if !can_grow_directly {
99106
info!(
100-
"Failed to grow memory of {} from consumer {}, spilling...",
107+
"Failed to grow memory of {} directly from consumer {}, spilling first ...",
101108
human_readable_size(required),
102109
self.id()
103110
);
104-
self.spill().await?;
111+
let freed = self.spill().await?;
112+
self.memory_manager().record_free(freed);
105113
}
114+
self.memory_manager().record_acquire(required);
106115
Ok(())
107116
}
108117

109-
/// Spill in-memory buffers to disk, free memory
110-
async fn spill(&self) -> Result<()>;
118+
/// Spill in-memory buffers to disk, free memory, return the previous used
119+
async fn spill(&self) -> Result<usize>;
111120

112121
/// Current memory used by this consumer
113122
fn mem_used(&self) -> usize;
@@ -160,10 +169,13 @@ pub struct MemoryManager {
160169
requesters: Arc<Mutex<HashMap<MemoryConsumerId, Weak<dyn MemoryConsumer>>>>,
161170
trackers: Arc<Mutex<HashMap<MemoryConsumerId, Weak<dyn MemoryConsumer>>>>,
162171
pool_size: usize,
172+
requesters_total: Arc<Mutex<usize>>,
173+
cv: Condvar,
163174
}
164175

165176
impl MemoryManager {
166177
/// Create new memory manager based on max available pool_size
178+
#[allow(clippy::mutex_atomic)]
167179
pub fn new(pool_size: usize) -> Self {
168180
info!(
169181
"Creating memory manager with initial size {}",
@@ -173,6 +185,8 @@ impl MemoryManager {
173185
requesters: Arc::new(Mutex::new(HashMap::new())),
174186
trackers: Arc::new(Mutex::new(HashMap::new())),
175187
pool_size,
188+
requesters_total: Arc::new(Mutex::new(0)),
189+
cv: Condvar::new(),
176190
}
177191
}
178192

@@ -189,10 +203,7 @@ impl MemoryManager {
189203
}
190204

191205
/// Register a new memory consumer for memory usage tracking
192-
pub(crate) fn register_consumer(
193-
self: &Arc<Self>,
194-
consumer: &Arc<dyn MemoryConsumer>,
195-
) {
206+
pub(crate) fn register_consumer(&self, consumer: &Arc<dyn MemoryConsumer>) {
196207
let id = consumer.id().clone();
197208
match consumer.type_() {
198209
ConsumerType::Requesting => {
@@ -206,32 +217,58 @@ impl MemoryManager {
206217
}
207218
}
208219

220+
fn max_mem_for_requesters(&self) -> usize {
221+
let trk_total = self.get_tracker_total();
222+
self.pool_size - trk_total
223+
}
224+
209225
/// Grow memory attempt from a consumer, return if we could grant that much to it
210-
async fn can_grow(
211-
self: &Arc<Self>,
212-
required: usize,
213-
current: usize,
214-
consumer_id: &MemoryConsumerId,
215-
) -> bool {
216-
let tracker_total = self.get_tracker_total();
217-
let max_per_op = {
218-
let total_available = self.pool_size - tracker_total;
219-
let ops = self.requesters.lock().unwrap().len();
220-
(total_available / ops) as usize
221-
};
222-
let granted = required + current < max_per_op;
223-
info!(
224-
"trying to acquire {} whiling holding {} from consumer {}, got: {}",
225-
human_readable_size(required),
226-
human_readable_size(current),
227-
consumer_id,
228-
granted,
229-
);
226+
async fn can_grow_directly(&self, required: usize, current: usize) -> bool {
227+
let num_rqt = self.requesters.lock().unwrap().len();
228+
let mut rqt_current_used = self.requesters_total.lock().unwrap();
229+
let mut rqt_max = self.max_mem_for_requesters();
230+
231+
let granted;
232+
loop {
233+
let remaining = rqt_max - *rqt_current_used;
234+
let max_per_rqt = rqt_max / num_rqt;
235+
let min_per_rqt = max_per_rqt / 2;
236+
237+
if required + current >= max_per_rqt {
238+
granted = false;
239+
break;
240+
}
241+
242+
if remaining >= required {
243+
granted = true;
244+
break;
245+
} else if current < min_per_rqt {
246+
// if we cannot acquire at lease 1/2n memory, just wait for others
247+
// to spill instead spill self frequently with limited total mem
248+
rqt_current_used = self.cv.wait(rqt_current_used).unwrap();
249+
} else {
250+
granted = false;
251+
break;
252+
}
253+
254+
rqt_max = self.max_mem_for_requesters();
255+
}
256+
230257
granted
231258
}
232259

260+
fn record_free(&self, freed: usize) {
261+
let mut requesters_total = self.requesters_total.lock().unwrap();
262+
*requesters_total -= freed;
263+
self.cv.notify_all()
264+
}
265+
266+
fn record_acquire(&self, acquired: usize) {
267+
*self.requesters_total.lock().unwrap() += acquired;
268+
}
269+
233270
/// Drop a memory consumer from memory usage tracking
234-
pub(crate) fn drop_consumer(self: &Arc<Self>, id: &MemoryConsumerId) {
271+
pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId) {
235272
// find in requesters first
236273
{
237274
let mut requesters = self.requesters.lock().unwrap();
@@ -319,8 +356,10 @@ mod tests {
319356
}
320357
}
321358

322-
fn set_used(&self, used: usize) {
323-
self.mem_used.store(used, Ordering::SeqCst);
359+
async fn do_with_mem(&self, grow: usize) -> Result<()> {
360+
self.try_grow(grow).await?;
361+
self.mem_used.fetch_add(grow, Ordering::SeqCst);
362+
Ok(())
324363
}
325364

326365
fn get_spills(&self) -> usize {
@@ -346,10 +385,10 @@ mod tests {
346385
&ConsumerType::Requesting
347386
}
348387

349-
async fn spill(&self) -> Result<()> {
388+
async fn spill(&self) -> Result<usize> {
350389
self.spills.fetch_add(1, Ordering::SeqCst);
351-
self.mem_used.store(0, Ordering::SeqCst);
352-
Ok(())
390+
let used = self.mem_used.swap(0, Ordering::SeqCst);
391+
Ok(used)
353392
}
354393

355394
fn mem_used(&self) -> usize {
@@ -391,8 +430,8 @@ mod tests {
391430
&ConsumerType::Tracking
392431
}
393432

394-
async fn spill(&self) -> Result<()> {
395-
Ok(())
433+
async fn spill(&self) -> Result<usize> {
434+
Ok(0)
396435
}
397436

398437
fn mem_used(&self) -> usize {
@@ -426,21 +465,25 @@ mod tests {
426465
runtime.register_consumer(&(requester1.clone() as Arc<dyn MemoryConsumer>));
427466

428467
// first requester entered, should be able to use any of the remaining 80
429-
requester1.set_used(40);
430-
requester1.try_grow(10).await?;
468+
requester1.do_with_mem(40).await?;
469+
requester1.do_with_mem(10).await?;
431470
assert_eq!(requester1.get_spills(), 0);
471+
assert_eq!(requester1.mem_used(), 50);
472+
assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 50);
432473

433474
let requester2 = Arc::new(DummyRequester::new(0, runtime.clone()));
434475
runtime.register_consumer(&(requester2.clone() as Arc<dyn MemoryConsumer>));
435476

436-
requester2.set_used(20);
437-
requester2.try_grow(30).await?;
477+
requester2.do_with_mem(20).await?;
478+
requester2.do_with_mem(30).await?;
438479
assert_eq!(requester2.get_spills(), 1);
439-
assert_eq!(requester2.mem_used(), 0);
480+
assert_eq!(requester2.mem_used(), 30);
440481

441-
requester1.try_grow(10).await?;
482+
requester1.do_with_mem(10).await?;
442483
assert_eq!(requester1.get_spills(), 1);
443-
assert_eq!(requester1.mem_used(), 0);
484+
assert_eq!(requester1.mem_used(), 10);
485+
486+
assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 40);
444487

445488
Ok(())
446489
}

datafusion/src/physical_plan/sorts/external_sort.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ impl MemoryConsumer for ExternalSorter {
204204
&ConsumerType::Requesting
205205
}
206206

207-
async fn spill(&self) -> Result<()> {
207+
async fn spill(&self) -> Result<usize> {
208208
info!(
209209
"{}[{}] spilling sort data of {} to disk while inserting ({} time(s) so far)",
210210
self.name(),
@@ -217,7 +217,7 @@ impl MemoryConsumer for ExternalSorter {
217217
let mut in_mem_batches = self.in_mem_batches.lock().await;
218218
// we could always get a chance to free some memory as long as we are holding some
219219
if in_mem_batches.len() == 0 {
220-
return Ok(());
220+
return Ok(0);
221221
}
222222

223223
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
@@ -237,11 +237,11 @@ impl MemoryConsumer for ExternalSorter {
237237
.await?;
238238

239239
let mut spills = self.spills.lock().await;
240-
self.used.store(0, Ordering::SeqCst);
240+
let used = self.used.swap(0, Ordering::SeqCst);
241241
self.spilled_count.fetch_add(1, Ordering::SeqCst);
242242
self.spilled_bytes.fetch_add(total_size, Ordering::SeqCst);
243243
spills.push(path);
244-
Ok(())
244+
Ok(used)
245245
}
246246

247247
fn mem_used(&self) -> usize {

datafusion/src/physical_plan/sorts/sort_preserving_merge.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ impl MemoryConsumer for MergingStreams {
264264
&ConsumerType::Tracking
265265
}
266266

267-
async fn spill(&self) -> Result<()> {
267+
async fn spill(&self) -> Result<usize> {
268268
return Err(DataFusionError::Internal(format!(
269269
"Calling spill on a tracking only consumer {}, {}",
270270
self.name(),

0 commit comments

Comments
 (0)